Skip to content

Scripted action


Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.

This is a US Government Work not subject to copyright protection in the US.

The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.


Scripted Action Policy

ScriptedActionPolicy (CustomPolicy) ¤

Scripted action policy.

Source code in corl/policies/scripted_action.py
class ScriptedActionPolicy(CustomPolicy):  # pylint: disable=abstract-method
    """Scripted action policy.
    """

    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)

        self._input_index: int
        self._last_action: dict

    @property
    def get_validator(self) -> typing.Type[BasePolicyValidator]:
        """
        Get the validator for this experiment class,
        the kwargs sent to the experiment class will
        be validated using this object and add a self.config
        attr to the experiment class
        """
        return ScriptedActionPolicyValidator

    def _reset(self):
        super()._reset()
        self._input_index = 0
        self._last_action = EnvSpaceUtil.get_zero_sample_from_space(self.validated_config.act_space)

    def custom_compute_actions(
        self,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
        info_batch=None,
        episodes=None,
        explore=None,
        timestep=None,
        sim_time=None,
        agent_id=None,
        info=None,
        episode=None,
        **kwargs
    ):
        for control_index in range(self._input_index, len(self.validated_config.control_times)):
            control_time = self.validated_config.control_times[control_index]
            if sim_time >= control_time:
                # apply control_list to controls
                control_values = self.validated_config.control_values[control_index]

                self._input_index = control_index + 1
                self._last_action = control_values
                return [control_values], [], {}

            break

        if self.validated_config.missing_action_policy == 'repeat_last_action':
            return [self._last_action], [], {}

        self._last_action = self.validated_config.default_action
        return [self.validated_config.default_action], [], {}

get_validator: Type[corl.policies.base_policy.BasePolicyValidator] property readonly ¤

Get the validator for this experiment class, the kwargs sent to the experiment class will be validated using this object and add a self.config attr to the experiment class

custom_compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, sim_time=None, agent_id=None, info=None, episode=None, **kwargs) ¤

Computes actions for the current policy.

Parameters:

Name Type Description Default
obs_batch

Batch of observations.

required
state_batches

List of RNN state input batches, if any.

None
prev_action_batch

Batch of previous action values.

None
prev_reward_batch

Batch of previous rewards.

None
info_batch

Batch of info objects.

None
episodes

List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

None
explore

Whether to pick an exploitation or exploration action. Set to None (default) for using the value of self.config["explore"].

None
timestep

The current (sampling) time step.

None

Keyword arguments:

Name Type Description
kwargs

Forward compatibility placeholder

Returns:

Type Description
actions (TensorType)

Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (List[dict]): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.

Source code in corl/policies/scripted_action.py
def custom_compute_actions(
    self,
    obs_batch,
    state_batches=None,
    prev_action_batch=None,
    prev_reward_batch=None,
    info_batch=None,
    episodes=None,
    explore=None,
    timestep=None,
    sim_time=None,
    agent_id=None,
    info=None,
    episode=None,
    **kwargs
):
    for control_index in range(self._input_index, len(self.validated_config.control_times)):
        control_time = self.validated_config.control_times[control_index]
        if sim_time >= control_time:
            # apply control_list to controls
            control_values = self.validated_config.control_values[control_index]

            self._input_index = control_index + 1
            self._last_action = control_values
            return [control_values], [], {}

        break

    if self.validated_config.missing_action_policy == 'repeat_last_action':
        return [self._last_action], [], {}

    self._last_action = self.validated_config.default_action
    return [self.validated_config.default_action], [], {}

ScriptedActionPolicyValidator (CustomPolicyValidator) pydantic-model ¤

Validator for the ScriptedActionPolicy

Source code in corl/policies/scripted_action.py
class ScriptedActionPolicyValidator(CustomPolicyValidator):
    """Validator for the ScriptedActionPolicy"""

    control_times: typing.List[float]
    control_values: typing.List[typing.Dict]

    missing_action_policy: typing.Literal['default_action', 'repeat_last_action']
    default_action: typing.Optional[typing.Dict] = None

    class Config:
        """pydantic configuration options"""
        arbitrary_types_allowed = True

    @validator('control_times')
    def sort_control_times(cls, v):  # pylint: disable=no-self-argument, no-self-use
        """Ensures that control_times are in order"""
        assert v == sorted(v), "control_times must be in order"
        return v

    @staticmethod
    def convert_control_value(controls, controller_key_paths, sample_control):
        """converts the controls into a dict"""

        flat_sample_control = flatten_dict.flatten(sample_control)

        assert len(controls) == len(controller_key_paths), 'mismatch between number of controllers and length of control values'
        assert len(controller_key_paths) == len(flat_sample_control), 'mismatch between number of controllers and the action_space'

        flat_control_dict = {}
        for i, ctrl_value in enumerate(controls):
            try:
                controller_key = controller_key_paths[i]
                sample_value = flat_sample_control[controller_key]
                if isinstance(sample_value, np.ndarray):
                    flat_control_dict[controller_key] = np.add(sample_value * 0, ctrl_value, dtype=sample_value.dtype)
                else:
                    flat_control_dict[controller_key] = type(sample_value)(ctrl_value)
            except Exception as e:
                raise RuntimeError(f'@idx: {i}, controller: {controller_key}, control_value: {ctrl_value}') from e
        control_dict = flatten_dict.unflatten(flat_control_dict)
        return control_dict

    @staticmethod
    def convert_control_values(action_space, controller_key_paths, control_times, controls_list):
        """converts the input_control values into dictionary and validates it agaist the action space"""
        assert len(controls_list) == len(control_times), 'mismatch between number of control_times and control_values'

        sample_control = action_space.sample()

        converted_control_list: typing.List[typing.Dict] = []
        for controls in controls_list:

            control_dict = ScriptedActionPolicyValidator.convert_control_value(controls, controller_key_paths, sample_control)

            EnvSpaceUtil.deep_sanity_check_space_sample(action_space, control_dict)

            converted_control_list.append(control_dict)

        return converted_control_list

    @validator('control_values', pre=True, always=True)
    def validate_control_values(cls, controls_list, values):  # pylint: disable=no-self-argument, no-self-use
        """validate that control_values match the controllers"""
        if 'controllers' not in values or 'control_times' not in values:
            raise ValueError(f'Could not run "validate_control_values" because previous items failed: {values}')

        return ScriptedActionPolicyValidator.convert_control_values(
            values['act_space'], values['controllers'], values['control_times'], controls_list
        )

    @validator('missing_action_policy')
    def validate_missing_action_policy(cls, missing_action_policy, values):
        """validates the missing action policy"""
        if missing_action_policy == 'repeat_last_action':
            assert values['control_times'][0] == 0, "missing control_time for t=0"
        return missing_action_policy

    @validator('default_action', pre=True)
    def validate_default_action(cls, default_action, values):
        """validates that the default_action is consistent with the missing_action_policy"""
        if values["missing_action_policy"] == 'default_action':
            assert default_action is not None, 'default_action is requried when using the default_action missing_action_policy'

            action_space = values['act_space']
            action = ScriptedActionPolicyValidator.convert_control_value(default_action, values['controllers'], action_space.sample())

            return action

        assert default_action is None, 'default_action is invalid except when using the default_action missing_action_policy'
        return default_action

