Skip to content

Rllib experiment


Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.

This is a US Government Work not subject to copyright protection in the US.

The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.


RllibExperiment (BaseExperiment) ¤

The Rllib Experiment is an experiment for running multiagent configurable environments with patchable settings

Source code in corl/experiments/rllib_experiment.py
class RllibExperiment(BaseExperiment):
    """
    The Rllib Experiment is an experiment for running
    multiagent configurable environments with patchable settings
    """

    def __init__(self, **kwargs) -> None:
        self.config: RllibExperimentValidator
        self._logger = logging.getLogger(RllibExperiment.__name__)
        super().__init__(**kwargs)

    @property
    def get_validator(self) -> typing.Type[RllibExperimentValidator]:
        return RllibExperimentValidator

    @property
    def get_policy_validator(self) -> typing.Type[RllibPolicyValidator]:
        """Return validator"""
        return RllibPolicyValidator

    def run_experiment(self, args: argparse.Namespace) -> None:

        rllib_config = self._select_rllib_config(args.compute_platform)

        if args.compute_platform in ['ray']:
            self._update_ray_config_for_ray_platform()

        self._add_trial_creator()

        # This needs to be before the ray cluster is initialized
        if args.debug:
            self.config.ray_config['local_mode'] = True

        ray.init(**self.config.ray_config)

        ray_resources = ray.available_resources()

        auto_configure_rllib_config(rllib_config, self.config.auto_rllib_config_setup, ray_resources)

        self.config.env_config["agents"], self.config.env_config["agent_platforms"] = self.create_agents(
            args.platform_config, args.agent_config
        )

        self.config.env_config["horizon"] = rllib_config["horizon"]

        if args.output:
            self.config.env_config["output_path"] = args.output
            self.config.tune_config["local_dir"] = args.output

        if args.name:
            self.config.env_config["TrialName"] = args.name

        if args.other_platform:
            self.config.env_config["other_platforms"] = self.create_other_platforms(args.other_platform)

        if not self.config.ray_config['local_mode']:
            self.config.env_config['episode_parameter_provider'] = RemoteEpisodeParameterProvider.wrap_epp_factory(
                Factory(**self.config.env_config['episode_parameter_provider']),
                actor_name=ACT3MultiAgentEnv.episode_parameter_provider_name
            )

            for agent_name, agent_configs in self.config.env_config['agents'].items():
                agent_configs.class_config.config['episode_parameter_provider'] = RemoteEpisodeParameterProvider.wrap_epp_factory(
                    Factory(**agent_configs.class_config.config['episode_parameter_provider']), agent_name
                )

        self.config.env_config['epp_registry'] = ACT3MultiAgentEnvValidator(**self.config.env_config).epp_registry

        tmp = ACT3MultiAgentEnv(self.config.env_config)
        tmp_as = tmp.action_space
        tmp_os = tmp.observation_space
        tmp_ac = self.config.env_config['agents']

        policies = {
            policy_name: (
                tmp_ac[policy_name].policy_config["policy_class"],
                policy_obs,
                tmp_as[policy_name],
                tmp_ac[policy_name].policy_config["config"]
            )
            for policy_name,
            policy_obs in tmp_os.spaces.items()
            if tmp_ac[policy_name]
        }

        train_policies = [policy_name for policy_name in policies.keys() if tmp_ac[policy_name].policy_config["train"]]

        self._update_rllib_config(rllib_config, train_policies, policies, args)

        self._enable_episode_parameter_provider_checkpointing()

        if args.profile:
            if "stop" not in self.config.tune_config:
                self.config.tune_config["stop"] = {}
            self.config.tune_config["stop"]["training_iteration"] = args.profile_iterations

        search_class = None
        if self.config.hparam_search_class is not None:
            if self.config.hparam_search_config is not None:
                search_class = self.config.hparam_search_class(**self.config.hparam_search_config)
            else:
                search_class = self.config.hparam_search_class()
            search_class.add_algorithm_hparams(rllib_config, self.config.tune_config)

        tune.run(
            config=rllib_config,
            **self.config.tune_config,
        )

    def _update_rllib_config(self, rllib_config, train_policies, policies, args: argparse.Namespace) -> None:
        """
        Update several rllib config fields
        """

        rllib_config["multiagent"] = {
            "policies": policies, "policy_mapping_fn": lambda agent_id: agent_id, "policies_to_train": train_policies
        }

        rllib_config["env"] = ACT3MultiAgentEnv
        callback_list = [self.get_callbacks()]
        if self.config.extra_callbacks:
            callback_list.extend(self.config.extra_callbacks)  # type: ignore[arg-type]
        rllib_config["callbacks"] = MultiCallbacks(callback_list)
        rllib_config["env_config"] = self.config.env_config
        now = datetime.now()
        rllib_config["env_config"]["output_date_string"] = f"{now.strftime('%Y%m%d_%H%M%S')}_{socket.gethostname()}"
        rllib_config["create_env_on_driver"] = True
        rllib_config["batch_mode"] = "complete_episodes"

        self._add_git_hashes_to_config(rllib_config)

        if args.debug:
            rllib_config['num_workers'] = 0

    def _add_git_hashes_to_config(self, rllib_config) -> None:
        """adds git hashes (or package version information if git information
        is unavailable) of key modules to rllib_config["env_config"]["git_hash"].
        Key modules are the following:
          - corl,
          - whatever cwd is set to at the time of the function call
            (notionally /opt/project /)
          - any other modules listed in rllib_config["env_config"]["plugin_paths"]

        This information is not actually used by ACT3MultiAgentEnv;
        however, putting it in the env_config means that this
        information is saved to the params.pkl and thus is available
        for later inspection while seeking to understand the
        performance of a trained model.
        """
        try:
            # pattern used below to find root repository paths
            repo_pattern = r"(?P<repopath>.*)\/__init__.py"
            rp = re.compile(repo_pattern)

            corl_pattern = r"corl.*"
            cp0 = re.compile(corl_pattern)
            rllib_config["env_config"]["git_hash"] = dict()

            # store hash on cwd
            cwd = os.getcwd()
            try:
                githash = git.Repo(cwd, search_parent_directories=True).head.object.hexsha
                rllib_config["env_config"]["git_hash"]["cwd"] = githash
                self._logger.info(f"cwd hash: {githash}")
            except git.InvalidGitRepositoryError:
                self._logger.warning("cwd is not a git repo\n")

            # go ahead and strip out corl related things from plugin_path
            plugpath = []
            for item in rllib_config['env_config']['plugin_paths']:
                match0 = cp0.match(item)
                if match0 is None:
                    plugpath.append(item)

            plugpath.append('corl')

            # add git hashes to env_config dictionary
            for module0 in plugpath:
                env_hash_key = module0
                module1 = importlib.import_module(module0)
                modulefile = module1.__file__
                if modulefile is not None:
                    match0 = rp.match(modulefile)
                    if match0 is not None:
                        repo_path = match0.group('repopath')
                        try:
                            githash = git.Repo(repo_path, search_parent_directories=True).head.object.hexsha
                            rllib_config["env_config"]["git_hash"][env_hash_key] = githash
                            self._logger.info(f"{module0} hash: {githash}")
                        except git.InvalidGitRepositoryError:
                            # possibly installed in image but not a git repo
                            # look for version number
                            if hasattr(module1, 'version') and hasattr(module1.version, '__version__'):
                                githash = module1.version.__version__
                                rllib_config["env_config"]["git_hash"][env_hash_key] = githash
                                self._logger.info(f"{module0} hash: {githash}")
                            else:
                                self._logger.warning((f"module: {module0}, repopath: {repo_path}"
                                                      "is invalid git repo\n"))
                                sys.stderr.write((f"module: {module0}, repopath: {repo_path}"
                                                  "is invalid git repo\n"))
        except ValueError:
            warnings.warn("Unable to add the gitlab hash to experiment!!!")

    def get_callbacks(self) -> typing.Type[EnvironmentDefaultCallbacks]:
        """Get the environment callbacks"""
        return EnvironmentDefaultCallbacks

    def _select_rllib_config(self, platform: typing.Optional[str]) -> typing.Dict[str, typing.Any]:
        """Extract the rllib config for the proper computational platform

        Parameters
        ----------
        platform : typing.Optional[str]
            Specification of the computational platform to use, such as "local", "hpc", etc.  This must be present in the rllib_configs.
            If None, the rllib_configs must only have a single entry.

        Returns
        -------
        typing.Dict[str, typing.Any]
            Rllib configuration for the desired computational platform.

        Raises
        ------
        RuntimeError
            The requested computational platform does not exist or None was used when multiple platforms were defined.
        """
        if platform is not None:
            return self.config.rllib_configs[platform]

        if len(self.config.rllib_configs) == 1:
            return self.config.rllib_configs[next(iter(self.config.rllib_configs))]

        raise RuntimeError(f'Invalid rllib_config for platform "{platform}"')

    def _update_ray_config_for_ray_platform(self) -> None:
        """Update the ray configuration for ray platforms
        """
        self.config.ray_config['address'] = 'auto'
        self.config.ray_config['log_to_driver'] = False

    def _enable_episode_parameter_provider_checkpointing(self) -> None:

        base_trainer = self.config.tune_config["run_or_experiment"]

        trainer_class = get_trainable_cls(base_trainer)

        class EpisodeParameterProviderSavingTrainer(trainer_class):  # type: ignore[valid-type, misc]
            """
            Tune Trainer that adds capability to restore
            progress of the EpisodeParameterProvider on restoring training
            progress
            """

            def save_checkpoint(self, checkpoint_path):
                """
                adds additional checkpoint saving functionality
                by also saving any episode parameter providers
                currently running
                """
                tmp = super().save_checkpoint(checkpoint_path)

                checkpoint_folder = pathlib.Path(checkpoint_path)

                # Environment
                epp_name = ACT3MultiAgentEnv.episode_parameter_provider_name
                env = self.workers.local_worker().env
                epp: EpisodeParameterProvider = env.config.epp
                epp.save_checkpoint(checkpoint_folder / epp_name)

                # Agents
                for agent_name, agent_configs in env.agent_dict.items():
                    epp = agent_configs.config.epp
                    epp.save_checkpoint(checkpoint_folder / agent_name)

                return tmp

            def load_checkpoint(self, checkpoint_path):
                """
                adds additional checkpoint loading functionality
                by also loading any episode parameter providers
                with the checkpoint
                """
                super().load_checkpoint(checkpoint_path)

                checkpoint_folder = pathlib.Path(checkpoint_path).parent

                # Environment
                epp_name = ACT3MultiAgentEnv.episode_parameter_provider_name
                env = self.workers.local_worker().env
                epp: EpisodeParameterProvider = env.config.epp
                epp.load_checkpoint(checkpoint_folder / epp_name)

                # Agents
                for agent_name, agent_configs in env.agent_dict.items():
                    epp = agent_configs.config.epp
                    epp.load_checkpoint(checkpoint_folder / agent_name)

        self.config.tune_config["run_or_experiment"] = EpisodeParameterProviderSavingTrainer

    def _add_trial_creator(self):
        """Updates the trial name based on the HPC Job Number and the trial name in the configuration
        """
        if "trial_name_creator" not in self.config.tune_config:
            if self.config.trial_creator_function is None:

                def trial_name_prefix(trial):
                    """
                    Args:
                        trial (Trial): A generated trial object.

                    Returns:
                        trial_name (str): String representation of Trial prefixed
                            by the contents of the environment variable:
                            TRIAL_NAME_PREFIX
                            Or the prefix 'RUN' if none is set.
                    """
                    trial_prefix = os.environ.get('PBS_JOBID', os.environ.get('TRIAL_NAME_PREFIX', ""))
                    trial_name = ""

                    if "TrialName" in self.config.env_config.keys():
                        if trial_prefix:
                            trial_name = "-" + self.config.env_config["TrialName"]
                        else:
                            trial_name = self.config.env_config["TrialName"]
                    return f"{trial_prefix}{trial_name}-{trial}"

                self.config.tune_config["trial_name_creator"] = trial_name_prefix
            else:
                self.config.tune_config["trial_name_creator"] = self.config.trial_creator_function

