Scripted action
Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.
This is a US Government Work not subject to copyright protection in the US.
The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.
Scripted Action Policy
ScriptedActionPolicy (CustomPolicy)
¤
Scripted action policy.
Source code in corl/policies/scripted_action.py
class ScriptedActionPolicy(CustomPolicy): # pylint: disable=abstract-method
"""Scripted action policy.
"""
def __init__(self, observation_space, action_space, config):
super().__init__(observation_space, action_space, config)
self._input_index: int
self._last_action: dict
@property
def get_validator(self) -> typing.Type[BasePolicyValidator]:
"""
Get the validator for this experiment class,
the kwargs sent to the experiment class will
be validated using this object and add a self.config
attr to the experiment class
"""
return ScriptedActionPolicyValidator
def _reset(self):
super()._reset()
self._input_index = 0
self._last_action = EnvSpaceUtil.get_zero_sample_from_space(self.validated_config.act_space)
def custom_compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
sim_time=None,
agent_id=None,
info=None,
episode=None,
**kwargs
):
for control_index in range(self._input_index, len(self.validated_config.control_times)):
control_time = self.validated_config.control_times[control_index]
if sim_time >= control_time:
# apply control_list to controls
control_values = self.validated_config.control_values[control_index]
self._input_index = control_index + 1
self._last_action = control_values
return [control_values], [], {}
break
if self.validated_config.missing_action_policy == 'repeat_last_action':
return [self._last_action], [], {}
self._last_action = self.validated_config.default_action
return [self.validated_config.default_action], [], {}
get_validator: Type[corl.policies.base_policy.BasePolicyValidator]
property
readonly
¤
Get the validator for this experiment class, the kwargs sent to the experiment class will be validated using this object and add a self.config attr to the experiment class
custom_compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, sim_time=None, agent_id=None, info=None, episode=None, **kwargs)
¤
Computes actions for the current policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obs_batch |
Batch of observations. |
required | |
state_batches |
List of RNN state input batches, if any. |
None |
|
prev_action_batch |
Batch of previous action values. |
None |
|
prev_reward_batch |
Batch of previous rewards. |
None |
|
info_batch |
Batch of info objects. |
None |
|
episodes |
List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms. |
None |
|
explore |
Whether to pick an exploitation or exploration action.
Set to None (default) for using the value of
|
None |
|
timestep |
The current (sampling) time step. |
None |
Keyword arguments:
Name | Type | Description |
---|---|---|
kwargs |
Forward compatibility placeholder |
Returns:
Type | Description |
---|---|
actions (TensorType) |
Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. |
Source code in corl/policies/scripted_action.py
def custom_compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
explore=None,
timestep=None,
sim_time=None,
agent_id=None,
info=None,
episode=None,
**kwargs
):
for control_index in range(self._input_index, len(self.validated_config.control_times)):
control_time = self.validated_config.control_times[control_index]
if sim_time >= control_time:
# apply control_list to controls
control_values = self.validated_config.control_values[control_index]
self._input_index = control_index + 1
self._last_action = control_values
return [control_values], [], {}
break
if self.validated_config.missing_action_policy == 'repeat_last_action':
return [self._last_action], [], {}
self._last_action = self.validated_config.default_action
return [self.validated_config.default_action], [], {}
ScriptedActionPolicyValidator (CustomPolicyValidator)
pydantic-model
¤
Validator for the ScriptedActionPolicy
Source code in corl/policies/scripted_action.py
class ScriptedActionPolicyValidator(CustomPolicyValidator):
"""Validator for the ScriptedActionPolicy"""
control_times: typing.List[float]
control_values: typing.List[typing.Dict]
missing_action_policy: typing.Literal['default_action', 'repeat_last_action']
default_action: typing.Optional[typing.Dict] = None
class Config:
"""pydantic configuration options"""
arbitrary_types_allowed = True
@validator('control_times')
def sort_control_times(cls, v): # pylint: disable=no-self-argument, no-self-use
"""Ensures that control_times are in order"""
assert v == sorted(v), "control_times must be in order"
return v
@staticmethod
def convert_control_value(controls, controller_key_paths, sample_control):
"""converts the controls into a dict"""
flat_sample_control = flatten_dict.flatten(sample_control)
assert len(controls) == len(controller_key_paths), 'mismatch between number of controllers and length of control values'
assert len(controller_key_paths) == len(flat_sample_control), 'mismatch between number of controllers and the action_space'
flat_control_dict = {}
for i, ctrl_value in enumerate(controls):
try:
controller_key = controller_key_paths[i]
sample_value = flat_sample_control[controller_key]
if isinstance(sample_value, np.ndarray):
flat_control_dict[controller_key] = np.add(sample_value * 0, ctrl_value, dtype=sample_value.dtype)
else:
flat_control_dict[controller_key] = type(sample_value)(ctrl_value)
except Exception as e:
raise RuntimeError(f'@idx: {i}, controller: {controller_key}, control_value: {ctrl_value}') from e
control_dict = flatten_dict.unflatten(flat_control_dict)
return control_dict
@staticmethod
def convert_control_values(action_space, controller_key_paths, control_times, controls_list):
"""converts the input_control values into dictionary and validates it agaist the action space"""
assert len(controls_list) == len(control_times), 'mismatch between number of control_times and control_values'
sample_control = action_space.sample()
converted_control_list: typing.List[typing.Dict] = []
for controls in controls_list:
control_dict = ScriptedActionPolicyValidator.convert_control_value(controls, controller_key_paths, sample_control)
EnvSpaceUtil.deep_sanity_check_space_sample(action_space, control_dict)
converted_control_list.append(control_dict)
return converted_control_list
@validator('control_values', pre=True, always=True)
def validate_control_values(cls, controls_list, values): # pylint: disable=no-self-argument, no-self-use
"""validate that control_values match the controllers"""
if 'controllers' not in values or 'control_times' not in values:
raise ValueError(f'Could not run "validate_control_values" because previous items failed: {values}')
return ScriptedActionPolicyValidator.convert_control_values(
values['act_space'], values['controllers'], values['control_times'], controls_list
)
@validator('missing_action_policy')
def validate_missing_action_policy(cls, missing_action_policy, values):
"""validates the missing action policy"""
if missing_action_policy == 'repeat_last_action':
assert values['control_times'][0] == 0, "missing control_time for t=0"
return missing_action_policy
@validator('default_action', pre=True)
def validate_default_action(cls, default_action, values):
"""validates that the default_action is consistent with the missing_action_policy"""
if values["missing_action_policy"] == 'default_action':
assert default_action is not None, 'default_action is requried when using the default_action missing_action_policy'
action_space = values['act_space']
action = ScriptedActionPolicyValidator.convert_control_value(default_action, values['controllers'], action_space.sample())
return action
assert default_action is None, 'default_action is invalid except when using the default_action missing_action_policy'
return default_action
Config
¤
pydantic configuration options
Source code in corl/policies/scripted_action.py
class Config:
"""pydantic configuration options"""
arbitrary_types_allowed = True
convert_control_value(controls, controller_key_paths, sample_control)
staticmethod
¤
converts the controls into a dict
Source code in corl/policies/scripted_action.py
@staticmethod
def convert_control_value(controls, controller_key_paths, sample_control):
"""converts the controls into a dict"""
flat_sample_control = flatten_dict.flatten(sample_control)
assert len(controls) == len(controller_key_paths), 'mismatch between number of controllers and length of control values'
assert len(controller_key_paths) == len(flat_sample_control), 'mismatch between number of controllers and the action_space'
flat_control_dict = {}
for i, ctrl_value in enumerate(controls):
try:
controller_key = controller_key_paths[i]
sample_value = flat_sample_control[controller_key]
if isinstance(sample_value, np.ndarray):
flat_control_dict[controller_key] = np.add(sample_value * 0, ctrl_value, dtype=sample_value.dtype)
else:
flat_control_dict[controller_key] = type(sample_value)(ctrl_value)
except Exception as e:
raise RuntimeError(f'@idx: {i}, controller: {controller_key}, control_value: {ctrl_value}') from e
control_dict = flatten_dict.unflatten(flat_control_dict)
return control_dict
convert_control_values(action_space, controller_key_paths, control_times, controls_list)
staticmethod
¤
converts the input_control values into dictionary and validates it agaist the action space
Source code in corl/policies/scripted_action.py
@staticmethod
def convert_control_values(action_space, controller_key_paths, control_times, controls_list):
"""converts the input_control values into dictionary and validates it agaist the action space"""
assert len(controls_list) == len(control_times), 'mismatch between number of control_times and control_values'
sample_control = action_space.sample()
converted_control_list: typing.List[typing.Dict] = []
for controls in controls_list:
control_dict = ScriptedActionPolicyValidator.convert_control_value(controls, controller_key_paths, sample_control)
EnvSpaceUtil.deep_sanity_check_space_sample(action_space, control_dict)
converted_control_list.append(control_dict)
return converted_control_list
sort_control_times(v)
classmethod
¤
Ensures that control_times are in order
Source code in corl/policies/scripted_action.py
@validator('control_times')
def sort_control_times(cls, v): # pylint: disable=no-self-argument, no-self-use
"""Ensures that control_times are in order"""
assert v == sorted(v), "control_times must be in order"
return v
validate_control_values(controls_list, values)
classmethod
¤
validate that control_values match the controllers
Source code in corl/policies/scripted_action.py
@validator('control_values', pre=True, always=True)
def validate_control_values(cls, controls_list, values): # pylint: disable=no-self-argument, no-self-use
"""validate that control_values match the controllers"""
if 'controllers' not in values or 'control_times' not in values:
raise ValueError(f'Could not run "validate_control_values" because previous items failed: {values}')
return ScriptedActionPolicyValidator.convert_control_values(
values['act_space'], values['controllers'], values['control_times'], controls_list
)
validate_default_action(default_action, values)
classmethod
¤
validates that the default_action is consistent with the missing_action_policy
Source code in corl/policies/scripted_action.py
@validator('default_action', pre=True)
def validate_default_action(cls, default_action, values):
"""validates that the default_action is consistent with the missing_action_policy"""
if values["missing_action_policy"] == 'default_action':
assert default_action is not None, 'default_action is requried when using the default_action missing_action_policy'
action_space = values['act_space']
action = ScriptedActionPolicyValidator.convert_control_value(default_action, values['controllers'], action_space.sample())
return action
assert default_action is None, 'default_action is invalid except when using the default_action missing_action_policy'
return default_action
validate_missing_action_policy(missing_action_policy, values)
classmethod
¤
validates the missing action policy
Source code in corl/policies/scripted_action.py
@validator('missing_action_policy')
def validate_missing_action_policy(cls, missing_action_policy, values):
"""validates the missing action policy"""
if missing_action_policy == 'repeat_last_action':
assert values['control_times'][0] == 0, "missing control_time for t=0"
return missing_action_policy