Config ¤

pydantic configuration options

Source code in corl/policies/scripted_action.py
class Config:
    """pydantic configuration options"""
    arbitrary_types_allowed = True

convert_control_value(controls, controller_key_paths, sample_control) staticmethod ¤

converts the controls into a dict

Source code in corl/policies/scripted_action.py
@staticmethod
def convert_control_value(controls, controller_key_paths, sample_control):
    """converts the controls into a dict"""

    flat_sample_control = flatten_dict.flatten(sample_control)

    assert len(controls) == len(controller_key_paths), 'mismatch between number of controllers and length of control values'
    assert len(controller_key_paths) == len(flat_sample_control), 'mismatch between number of controllers and the action_space'

    flat_control_dict = {}
    for i, ctrl_value in enumerate(controls):
        try:
            controller_key = controller_key_paths[i]
            sample_value = flat_sample_control[controller_key]
            if isinstance(sample_value, np.ndarray):
                flat_control_dict[controller_key] = np.add(sample_value * 0, ctrl_value, dtype=sample_value.dtype)
            else:
                flat_control_dict[controller_key] = type(sample_value)(ctrl_value)
        except Exception as e:
            raise RuntimeError(f'@idx: {i}, controller: {controller_key}, control_value: {ctrl_value}') from e
    control_dict = flatten_dict.unflatten(flat_control_dict)
    return control_dict

convert_control_values(action_space, controller_key_paths, control_times, controls_list) staticmethod ¤

converts the input_control values into dictionary and validates it agaist the action space

Source code in corl/policies/scripted_action.py
@staticmethod
def convert_control_values(action_space, controller_key_paths, control_times, controls_list):
    """converts the input_control values into dictionary and validates it agaist the action space"""
    assert len(controls_list) == len(control_times), 'mismatch between number of control_times and control_values'

    sample_control = action_space.sample()

    converted_control_list: typing.List[typing.Dict] = []
    for controls in controls_list:

        control_dict = ScriptedActionPolicyValidator.convert_control_value(controls, controller_key_paths, sample_control)

        EnvSpaceUtil.deep_sanity_check_space_sample(action_space, control_dict)

        converted_control_list.append(control_dict)

    return converted_control_list

sort_control_times(v) classmethod ¤

Ensures that control_times are in order

Source code in corl/policies/scripted_action.py
@validator('control_times')
def sort_control_times(cls, v):  # pylint: disable=no-self-argument, no-self-use
    """Ensures that control_times are in order"""
    assert v == sorted(v), "control_times must be in order"
    return v

validate_control_values(controls_list, values) classmethod ¤

validate that control_values match the controllers

Source code in corl/policies/scripted_action.py
@validator('control_values', pre=True, always=True)
def validate_control_values(cls, controls_list, values):  # pylint: disable=no-self-argument, no-self-use
    """validate that control_values match the controllers"""
    if 'controllers' not in values or 'control_times' not in values:
        raise ValueError(f'Could not run "validate_control_values" because previous items failed: {values}')

    return ScriptedActionPolicyValidator.convert_control_values(
        values['act_space'], values['controllers'], values['control_times'], controls_list
    )

validate_default_action(default_action, values) classmethod ¤

validates that the default_action is consistent with the missing_action_policy

Source code in corl/policies/scripted_action.py
@validator('default_action', pre=True)
def validate_default_action(cls, default_action, values):
    """validates that the default_action is consistent with the missing_action_policy"""
    if values["missing_action_policy"] == 'default_action':
        assert default_action is not None, 'default_action is requried when using the default_action missing_action_policy'

        action_space = values['act_space']
        action = ScriptedActionPolicyValidator.convert_control_value(default_action, values['controllers'], action_space.sample())

        return action

    assert default_action is None, 'default_action is invalid except when using the default_action missing_action_policy'
    return default_action

validate_missing_action_policy(missing_action_policy, values) classmethod ¤

validates the missing action policy

Source code in corl/policies/scripted_action.py
@validator('missing_action_policy')
def validate_missing_action_policy(cls, missing_action_policy, values):
    """validates the missing action policy"""
    if missing_action_policy == 'repeat_last_action':
        assert values['control_times'][0] == 0, "missing control_time for t=0"
    return missing_action_policy