get_policy_validator: Type[corl.experiments.rllib_experiment.RllibPolicyValidator] property readonly ¤

Return validator

get_validator: Type[corl.experiments.rllib_experiment.RllibExperimentValidator] property readonly ¤

Get the validator for this experiment class, the kwargs sent to the experiment class will be validated using this object and add a self.config attr to the experiment class

get_callbacks(self) ¤

Get the environment callbacks

Source code in corl/experiments/rllib_experiment.py
def get_callbacks(self) -> typing.Type[EnvironmentDefaultCallbacks]:
    """Get the environment callbacks"""
    return EnvironmentDefaultCallbacks

run_experiment(self, args) ¤

Runs the experiment associated with this experiment class

Source code in corl/experiments/rllib_experiment.py
def run_experiment(self, args: argparse.Namespace) -> None:

    rllib_config = self._select_rllib_config(args.compute_platform)

    if args.compute_platform in ['ray']:
        self._update_ray_config_for_ray_platform()

    self._add_trial_creator()

    # This needs to be before the ray cluster is initialized
    if args.debug:
        self.config.ray_config['local_mode'] = True

    ray.init(**self.config.ray_config)

    ray_resources = ray.available_resources()

    auto_configure_rllib_config(rllib_config, self.config.auto_rllib_config_setup, ray_resources)

    self.config.env_config["agents"], self.config.env_config["agent_platforms"] = self.create_agents(
        args.platform_config, args.agent_config
    )

    self.config.env_config["horizon"] = rllib_config["horizon"]

    if args.output:
        self.config.env_config["output_path"] = args.output
        self.config.tune_config["local_dir"] = args.output

    if args.name:
        self.config.env_config["TrialName"] = args.name

    if args.other_platform:
        self.config.env_config["other_platforms"] = self.create_other_platforms(args.other_platform)

    if not self.config.ray_config['local_mode']:
        self.config.env_config['episode_parameter_provider'] = RemoteEpisodeParameterProvider.wrap_epp_factory(
            Factory(**self.config.env_config['episode_parameter_provider']),
            actor_name=ACT3MultiAgentEnv.episode_parameter_provider_name
        )

        for agent_name, agent_configs in self.config.env_config['agents'].items():
            agent_configs.class_config.config['episode_parameter_provider'] = RemoteEpisodeParameterProvider.wrap_epp_factory(
                Factory(**agent_configs.class_config.config['episode_parameter_provider']), agent_name
            )

    self.config.env_config['epp_registry'] = ACT3MultiAgentEnvValidator(**self.config.env_config).epp_registry

    tmp = ACT3MultiAgentEnv(self.config.env_config)
    tmp_as = tmp.action_space
    tmp_os = tmp.observation_space
    tmp_ac = self.config.env_config['agents']

    policies = {
        policy_name: (
            tmp_ac[policy_name].policy_config["policy_class"],
            policy_obs,
            tmp_as[policy_name],
            tmp_ac[policy_name].policy_config["config"]
        )
        for policy_name,
        policy_obs in tmp_os.spaces.items()
        if tmp_ac[policy_name]
    }

    train_policies = [policy_name for policy_name in policies.keys() if tmp_ac[policy_name].policy_config["train"]]

    self._update_rllib_config(rllib_config, train_policies, policies, args)

    self._enable_episode_parameter_provider_checkpointing()

    if args.profile:
        if "stop" not in self.config.tune_config:
            self.config.tune_config["stop"] = {}
        self.config.tune_config["stop"]["training_iteration"] = args.profile_iterations

    search_class = None
    if self.config.hparam_search_class is not None:
        if self.config.hparam_search_config is not None:
            search_class = self.config.hparam_search_class(**self.config.hparam_search_config)
        else:
            search_class = self.config.hparam_search_class()
        search_class.add_algorithm_hparams(rllib_config, self.config.tune_config)

    tune.run(
        config=rllib_config,
        **self.config.tune_config,
    )

