Skip to content

Default env rllib callbacks


Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.

This is a US Government Work not subject to copyright protection in the US.

The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.


EnvironmentDefaultCallbacks

EnvironmentDefaultCallbacks (DefaultCallbacks) ¤

This is the default class for callbacks to be use in the Environment class. To make your own custom callbacks set the EnvironmentCallbacks in your derived class to be a class that subclasses this class (EnvironmentDefaultCallbacks)

Make sure you call the super function for all derived functions or else there will be unexpected callback behavior

Source code in corl/environment/default_env_rllib_callbacks.py
class EnvironmentDefaultCallbacks(DefaultCallbacks):
    """
    This is the default class for callbacks to be use in the Environment class.
    To make your own custom callbacks set the EnvironmentCallbacks in your derived class
    to be a class that subclasses this class (EnvironmentDefaultCallbacks)

    Make sure you call the super function for all derived functions or else there will be unexpected callback behavior
    """

    DEFAULT_METRIC_OPS = {
        "min": np.min,
        "max": np.max,
        "median": np.median,
        "mean": np.mean,
        "var": np.var,
        "std": np.std,
        "sum": np.sum,
        "nonzero": np.count_nonzero,
    }

    def on_episode_start(
        self,
        *,
        worker,
        base_env: BaseEnv,
        policies: typing.Dict[PolicyID, Policy],
        episode: Episode,
        **kwargs,
    ) -> None:
        """Callback run on the rollout worker before each episode starts.

        Args:
            worker: Reference to the current rollout worker.
            base_env: BaseEnv running the episode. The underlying
                sub environment objects can be retrieved by calling
                `base_env.get_sub_environments()`.
            policies: Mapping of policy id to policy objects. In single
                agent mode there will only be a single "default" policy.
            episode: Episode object which contains the episode's
                state. You can use the `episode.user_data` dict to store
                temporary data, and `episode.custom_metrics` to store custom
                metrics for the episode.
            kwargs: Forward compatibility placeholder.
        """
        super().on_episode_start(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)

        episode.user_data["rewards_accumulator"] = defaultdict(float)  # default dict with default value of 0.0

    def on_episode_step(
        self,
        *,
        worker,
        base_env: BaseEnv,
        policies: typing.Optional[typing.Dict[PolicyID, Policy]] = None,
        episode: Episode,
        **kwargs,
    ) -> None:
        """Runs on each episode step.

        Args:
            worker: Reference to the current rollout worker.
            base_env: BaseEnv running the episode. The underlying
                sub environment objects can be retrieved by calling
                `base_env.get_sub_environments()`.
            policies: Mapping of policy id to policy objects.
                In single agent mode there will only be a single
                "default_policy".
            episode: Episode object which contains episode
                state. You can use the `episode.user_data` dict to store
                temporary data, and `episode.custom_metrics` to store custom
                metrics for the episode.
            kwargs: Forward compatibility placeholder.
        """

        super().on_episode_step(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)

        env = base_env.get_sub_environments()[episode.env_id]

        if env.reward_info:
            rewards_accumulator = episode.user_data["rewards_accumulator"]
            for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
                key = f"rewards_cumulative/{reward_name}"
                rewards_accumulator[key] += reward_val

    def on_episode_end(  # pylint: disable=too-many-branches
        self,
        *,
        worker,
        base_env: BaseEnv,  # pylint: disable=too-many-branches
        policies: typing.Dict[PolicyID, Policy],
        episode,
        **kwargs
    ) -> None:
        """
        on_episode_end  stores the custom metrics in RLLIB. Note this is on a per glue basis.

        1. read the training information for the current episode
        2. For each metric in each platform interface in each environment
           update metric container

        Parameters
        ----------
        worker: RolloutWorker
            Reference to the current rollout worker.
        base_env: BaseEnv
            BaseEnv running the episode. The underlying
            env object can be gotten by calling base_env.get_sub_environments().
        policies: dict
            Mapping of policy id to policy objects. In single
            agent mode there will only be a single "default" policy.
        episode: MultiAgentEpisode
            Episode object which contains episode
            state. You can use the `episode.user_data` dict to store
            temporary data, and `episode.custom_metrics` to store custom
            metrics for the episode.
        """

        # Issue a warning if the episode is short
        # This, should get outputing in the log and may help identify any setup issues
        if episode.length < SHORT_EPISODE_THRESHOLD:
            msg = f"Episode {str(episode.episode_id)} length {episode.length} is less than warn threshold {str(SHORT_EPISODE_THRESHOLD)}"
            if "params" in episode.user_data:
                msg += "\nparams:\n"
                for key in episode.user_data["params"].keys():
                    msg += f"  {key}: {str(episode.user_data['params'][key])}"
                    msg += "\n"
            else:
                msg += "\nParams not provied in episode user_data"
            warnings.warn(msg)

        env = base_env.get_sub_environments()[episode.env_id]
        if env.glue_info:  # pylint: disable=too-many-nested-blocks
            for glue_name, metric_val in flatten(env.glue_info, reducer="path").items():
                episode.custom_metrics[glue_name] = metric_val

        if env.reward_info:
            for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
                key = f"rewards/{reward_name}"
                episode.custom_metrics[key] = reward_val

        log_done_info(env, episode)

        log_done_status(env, episode)

        # Variables
        for key, value in flatten(env.local_variable_store, reducer="path").items():
            try:
                episode.custom_metrics[f'variable/env/{key}'] = float(value.value)
            except ValueError:
                pass

        for agent_name, agent_data in env.agent_dict.items():
            for key, value in flatten(agent_data.local_variable_store, reducer="path").items():
                try:
                    episode.custom_metrics[f'variable/{agent_name}/{key}'] = float(value.value)
                except ValueError:
                    pass

        # Episode Parameter Providers
        metrics = env.config.epp.compute_metrics()
        for k, v in metrics.items():
            episode.custom_metrics[f'adr/env/{k}'] = v

        for agent_name, agent_data in env.agent_dict.items():
            metrics = agent_data.config.epp.compute_metrics()
            for k, v in metrics.items():
                episode.custom_metrics[f'adr/{agent_name}/{k}'] = v

        # Cumulative Rewards
        for key, value in episode.user_data["rewards_accumulator"].items():
            episode.custom_metrics[key] = value

    def on_postprocess_trajectory(
        self,
        *,
        worker,
        episode,
        agent_id: AgentID,
        policy_id: PolicyID,
        policies: typing.Dict[PolicyID, Policy],
        postprocessed_batch: SampleBatch,
        original_batches: typing.Dict[AgentID, typing.Tuple[Policy, SampleBatch]],
        **kwargs
    ) -> None:
        """
        Called immediately after a policy's postprocess_fn is called.

        You can use this callback to do additional postprocessing for a policy,
        including looking at the trajectory data of other agents in multi-agent
        settings.

        Parameters
        ----------
        worker: RolloutWorker
            Reference to the current rollout worker.
        episode: MultiAgentEpisode
            Episode object.
        agent_id: str
            Id of the current agent.
        policy_id: str
            Id of the current policy for the agent.
        policies: dict
            Mapping of policy id to policy objects. In single
            agent mode there will only be a single "default" policy.
        postprocessed_batch: SampleBatch
            The postprocessed sample batch
            for this agent. You can mutate this object to apply your own
            trajectory postprocessing.
        original_batches: dict
            Mapping of agents to their unpostprocessed
            trajectory data. You should not mutate this object.
        """
        super().on_postprocess_trajectory(
            worker=worker,
            episode=episode,
            agent_id=agent_id,
            policy_id=policy_id,
            policies=policies,
            postprocessed_batch=postprocessed_batch,
            original_batches=original_batches
        )
        episode.worker.foreach_env(
            lambda env: env.post_process_trajectory(agent_id, postprocessed_batch, episode, policies[policy_id])
        )  # type: ignore

    def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
        """
        Called at the end of Trainable.train().

        Parameters
        ----------
        trainer: Trainer
            Current trainer instance.
        result: dict
            Dict of results returned from trainer.train() call.
            You can mutate this object to add additional metrics.
        """
        rng, _ = seeding.np_random(seed=trainer.iteration)

        # Environment EPP
        assert result['config']['env'] == ACT3MultiAgentEnv.__name__

        for epp in result['config']['env_config']['epp_registry'].values():
            epp.update(result, rng)

