Default env rllib callbacks
Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.
This is a US Government Work not subject to copyright protection in the US.
The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.
EnvironmentDefaultCallbacks
EnvironmentDefaultCallbacks (DefaultCallbacks)
¤
This is the default class for callbacks to be use in the Environment class. To make your own custom callbacks set the EnvironmentCallbacks in your derived class to be a class that subclasses this class (EnvironmentDefaultCallbacks)
Make sure you call the super function for all derived functions or else there will be unexpected callback behavior
Source code in corl/environment/default_env_rllib_callbacks.py
class EnvironmentDefaultCallbacks(DefaultCallbacks):
"""
This is the default class for callbacks to be use in the Environment class.
To make your own custom callbacks set the EnvironmentCallbacks in your derived class
to be a class that subclasses this class (EnvironmentDefaultCallbacks)
Make sure you call the super function for all derived functions or else there will be unexpected callback behavior
"""
DEFAULT_METRIC_OPS = {
"min": np.min,
"max": np.max,
"median": np.median,
"mean": np.mean,
"var": np.var,
"std": np.std,
"sum": np.sum,
"nonzero": np.count_nonzero,
}
def on_episode_start(
self,
*,
worker,
base_env: BaseEnv,
policies: typing.Dict[PolicyID, Policy],
episode: Episode,
**kwargs,
) -> None:
"""Callback run on the rollout worker before each episode starts.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
super().on_episode_start(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)
episode.user_data["rewards_accumulator"] = defaultdict(float) # default dict with default value of 0.0
def on_episode_step(
self,
*,
worker,
base_env: BaseEnv,
policies: typing.Optional[typing.Dict[PolicyID, Policy]] = None,
episode: Episode,
**kwargs,
) -> None:
"""Runs on each episode step.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects.
In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
super().on_episode_step(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)
env = base_env.get_sub_environments()[episode.env_id]
if env.reward_info:
rewards_accumulator = episode.user_data["rewards_accumulator"]
for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
key = f"rewards_cumulative/{reward_name}"
rewards_accumulator[key] += reward_val
def on_episode_end( # pylint: disable=too-many-branches
self,
*,
worker,
base_env: BaseEnv, # pylint: disable=too-many-branches
policies: typing.Dict[PolicyID, Policy],
episode,
**kwargs
) -> None:
"""
on_episode_end stores the custom metrics in RLLIB. Note this is on a per glue basis.
1. read the training information for the current episode
2. For each metric in each platform interface in each environment
update metric container
Parameters
----------
worker: RolloutWorker
Reference to the current rollout worker.
base_env: BaseEnv
BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_sub_environments().
policies: dict
Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: MultiAgentEpisode
Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
"""
# Issue a warning if the episode is short
# This, should get outputing in the log and may help identify any setup issues
if episode.length < SHORT_EPISODE_THRESHOLD:
msg = f"Episode {str(episode.episode_id)} length {episode.length} is less than warn threshold {str(SHORT_EPISODE_THRESHOLD)}"
if "params" in episode.user_data:
msg += "\nparams:\n"
for key in episode.user_data["params"].keys():
msg += f" {key}: {str(episode.user_data['params'][key])}"
msg += "\n"
else:
msg += "\nParams not provied in episode user_data"
warnings.warn(msg)
env = base_env.get_sub_environments()[episode.env_id]
if env.glue_info: # pylint: disable=too-many-nested-blocks
for glue_name, metric_val in flatten(env.glue_info, reducer="path").items():
episode.custom_metrics[glue_name] = metric_val
if env.reward_info:
for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
key = f"rewards/{reward_name}"
episode.custom_metrics[key] = reward_val
log_done_info(env, episode)
log_done_status(env, episode)
# Variables
for key, value in flatten(env.local_variable_store, reducer="path").items():
try:
episode.custom_metrics[f'variable/env/{key}'] = float(value.value)
except ValueError:
pass
for agent_name, agent_data in env.agent_dict.items():
for key, value in flatten(agent_data.local_variable_store, reducer="path").items():
try:
episode.custom_metrics[f'variable/{agent_name}/{key}'] = float(value.value)
except ValueError:
pass
# Episode Parameter Providers
metrics = env.config.epp.compute_metrics()
for k, v in metrics.items():
episode.custom_metrics[f'adr/env/{k}'] = v
for agent_name, agent_data in env.agent_dict.items():
metrics = agent_data.config.epp.compute_metrics()
for k, v in metrics.items():
episode.custom_metrics[f'adr/{agent_name}/{k}'] = v
# Cumulative Rewards
for key, value in episode.user_data["rewards_accumulator"].items():
episode.custom_metrics[key] = value
def on_postprocess_trajectory(
self,
*,
worker,
episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: typing.Dict[PolicyID, Policy],
postprocessed_batch: SampleBatch,
original_batches: typing.Dict[AgentID, typing.Tuple[Policy, SampleBatch]],
**kwargs
) -> None:
"""
Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy,
including looking at the trajectory data of other agents in multi-agent
settings.
Parameters
----------
worker: RolloutWorker
Reference to the current rollout worker.
episode: MultiAgentEpisode
Episode object.
agent_id: str
Id of the current agent.
policy_id: str
Id of the current policy for the agent.
policies: dict
Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
postprocessed_batch: SampleBatch
The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
original_batches: dict
Mapping of agents to their unpostprocessed
trajectory data. You should not mutate this object.
"""
super().on_postprocess_trajectory(
worker=worker,
episode=episode,
agent_id=agent_id,
policy_id=policy_id,
policies=policies,
postprocessed_batch=postprocessed_batch,
original_batches=original_batches
)
episode.worker.foreach_env(
lambda env: env.post_process_trajectory(agent_id, postprocessed_batch, episode, policies[policy_id])
) # type: ignore
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
"""
Called at the end of Trainable.train().
Parameters
----------
trainer: Trainer
Current trainer instance.
result: dict
Dict of results returned from trainer.train() call.
You can mutate this object to add additional metrics.
"""
rng, _ = seeding.np_random(seed=trainer.iteration)
# Environment EPP
assert result['config']['env'] == ACT3MultiAgentEnv.__name__
for epp in result['config']['env_config']['epp_registry'].values():
epp.update(result, rng)
on_episode_end(self, *, worker, base_env, policies, episode, **kwargs)
¤
on_episode_end stores the custom metrics in RLLIB. Note this is on a per glue basis.
- read the training information for the current episode
- For each metric in each platform interface in each environment update metric container
Parameters¤
RolloutWorker
Reference to the current rollout worker.
BaseEnv
BaseEnv running the episode. The underlying env object can be gotten by calling base_env.get_sub_environments().
dict
Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.
MultiAgentEpisode
Episode object which contains episode
state. You can use the episode.user_data
dict to store
temporary data, and episode.custom_metrics
to store custom
metrics for the episode.
Source code in corl/environment/default_env_rllib_callbacks.py
def on_episode_end( # pylint: disable=too-many-branches
self,
*,
worker,
base_env: BaseEnv, # pylint: disable=too-many-branches
policies: typing.Dict[PolicyID, Policy],
episode,
**kwargs
) -> None:
"""
on_episode_end stores the custom metrics in RLLIB. Note this is on a per glue basis.
1. read the training information for the current episode
2. For each metric in each platform interface in each environment
update metric container
Parameters
----------
worker: RolloutWorker
Reference to the current rollout worker.
base_env: BaseEnv
BaseEnv running the episode. The underlying
env object can be gotten by calling base_env.get_sub_environments().
policies: dict
Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: MultiAgentEpisode
Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
"""
# Issue a warning if the episode is short
# This, should get outputing in the log and may help identify any setup issues
if episode.length < SHORT_EPISODE_THRESHOLD:
msg = f"Episode {str(episode.episode_id)} length {episode.length} is less than warn threshold {str(SHORT_EPISODE_THRESHOLD)}"
if "params" in episode.user_data:
msg += "\nparams:\n"
for key in episode.user_data["params"].keys():
msg += f" {key}: {str(episode.user_data['params'][key])}"
msg += "\n"
else:
msg += "\nParams not provied in episode user_data"
warnings.warn(msg)
env = base_env.get_sub_environments()[episode.env_id]
if env.glue_info: # pylint: disable=too-many-nested-blocks
for glue_name, metric_val in flatten(env.glue_info, reducer="path").items():
episode.custom_metrics[glue_name] = metric_val
if env.reward_info:
for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
key = f"rewards/{reward_name}"
episode.custom_metrics[key] = reward_val
log_done_info(env, episode)
log_done_status(env, episode)
# Variables
for key, value in flatten(env.local_variable_store, reducer="path").items():
try:
episode.custom_metrics[f'variable/env/{key}'] = float(value.value)
except ValueError:
pass
for agent_name, agent_data in env.agent_dict.items():
for key, value in flatten(agent_data.local_variable_store, reducer="path").items():
try:
episode.custom_metrics[f'variable/{agent_name}/{key}'] = float(value.value)
except ValueError:
pass
# Episode Parameter Providers
metrics = env.config.epp.compute_metrics()
for k, v in metrics.items():
episode.custom_metrics[f'adr/env/{k}'] = v
for agent_name, agent_data in env.agent_dict.items():
metrics = agent_data.config.epp.compute_metrics()
for k, v in metrics.items():
episode.custom_metrics[f'adr/{agent_name}/{k}'] = v
# Cumulative Rewards
for key, value in episode.user_data["rewards_accumulator"].items():
episode.custom_metrics[key] = value
on_episode_start(self, *, worker, base_env, policies, episode, **kwargs)
¤
Callback run on the rollout worker before each episode starts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
Reference to the current rollout worker. |
required | |
base_env |
BaseEnv |
BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
|
required |
policies |
Dict[str, ray.rllib.policy.policy.Policy] |
Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy. |
required |
episode |
Episode |
Episode object which contains the episode's
state. You can use the |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in corl/environment/default_env_rllib_callbacks.py
def on_episode_start(
self,
*,
worker,
base_env: BaseEnv,
policies: typing.Dict[PolicyID, Policy],
episode: Episode,
**kwargs,
) -> None:
"""Callback run on the rollout worker before each episode starts.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
episode: Episode object which contains the episode's
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
super().on_episode_start(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)
episode.user_data["rewards_accumulator"] = defaultdict(float) # default dict with default value of 0.0
on_episode_step(self, *, worker, base_env, policies=None, episode, **kwargs)
¤
Runs on each episode step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
worker |
Reference to the current rollout worker. |
required | |
base_env |
BaseEnv |
BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
|
required |
policies |
Optional[Dict[str, ray.rllib.policy.policy.Policy]] |
Mapping of policy id to policy objects. In single agent mode there will only be a single "default_policy". |
None |
episode |
Episode |
Episode object which contains episode
state. You can use the |
required |
kwargs |
Forward compatibility placeholder. |
{} |
Source code in corl/environment/default_env_rllib_callbacks.py
def on_episode_step(
self,
*,
worker,
base_env: BaseEnv,
policies: typing.Optional[typing.Dict[PolicyID, Policy]] = None,
episode: Episode,
**kwargs,
) -> None:
"""Runs on each episode step.
Args:
worker: Reference to the current rollout worker.
base_env: BaseEnv running the episode. The underlying
sub environment objects can be retrieved by calling
`base_env.get_sub_environments()`.
policies: Mapping of policy id to policy objects.
In single agent mode there will only be a single
"default_policy".
episode: Episode object which contains episode
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
kwargs: Forward compatibility placeholder.
"""
super().on_episode_step(worker=worker, base_env=base_env, policies=policies, episode=episode, **kwargs)
env = base_env.get_sub_environments()[episode.env_id]
if env.reward_info:
rewards_accumulator = episode.user_data["rewards_accumulator"]
for reward_name, reward_val in flatten(env.reward_info, reducer="path").items():
key = f"rewards_cumulative/{reward_name}"
rewards_accumulator[key] += reward_val
on_postprocess_trajectory(self, *, worker, episode, agent_id, policy_id, policies, postprocessed_batch, original_batches, **kwargs)
¤
Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy, including looking at the trajectory data of other agents in multi-agent settings.
Parameters¤
RolloutWorker
Reference to the current rollout worker.
MultiAgentEpisode
Episode object.
str
Id of the current agent.
str
Id of the current policy for the agent.
dict
Mapping of policy id to policy objects. In single agent mode there will only be a single "default" policy.
SampleBatch
The postprocessed sample batch for this agent. You can mutate this object to apply your own trajectory postprocessing.
dict
Mapping of agents to their unpostprocessed trajectory data. You should not mutate this object.
Source code in corl/environment/default_env_rllib_callbacks.py
def on_postprocess_trajectory(
self,
*,
worker,
episode,
agent_id: AgentID,
policy_id: PolicyID,
policies: typing.Dict[PolicyID, Policy],
postprocessed_batch: SampleBatch,
original_batches: typing.Dict[AgentID, typing.Tuple[Policy, SampleBatch]],
**kwargs
) -> None:
"""
Called immediately after a policy's postprocess_fn is called.
You can use this callback to do additional postprocessing for a policy,
including looking at the trajectory data of other agents in multi-agent
settings.
Parameters
----------
worker: RolloutWorker
Reference to the current rollout worker.
episode: MultiAgentEpisode
Episode object.
agent_id: str
Id of the current agent.
policy_id: str
Id of the current policy for the agent.
policies: dict
Mapping of policy id to policy objects. In single
agent mode there will only be a single "default" policy.
postprocessed_batch: SampleBatch
The postprocessed sample batch
for this agent. You can mutate this object to apply your own
trajectory postprocessing.
original_batches: dict
Mapping of agents to their unpostprocessed
trajectory data. You should not mutate this object.
"""
super().on_postprocess_trajectory(
worker=worker,
episode=episode,
agent_id=agent_id,
policy_id=policy_id,
policies=policies,
postprocessed_batch=postprocessed_batch,
original_batches=original_batches
)
episode.worker.foreach_env(
lambda env: env.post_process_trajectory(agent_id, postprocessed_batch, episode, policies[policy_id])
) # type: ignore
on_train_result(self, *, trainer, result, **kwargs)
¤
Called at the end of Trainable.train().
Parameters¤
Trainer
Current trainer instance.
dict
Dict of results returned from trainer.train() call. You can mutate this object to add additional metrics.
Source code in corl/environment/default_env_rllib_callbacks.py
def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
"""
Called at the end of Trainable.train().
Parameters
----------
trainer: Trainer
Current trainer instance.
result: dict
Dict of results returned from trainer.train() call.
You can mutate this object to add additional metrics.
"""
rng, _ = seeding.np_random(seed=trainer.iteration)
# Environment EPP
assert result['config']['env'] == ACT3MultiAgentEnv.__name__
for epp in result['config']['env_config']['epp_registry'].values():
epp.update(result, rng)
log_done_info(env, episode)
¤
Log done info to done_results/{platform}/{done_name}
Source code in corl/environment/default_env_rllib_callbacks.py
def log_done_info(env, episode):
"""
Log done info to done_results/{platform}/{done_name}
"""
if env.done_info:
platform_done_info: typing.Dict[str, typing.Dict[str, bool]] = {}
for agent_id, agent_data in env.done_info.items():
plat_name = env.agent_dict[agent_id].platform_name
platform_done_info.setdefault(plat_name, {})
for done_name, done_data in agent_data.items():
if done_name == '__all__':
continue
platform_done_info[plat_name].setdefault(done_name, False)
if done_data[plat_name] is not None:
platform_done_info[plat_name][done_name] |= done_data[plat_name]
for plat_name, done_data in platform_done_info.items():
for done_name, done_value in done_data.items():
episode.custom_metrics[f'done_results/{plat_name}/{done_name}'] = int(done_value)
episode.custom_metrics[f'done_results/{plat_name}/NoDone'] = int(not any(platform_done_info[plat_name].values()))
log_done_status(env, episode)
¤
Log done status codes to done_status/{platform}/{status}
Source code in corl/environment/default_env_rllib_callbacks.py
def log_done_status(env, episode):
"""
Log done status codes to done_status/{platform}/{status}
"""
if env.state.episode_state:
for platform, data in env.state.episode_state.items():
if not data:
continue
episode_codes = set(data.values())
for status in DoneStatusCodes:
metric_key = f"done_status/{platform}/{status}"
if status in episode_codes:
episode.custom_metrics[metric_key] = 1
else:
episode.custom_metrics[metric_key] = 0