RllibExperimentValidator (BaseExperimentValidator) pydantic-model ¤

ray_config: dictionary to be fed into ray init, validated by ray init call env_config: environment configuration, validated by environment class

a mapping of compute platforms to rllib configs, see apply_patches_rllib_configs

    for information on the typing

tune_config: kwarg arguments to be sent to tune for this experiment extra_callbacks: extra rllib callbacks that will be added to the callback list

this function will overwrite the default trial string creator

            and allow more fine tune trial name creators

Source code in corl/experiments/rllib_experiment.py
class RllibExperimentValidator(BaseExperimentValidator):
    """
    ray_config: dictionary to be fed into ray init, validated by ray init call
    env_config: environment configuration, validated by environment class
    rllib_configs: a mapping of compute platforms to rllib configs, see apply_patches_rllib_configs
                    for information on the typing
    tune_config: kwarg arguments to be sent to tune for this experiment
    extra_callbacks: extra rllib callbacks that will be added to the callback list
    trial_creator_function: this function will overwrite the default trial string creator
                            and allow more fine tune trial name creators
    """
    ray_config: typing.Dict[str, typing.Any]
    env_config: EnvContext
    rllib_configs: typing.Dict[str, typing.Dict[str, typing.Any]]
    tune_config: typing.Dict[str, typing.Any]
    trainable_config: typing.Optional[typing.Dict[str, typing.Any]]
    auto_rllib_config_setup: AutoRllibConfigSetup = AutoRllibConfigSetup()
    hparam_search_class: typing.Optional[PyObject]
    hparam_search_config: typing.Optional[typing.Dict[str, typing.Any]]
    extra_callbacks: typing.Optional[typing.List[PyObject]]
    trial_creator_function: typing.Optional[PyObject]

    @validator('rllib_configs', pre=True)
    def apply_patches_rllib_configs(cls, v):  # pylint: disable=no-self-argument, no-self-use
        """
        The dictionary of rllib configs may come in as a dictionary of
        lists of dictionaries, this function is responsible for collapsing
        the list down to a typing.Dict[str, typing.Dict[str, typing.Any]]
        instead of
        typing.Dict[str, typing.Union[typing.List[typing.Dict[str, typing.Any]], typing.Dict[str, typing.Any]]]

        Raises:
            RuntimeError: [description]

        Returns:
            [type] -- [description]
        """
        if not isinstance(v, dict):
            raise RuntimeError("rllib_configs are expected to be a dict of keys to different compute configs")
        rllib_configs = {}
        for key, value in v.items():
            if isinstance(value, list):
                rllib_configs[key] = apply_patches(value)
            elif isinstance(value, dict):
                rllib_configs[key] = value
        return rllib_configs

    @validator('ray_config', 'tune_config', 'trainable_config', 'env_config', pre=True)
    def apply_patches_configs(cls, v):  # pylint: disable=no-self-argument, no-self-use
        """
        reduces a field from
        typing.Union[typing.List[typing.Dict[str, typing.Any]], typing.Dict[str, typing.Any]]]
        to
        typing.Dict[str, typing.Any]

        by patching the first dictionary in the list with each patch afterwards

        Returns:
            [type] -- [description]
        """
        if isinstance(v, list):
            v = apply_patches(v)
        return v

    @validator('env_config')
    def no_horizon(cls, v):
        """Ensure that the horizon is not specified in the env_config."""
        if 'horizon' in v:
            raise ValueError('Cannot specify the horizon in the env_config')
        return v