on_episode_end(self, *, worker, base_env, policies, episode, **kwargs) ¤

on_episode_end stores the custom metrics in RLLIB. Note this is on a per glue basis.

  1. read the training information for the current episode
  2. For each metric in each platform interface in each environment update metric container
Parameters¤

RolloutWorker

Reference to the current rollout worker.

BaseEnv

BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_sub_environments().

dict

Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.

MultiAgentEpisode

Episode object which contains episode state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.

Source code in corl/environment/default_env_rllib_callbacks.py
def on_episode_end(  # pylint: disable=too-many-branches
    self,
    *,
    worker,
    base_env: BaseEnv,  # pylint: disable=too-many-branches
    policies: typing.Dict[PolicyID, Policy],
    episode,
    **kwargs
) -> None:
    """
    on_episode_end  stores the custom metrics in RLLIB. Note this is on a per glue basis.

    1. read the training information for the current episode
    2. For each metric in each platform interface in each environment
       update metric container

    Parameters
    ----------
    worker: RolloutWorker
        Reference to the current rollout worker.
    base_env: BaseEnv
        BaseEnv running the episode. The underlying
        env object can be gotten by calling base_env.get_sub_environments().
    policies: dict
        Mapping of policy id to policy objects. In single
        agent mode there will only be a single "default" policy.
    episode: MultiAgentEpisode
        Episode object which contains episode
        state. You can use the `episode.user_data` dict to store
        temporary data, and `episode.custom_metrics` to store custom
        metrics for the episode.
    """

    # Issue a warning if the episode is short
    # This, should get outputing in the log and may help identify any setup issues
    if episode.length < SHORT_EPISODE_THRESHOLD:
        msg = f"Episode {str(episode.episode_id)} length {episode.length} is less than warn threshold {str(SHORT_EPISODE_THRESHOLD)}"
        if "params" in episode.user_data:
            msg += "\nparams:\n"
            for key in episode.user_data["params"].keys():
                msg += f"  {key}: {str(episode.user_data['params'][key])}"
                msg += "\n"
        else:
            msg += "\nParams not provied in episode user_data"
        warnings.warn(msg)

    env = base_env.get_sub_environments()[episode.env_id]
    if env.glue_info:  # pylint: disable=too-many-nested-blocks
        for glue_name, metric_val in flatten(env.glue_info, reducer="path").items():
            episode.custom_metrics[glue_name] = metric_val

    if env.reward_info:
        for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
            key = f"rewards/{reward_name}"
            episode.custom_metrics[key] = reward_val

    log_done_info(env, episode)

    log_done_status(env, episode)

    # Variables
    for key, value in flatten(env.local_variable_store, reducer="path").items():
        try:
            episode.custom_metrics[f'variable/env/{key}'] = float(value.value)
        except ValueError:
            pass

    for agent_name, agent_data in env.agent_dict.items():
        for key, value in flatten(agent_data.local_variable_store, reducer="path").items():
            try:
                episode.custom_metrics[f'variable/{agent_name}/{key}'] = float(value.value)
            except ValueError:
                pass

    # Episode Parameter Providers
    metrics = env.config.epp.compute_metrics()
    for k, v in metrics.items():
        episode.custom_metrics[f'adr/env/{k}'] = v

    for agent_name, agent_data in env.agent_dict.items():
        metrics = agent_data.config.epp.compute_metrics()
        for k, v in metrics.items():
            episode.custom_metrics[f'adr/{agent_name}/{k}'] = v

    # Cumulative Rewards
    for key, value in episode.user_data["rewards_accumulator"].items():
        episode.custom_metrics[key] = value

