Custom policy
Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.
This is a US Government Work not subject to copyright protection in the US.
The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.
Custom Policy
CustomPolicy (Policy)
¤
Custom base policy.
Source code in corl/policies/custom_policy.py
class CustomPolicy(Policy): # pylint: disable=abstract-method
"""Custom base policy.
"""
def __init__(self, observation_space, action_space, config):
self.validated_config: CustomPolicyValidator = self.get_validator(act_space=action_space, obs_space=observation_space, **config)
Policy.__init__(self, observation_space, action_space, config)
self.time_extractor = self.validated_config.time_extractor.construct_extractors()
self._reset()
@property
def get_validator(self) -> typing.Type[BasePolicyValidator]:
"""
Get the validator for this experiment class,
the kwargs sent to the experiment class will
be validated using this object and add a self.config
attr to the experiment class
"""
return CustomPolicyValidator
def _reset(self):
"""This must be overriden in order to reset the state between runs
"""
...
def learn_on_batch(self, samples):
return {}
def get_weights(self):
return {}
def set_weights(self, weights):
return
def compute_actions_from_input_dict(
self,
input_dict: typing.Union[SampleBatch, typing.Dict[str, TensorStructType]],
explore: bool = None,
timestep: typing.Optional[int] = None,
episodes: typing.Optional[typing.List[Episode]] = None,
**kwargs
) -> typing.Tuple[TensorType, typing.List[TensorType], typing.Dict[str, TensorType]]:
"""Computes actions from collected samples (across multiple-agents).
Takes an input dict (usually a SampleBatch) as its main data input.
This allows for using this method in case a more complex input pattern
(view requirements) is needed, for example when the Model requires the
last n observations, the last m actions/rewards, or a combination
of any of these.
Args:
input_dict: A SampleBatch or input dict containing the Tensors
to compute actions. `input_dict` already abides to the
Policy's as well as the Model's view requirements and can
thus be passed to the Model as-is.
explore: Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep: The current (sampling) time step.
episodes: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
Keyword Args:
kwargs: Forward compatibility placeholder.
Returns:
actions: Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs: List of RNN state output
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
info: Dictionary of extra feature batches, if any, with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
# Default implementation just passes obs, prev-a/r, and states on to
# `self.compute_actions()`.
agent_index = input_dict[SampleBatch.AGENT_INDEX][0]
episode_id = input_dict[SampleBatch.EPS_ID][0]
episode = [eps for eps in episodes if eps.episode_id == episode_id][0] # type: ignore
agent_id: str = [
aid for aid in episode._agent_to_index if episode._agent_to_index[aid] == agent_index # pylint: disable=protected-access
][0]
obs_batch = input_dict[SampleBatch.OBS]
info = episode.last_info_for(agent_id)
if info is None:
info = {}
if 'platform_obs' in info:
sim_time = self.time_extractor.value(info['platform_obs'][agent_id], full_extraction=True)
else:
self._reset()
sim_time = -1
state_batches = [s for k, s in input_dict.items() if k[:9] == "state_in_"]
return self.compute_actions(
obs_batch,
state_batches,
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
info_batch=input_dict.get(SampleBatch.INFOS), # type: ignore
explore=explore,
timestep=timestep,
episodes=episodes,
sim_time=sim_time,
agent_id=agent_id,
info=info,
episode=episode,
**kwargs,
)
def compute_actions(
self,
obs_batch: typing.Union[typing.List[TensorStructType], TensorStructType],
state_batches: typing.Optional[typing.List[TensorType]] = None,
prev_action_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
prev_reward_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
info_batch: typing.Optional[typing.Dict[str, list]] = None,
episodes: typing.Optional[typing.List[Episode]] = None,
explore: typing.Optional[bool] = None,
timestep: typing.Optional[int] = None,
**kwargs
) -> typing.Tuple[TensorType, typing.List[TensorType], typing.Dict[str, TensorType]]:
actions, state_outs, info = self.custom_compute_actions(
obs_batch,
state_batches=state_batches,
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes,
explore=explore,
timestep=timestep,
**kwargs
)
if self.validated_config.normalize_controls:
for i, action in enumerate(actions):
actions[i] = EnvSpaceUtil.scale_sample_from_space(self.validated_config.act_space, action)
return actions, state_outs, info
@abstractmethod
def custom_compute_actions(
self,
obs_batch: typing.Union[typing.List[TensorStructType], TensorStructType],
state_batches: typing.Optional[typing.List[TensorType]] = None,
prev_action_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
prev_reward_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
info_batch: typing.Optional[typing.Dict[str, list]] = None,
episodes: typing.Optional[typing.List[Episode]] = None,
explore: typing.Optional[bool] = None,
timestep: typing.Optional[int] = None,
sim_time: typing.Optional[float] = None,
agent_id: typing.Optional[str] = None,
info: typing.Optional[dict] = None,
episode: typing.Optional[Episode] = None,
**kwargs
) -> typing.Tuple[TensorType, typing.List[TensorType], typing.Dict[str, TensorType]]:
"""Computes actions for the current policy.
Args:
obs_batch: Batch of observations.
state_batches: List of RNN state input batches, if any.
prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards.
info_batch: Batch of info objects.
episodes: List of Episode objects, one for each obs in
obs_batch. This provides access to all of the internal
episode state, which may be useful for model-based or
multi-agent algorithms.
explore: Whether to pick an exploitation or exploration action.
Set to None (default) for using the value of
`self.config["explore"]`.
timestep: The current (sampling) time step.
Keyword Args:
kwargs: Forward compatibility placeholder
Returns:
actions (TensorType): Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (List[TensorType]): List of RNN state output
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
info (List[dict]): Dictionary of extra feature batches, if any,
with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
get_validator: Type[corl.policies.base_policy.BasePolicyValidator]
property
readonly
¤
Get the validator for this experiment class, the kwargs sent to the experiment class will be validated using this object and add a self.config attr to the experiment class
compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs)
¤
Computes actions for the current policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obs_batch |
Union[List[Union[Any, dict, tuple]], Any, dict, tuple] |
Batch of observations. |
required |
state_batches |
Optional[List[Any]] |
List of RNN state input batches, if any. |
None |
prev_action_batch |
Union[List[Union[Any, dict, tuple]], Any, dict, tuple] |
Batch of previous action values. |
None |
prev_reward_batch |
Union[List[Union[Any, dict, tuple]], Any, dict, tuple] |
Batch of previous rewards. |
None |
info_batch |
Optional[Dict[str, list]] |
Batch of info objects. |
None |
episodes |
Optional[List[ray.rllib.evaluation.episode.Episode]] |
List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. |
None |
explore |
Optional[bool] |
Whether to pick an exploitation or exploration action.
Set to None (default) for using the value of
|
None |
timestep |
Optional[int] |
The current (sampling) time step. |
None |
Keyword arguments:
Name | Type | Description |
---|---|---|
kwargs |
Forward compatibility placeholder |
Returns:
Type | Description |
---|---|
actions (TensorType) |
Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. |
Source code in corl/policies/custom_policy.py
def compute_actions(
self,
obs_batch: typing.Union[typing.List[TensorStructType], TensorStructType],
state_batches: typing.Optional[typing.List[TensorType]] = None,
prev_action_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
prev_reward_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
info_batch: typing.Optional[typing.Dict[str, list]] = None,
episodes: typing.Optional[typing.List[Episode]] = None,
explore: typing.Optional[bool] = None,
timestep: typing.Optional[int] = None,
**kwargs
) -> typing.Tuple[TensorType, typing.List[TensorType], typing.Dict[str, TensorType]]:
actions, state_outs, info = self.custom_compute_actions(
obs_batch,
state_batches=state_batches,
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes,
explore=explore,
timestep=timestep,
**kwargs
)
if self.validated_config.normalize_controls:
for i, action in enumerate(actions):
actions[i] = EnvSpaceUtil.scale_sample_from_space(self.validated_config.act_space, action)
return actions, state_outs, info
compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs)
¤
Computes actions from collected samples (across multiple-agents).
Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dict |
Union[ray.rllib.policy.sample_batch.SampleBatch, Dict[str, Union[Any, dict, tuple]]] |
A SampleBatch or input dict containing the Tensors
to compute actions. |
required |
explore |
bool |
Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). |
None |
timestep |
Optional[int] |
The current (sampling) time step. |
None |
episodes |
Optional[List[ray.rllib.evaluation.episode.Episode]] |
This provides access to all of the internal episodes' state, which may be useful for model-based or multi-agent algorithms. |
None |
Keyword arguments:
Name | Type | Description |
---|---|---|
kwargs |
Forward compatibility placeholder. |
Returns:
Type | Description |
---|---|
actions |
Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs: List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info: Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. |
Source code in corl/policies/custom_policy.py
def compute_actions_from_input_dict(
self,
input_dict: typing.Union[SampleBatch, typing.Dict[str, TensorStructType]],
explore: bool = None,
timestep: typing.Optional[int] = None,
episodes: typing.Optional[typing.List[Episode]] = None,
**kwargs
) -> typing.Tuple[TensorType, typing.List[TensorType], typing.Dict[str, TensorType]]:
"""Computes actions from collected samples (across multiple-agents).
Takes an input dict (usually a SampleBatch) as its main data input.
This allows for using this method in case a more complex input pattern
(view requirements) is needed, for example when the Model requires the
last n observations, the last m actions/rewards, or a combination
of any of these.
Args:
input_dict: A SampleBatch or input dict containing the Tensors
to compute actions. `input_dict` already abides to the
Policy's as well as the Model's view requirements and can
thus be passed to the Model as-is.
explore: Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
timestep: The current (sampling) time step.
episodes: This provides access to all of the internal episodes'
state, which may be useful for model-based or multi-agent
algorithms.
Keyword Args:
kwargs: Forward compatibility placeholder.
Returns:
actions: Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs: List of RNN state output
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
info: Dictionary of extra feature batches, if any, with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
# Default implementation just passes obs, prev-a/r, and states on to
# `self.compute_actions()`.
agent_index = input_dict[SampleBatch.AGENT_INDEX][0]
episode_id = input_dict[SampleBatch.EPS_ID][0]
episode = [eps for eps in episodes if eps.episode_id == episode_id][0] # type: ignore
agent_id: str = [
aid for aid in episode._agent_to_index if episode._agent_to_index[aid] == agent_index # pylint: disable=protected-access
][0]
obs_batch = input_dict[SampleBatch.OBS]
info = episode.last_info_for(agent_id)
if info is None:
info = {}
if 'platform_obs' in info:
sim_time = self.time_extractor.value(info['platform_obs'][agent_id], full_extraction=True)
else:
self._reset()
sim_time = -1
state_batches = [s for k, s in input_dict.items() if k[:9] == "state_in_"]
return self.compute_actions(
obs_batch,
state_batches,
prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
info_batch=input_dict.get(SampleBatch.INFOS), # type: ignore
explore=explore,
timestep=timestep,
episodes=episodes,
sim_time=sim_time,
agent_id=agent_id,
info=info,
episode=episode,
**kwargs,
)
custom_compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, sim_time=None, agent_id=None, info=None, episode=None, **kwargs)
¤
Computes actions for the current policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obs_batch |
Union[List[Union[Any, dict, tuple]], Any, dict, tuple] |
Batch of observations. |
required |
state_batches |
Optional[List[Any]] |
List of RNN state input batches, if any. |
None |
prev_action_batch |
Union[List[Union[Any, dict, tuple]], Any, dict, tuple] |
Batch of previous action values. |
None |
prev_reward_batch |
Union[List[Union[Any, dict, tuple]], Any, dict, tuple] |
Batch of previous rewards. |
None |
info_batch |
Optional[Dict[str, list]] |
Batch of info objects. |
None |
episodes |
Optional[List[ray.rllib.evaluation.episode.Episode]] |
List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. |
None |
explore |
Optional[bool] |
Whether to pick an exploitation or exploration action.
Set to None (default) for using the value of
|
None |
timestep |
Optional[int] |
The current (sampling) time step. |
None |
Keyword arguments:
Name | Type | Description |
---|---|---|
kwargs |
Forward compatibility placeholder |
Returns:
Type | Description |
---|---|
actions (TensorType) |
Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. |
Source code in corl/policies/custom_policy.py
@abstractmethod
def custom_compute_actions(
self,
obs_batch: typing.Union[typing.List[TensorStructType], TensorStructType],
state_batches: typing.Optional[typing.List[TensorType]] = None,
prev_action_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
prev_reward_batch: typing.Union[typing.List[TensorStructType], TensorStructType] = None,
info_batch: typing.Optional[typing.Dict[str, list]] = None,
episodes: typing.Optional[typing.List[Episode]] = None,
explore: typing.Optional[bool] = None,
timestep: typing.Optional[int] = None,
sim_time: typing.Optional[float] = None,
agent_id: typing.Optional[str] = None,
info: typing.Optional[dict] = None,
episode: typing.Optional[Episode] = None,
**kwargs
) -> typing.Tuple[TensorType, typing.List[TensorType], typing.Dict[str, TensorType]]:
"""Computes actions for the current policy.
Args:
obs_batch: Batch of observations.
state_batches: List of RNN state input batches, if any.
prev_action_batch: Batch of previous action values.
prev_reward_batch: Batch of previous rewards.
info_batch: Batch of info objects.
episodes: List of Episode objects, one for each obs in
obs_batch. This provides access to all of the internal
episode state, which may be useful for model-based or
multi-agent algorithms.
explore: Whether to pick an exploitation or exploration action.
Set to None (default) for using the value of
`self.config["explore"]`.
timestep: The current (sampling) time step.
Keyword Args:
kwargs: Forward compatibility placeholder
Returns:
actions (TensorType): Batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (List[TensorType]): List of RNN state output
batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
info (List[dict]): Dictionary of extra feature batches, if any,
with shape like
{"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
get_weights(self)
¤
Returns model weights.
Note: The return value of this method will reside under the "weights" key in the return value of Policy.get_state(). Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.
Returns:
Type | Description |
---|---|
Serializable copy or view of model weights. |
Source code in corl/policies/custom_policy.py
def get_weights(self):
return {}
learn_on_batch(self, samples)
¤
Perform one learning update, given samples
.
Either this method or the combination of compute_gradients
and
apply_gradients
must be implemented by subclasses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
samples |
The SampleBatch object to learn from. |
required |
Returns:
Type | Description |
---|---|
Dictionary of extra metadata from |
Examples:
>>> policy, sample_batch = ...
>>> policy.learn_on_batch(sample_batch)
Source code in corl/policies/custom_policy.py
def learn_on_batch(self, samples):
return {}
set_weights(self, weights)
¤
Sets this Policy's model's weights.
Note: Model weights are only one part of a Policy's state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weights |
Serializable copy or view of model weights. |
required |
Source code in corl/policies/custom_policy.py
def set_weights(self, weights):
return
CustomPolicyValidator (BasePolicyValidator)
pydantic-model
¤
Base validator for the CustomPolicy
Source code in corl/policies/custom_policy.py
class CustomPolicyValidator(BasePolicyValidator):
"""Base validator for the CustomPolicy"""
act_space: gym.Space
obs_space: gym.Space
time_extractor: ObservationExtractorValidator
# rllib assumes that actions have been normalized and calls 'unsquash_action' prior to sending it to the environment:
# https://github.com/ray-project/ray/blob/c78bd809ce4b2ec0e48c77aa461684ee1e6f259b/rllib/evaluation/sampler.py#L1241
normalize_controls: bool = True
controllers: typing.List[typing.Tuple]
class Config:
"""pydantic configuration options"""
arbitrary_types_allowed = True
@validator('act_space')
def validate_act_space(cls, v): # pylint: disable=no-self-argument, no-self-use
"""validate that it's an instance of an gym.Space"""
assert isinstance(v, gym.Space)
# TODO Issue warning if the action space is normalized
return v
@validator('time_extractor', always=True)
def validate_extractor(cls, v, values): # pylint: disable=no-self-argument, no-self-use
"""Ensures the time_extractor can actually extract data from the space and that it's not normalized"""
try:
time_space = v.construct_extractors().space(values['obs_space'].original_space)
if isinstance(time_space, gym.spaces.Box):
assert np.isinf(time_space.high[v.indices[0]]), "time_space must not be normalized"
except Exception as e:
raise RuntimeError(f"Failed to extract time using {v} from {values['obs_space'].original_space}") from e
return v
@validator('controllers', pre=True, always=True)
def validate_controllers(cls, v, values): # pylint: disable=no-self-argument, no-self-use
"""validate that the controllers match the action_space"""
tuple_list = []
for iterable in v:
tuple_list.append(tuple(iterable))
assert len(tuple_list) == len(set(tuple_list)), 'controller definitions must be unique'
sample_control = flatten_dict.flatten(values['act_space'].sample())
for tuple_key in tuple_list:
assert tuple_key in sample_control, f'controller {tuple_key} not found in action_space: {list(sample_control.keys())}'
return tuple_list
Config
¤
pydantic configuration options
Source code in corl/policies/custom_policy.py
class Config:
"""pydantic configuration options"""
arbitrary_types_allowed = True
validate_act_space(v)
classmethod
¤
validate that it's an instance of an gym.Space
Source code in corl/policies/custom_policy.py
@validator('act_space')
def validate_act_space(cls, v): # pylint: disable=no-self-argument, no-self-use
"""validate that it's an instance of an gym.Space"""
assert isinstance(v, gym.Space)
# TODO Issue warning if the action space is normalized
return v
validate_controllers(v, values)
classmethod
¤
validate that the controllers match the action_space
Source code in corl/policies/custom_policy.py
@validator('controllers', pre=True, always=True)
def validate_controllers(cls, v, values): # pylint: disable=no-self-argument, no-self-use
"""validate that the controllers match the action_space"""
tuple_list = []
for iterable in v:
tuple_list.append(tuple(iterable))
assert len(tuple_list) == len(set(tuple_list)), 'controller definitions must be unique'
sample_control = flatten_dict.flatten(values['act_space'].sample())
for tuple_key in tuple_list:
assert tuple_key in sample_control, f'controller {tuple_key} not found in action_space: {list(sample_control.keys())}'
return tuple_list
validate_extractor(v, values)
classmethod
¤
Ensures the time_extractor can actually extract data from the space and that it's not normalized
Source code in corl/policies/custom_policy.py
@validator('time_extractor', always=True)
def validate_extractor(cls, v, values): # pylint: disable=no-self-argument, no-self-use
"""Ensures the time_extractor can actually extract data from the space and that it's not normalized"""
try:
time_space = v.construct_extractors().space(values['obs_space'].original_space)
if isinstance(time_space, gym.spaces.Box):
assert np.isinf(time_space.high[v.indices[0]]), "time_space must not be normalized"
except Exception as e:
raise RuntimeError(f"Failed to extract time using {v} from {values['obs_space'].original_space}") from e
return v