apply_patches_configs(v) classmethod ¤

reduces a field from typing.Union[typing.List[typing.Dict[str, typing.Any]], typing.Dict[str, typing.Any]]] to typing.Dict[str, typing.Any]

by patching the first dictionary in the list with each patch afterwards

Returns:

Type Description

[type] -- [description]

Source code in corl/experiments/rllib_experiment.py
@validator('ray_config', 'tune_config', 'trainable_config', 'env_config', pre=True)
def apply_patches_configs(cls, v):  # pylint: disable=no-self-argument, no-self-use
    """
    reduces a field from
    typing.Union[typing.List[typing.Dict[str, typing.Any]], typing.Dict[str, typing.Any]]]
    to
    typing.Dict[str, typing.Any]

    by patching the first dictionary in the list with each patch afterwards

    Returns:
        [type] -- [description]
    """
    if isinstance(v, list):
        v = apply_patches(v)
    return v

apply_patches_rllib_configs(v) classmethod ¤

The dictionary of rllib configs may come in as a dictionary of lists of dictionaries, this function is responsible for collapsing the list down to a typing.Dict[str, typing.Dict[str, typing.Any]] instead of typing.Dict[str, typing.Union[typing.List[typing.Dict[str, typing.Any]], typing.Dict[str, typing.Any]]]