on_episode_start(self, *, worker, base_env, policies, episode, **kwargs) ¤

Callback run on the rollout worker before each episode starts.

Parameters:

Name Type Description Default
worker

Reference to the current rollout worker.

required
base_env BaseEnv

BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling base_env.get_sub_environments().

required
policies Dict[str, ray.rllib.policy.policy.Policy]

Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.

required
episode Episode

Episode object which contains the episode's state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.

required
kwargs

Forward compatibility placeholder.

{}
Source code in corl/environment/default_env_rllib_callbacks.py
def on_episode_start(
    self,
    *,
    worker,
    base_env: BaseEnv,
    policies: typing.Dict[PolicyID, Policy],
    episode: Episode,
    **kwargs,
) -> None:
    """Callback run on the rollout worker before each episode starts.

    Args:
        worker: Reference to the current rollout worker.
        base_env: BaseEnv running the episode. The underlying
            sub environment objects can be retrieved by calling
            `base_env.get_sub_environments()`.
        policies: Mapping of policy id to policy objects. In single
            agent mode there will only be a single "default" policy.
        episode: Episode object which contains the episode's
            state. You can use the `episode.user_data` dict to store
            temporary data, and `episode.custom_metrics` to store custom
            metrics for the episode.
        kwargs: Forward compatibility placeholder.
    """
    super().on_episode_start(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)

    episode.user_data["rewards_accumulator"] = defaultdict(float)  # default dict with default value of 0.0

on_episode_step(self, *, worker, base_env, policies=None, episode, **kwargs) ¤

Runs on each episode step.

Parameters:

Name Type Description Default
worker

Reference to the current rollout worker.

required
base_env BaseEnv

BaseEnv running the episode. The underlying sub environment objects can be retrieved by calling base_env.get_sub_environments().

required
policies Optional[Dict[str, ray.rllib.policy.policy.Policy]]

Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy".

None
episode Episode

Episode object which contains episode state. You can use the episode.user_data dict to store temporary data, and episode.custom_metrics to store custom metrics for the episode.

required
kwargs

Forward compatibility placeholder.

{}
Source code in corl/environment/default_env_rllib_callbacks.py
def on_episode_step(
    self,
    *,
    worker,
    base_env: BaseEnv,
    policies: typing.Optional[typing.Dict[PolicyID, Policy]] = None,
    episode: Episode,
    **kwargs,
) -> None:
    """Runs on each episode step.

    Args:
        worker: Reference to the current rollout worker.
        base_env: BaseEnv running the episode. The underlying
            sub environment objects can be retrieved by calling
            `base_env.get_sub_environments()`.
        policies: Mapping of policy id to policy objects.
            In single agent mode there will only be a single
            "default_policy".
        episode: Episode object which contains episode
            state. You can use the `episode.user_data` dict to store
            temporary data, and `episode.custom_metrics` to store custom
            metrics for the episode.
        kwargs: Forward compatibility placeholder.
    """

    super().on_episode_step(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)

    env = base_env.get_sub_environments()[episode.env_id]

    if env.reward_info:
        rewards_accumulator = episode.user_data["rewards_accumulator"]
        for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
            key = f"rewards_cumulative/{reward_name}"
            rewards_accumulator[key] += reward_val

on_postprocess_trajectory(self, *, worker, episode, agent_id, policy_id, policies, postprocessed_batch, original_batches, **kwargs) ¤

Called immediately after a policy's postprocess_fn is called.

You can use this callback to do additional postprocessing for a policy, including looking at the trajectory data of other agents in multi-agent settings.

Parameters¤

RolloutWorker

Reference to the current rollout worker.

MultiAgentEpisode

Episode object.

str

Id of the current agent.

str

Id of the current policy for the agent.

dict

Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.

SampleBatch

The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing.

dict

Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object.