Exceptions:

Type Description
RuntimeError

[description]

Returns:

Type Description

[type] -- [description]

Source code in corl/experiments/rllib_experiment.py
@validator('rllib_configs', pre=True)
def apply_patches_rllib_configs(cls, v):  # pylint: disable=no-self-argument, no-self-use
    """
    The dictionary of rllib configs may come in as a dictionary of
    lists of dictionaries, this function is responsible for collapsing
    the list down to a typing.Dict[str, typing.Dict[str, typing.Any]]
    instead of
    typing.Dict[str, typing.Union[typing.List[typing.Dict[str, typing.Any]], typing.Dict[str, typing.Any]]]

    Raises:
        RuntimeError: [description]

    Returns:
        [type] -- [description]
    """
    if not isinstance(v, dict):
        raise RuntimeError("rllib_configs are expected to be a dict of keys to different compute configs")
    rllib_configs = {}
    for key, value in v.items():
        if isinstance(value, list):
            rllib_configs[key] = apply_patches(value)
        elif isinstance(value, dict):
            rllib_configs[key] = value
    return rllib_configs

no_horizon(v) classmethod ¤

Ensure that the horizon is not specified in the env_config.

Source code in corl/experiments/rllib_experiment.py
@validator('env_config')
def no_horizon(cls, v):
    """Ensure that the horizon is not specified in the env_config."""
    if 'horizon' in v:
        raise ValueError('Cannot specify the horizon in the env_config')
    return v

RllibPolicyValidator (BasePolicyValidator) pydantic-model ¤

policy_class: callable policy class None will use default from trainer train: should this policy be trained

Exceptions:

Type Description
RuntimeError

[description]

Returns:

Type Description

[type] -- [description]

Source code in corl/experiments/rllib_experiment.py
class RllibPolicyValidator(BasePolicyValidator):
    """
    policy_class: callable policy class None will use default from trainer
    train: should this policy be trained
    Arguments:
        BaseModel {[type]} -- [description]

    Raises:
        RuntimeError: [description]

    Returns:
        [type] -- [description]
    """
    config: typing.Dict[str, typing.Any] = {}
    policy_class: typing.Union[PyObject, None] = None
    train: bool = True