Source code in corl/environment/default_env_rllib_callbacks.py
def on_postprocess_trajectory(
    self,
    *,
    worker,
    episode,
    agent_id: AgentID,
    policy_id: PolicyID,
    policies: typing.Dict[PolicyID, Policy],
    postprocessed_batch: SampleBatch,
    original_batches: typing.Dict[AgentID, typing.Tuple[Policy, SampleBatch]],
    **kwargs
) -> None:
    """
    Called immediately after a policy's postprocess_fn is called.

    You can use this callback to do additional postprocessing for a policy,
    including looking at the trajectory data of other agents in multi-agent
    settings.

    Parameters
    ----------
    worker: RolloutWorker
        Reference to the current rollout worker.
    episode: MultiAgentEpisode
        Episode object.
    agent_id: str
        Id of the current agent.
    policy_id: str
        Id of the current policy for the agent.
    policies: dict
        Mapping of policy id to policy objects. In single
        agent mode there will only be a single "default" policy.
    postprocessed_batch: SampleBatch
        The postprocessed sample batch
        for this agent. You can mutate this object to apply your own
        trajectory postprocessing.
    original_batches: dict
        Mapping of agents to their unpostprocessed
        trajectory data. You should not mutate this object.
    """
    super().on_postprocess_trajectory(
        worker=worker,
        episode=episode,
        agent_id=agent_id,
        policy_id=policy_id,
        policies=policies,
        postprocessed_batch=postprocessed_batch,
        original_batches=original_batches
    )
    episode.worker.foreach_env(
        lambda env: env.post_process_trajectory(agent_id, postprocessed_batch, episode, policies[policy_id])
    )  # type: ignore

on_train_result(self, *, trainer, result, **kwargs) ¤

Called at the end of Trainable.train().

Parameters¤

Trainer

Current trainer instance.

dict

Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics.

Source code in corl/environment/default_env_rllib_callbacks.py
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
    """
    Called at the end of Trainable.train().

    Parameters
    ----------
    trainer: Trainer
        Current trainer instance.
    result: dict
        Dict of results returned from trainer.train() call.
        You can mutate this object to add additional metrics.
    """
    rng, _ = seeding.np_random(seed=trainer.iteration)

    # Environment EPP
    assert result['config']['env'] == ACT3MultiAgentEnv.__name__

    for epp in result['config']['env_config']['epp_registry'].values():
        epp.update(result, rng)

log_done_info(env, episode) ¤

Log done info to done_results/{platform}/{done_name}

Source code in corl/environment/default_env_rllib_callbacks.py
def log_done_info(env, episode):
    """
    Log done info to done_results/{platform}/{done_name}
    """
    if env.done_info:
        platform_done_info: typing.Dict[str, typing.Dict[str, bool]] = {}
        for agent_id, agent_data in env.done_info.items():
            plat_name = env.agent_dict[agent_id].platform_name
            platform_done_info.setdefault(plat_name, {})
            for done_name, done_data in agent_data.items():
                if done_name == '__all__':
                    continue
                platform_done_info[plat_name].setdefault(done_name, False)
                if done_data[plat_name] is not None:
                    platform_done_info[plat_name][done_name] |= done_data[plat_name]

        for plat_name, done_data in platform_done_info.items():
            for done_name, done_value in done_data.items():
                episode.custom_metrics[f'done_results/{plat_name}/{done_name}'] = int(done_value)
            episode.custom_metrics[f'done_results/{plat_name}/NoDone'] = int(not any(platform_done_info[plat_name].values()))

log_done_status(env, episode) ¤

Log done status codes to done_status/{platform}/{status}

Source code in corl/environment/default_env_rllib_callbacks.py
def log_done_status(env, episode):
    """
    Log done status codes to done_status/{platform}/{status}
    """

    if env.state.episode_state:
        for platform, data in env.state.episode_state.items():
            if not data:
                continue

            episode_codes = set(data.values())

            for status in DoneStatusCodes:
                metric_key = f"done_status/{platform}/{status}"
                if status in episode_codes:
                    episode.custom_metrics[metric_key] = 1
                else:
                    episode.custom_metrics[metric_key] = 0