Skip to content

Core


Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.

This is a US Government Work not subject to copyright protection in the US.

The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.


EpisodeParameterProvider (ABC) ¤

Interface definition for episode parameter providers.

Source code in corl/episode_parameter_providers/core.py
class EpisodeParameterProvider(abc.ABC):
    """Interface definition for episode parameter providers.
    """

    def __init__(self, **kwargs) -> None:
        self.config: EpisodeParameterProviderValidator = self.get_validator(**kwargs)

    @property
    def get_validator(self) -> typing.Type[EpisodeParameterProviderValidator]:
        """Get the validator for this class."""
        return EpisodeParameterProviderValidator

    def get_params(self, rng: Randomness) -> typing.Tuple[ParameterModel, typing.Union[int, None]]:
        """Get the next instance of episode parameters from this provider

        Subclasses: Override _do_get_params.

        Parameters
        ----------
        rng : Union[Generator, RandomState]
            Random number generator from which to draw random values.

        Returns
        -------
        ParameterModel
            The parameters for this episode
        episode_id
            The episode index number for this set of parameters
        """

        output, episode_id = self._do_get_params(rng)

        extra_keys = output.keys() - self.config.parameters.keys()
        if extra_keys:
            raise KeyError(f'Extra keys provided: {extra_keys}')

        missing_keys = self.config.parameters.keys() - output.keys()
        if missing_keys:
            raise KeyError(f'Missing keys: {missing_keys}')

        bad_types: typing.Dict[str, str] = {}
        for key, value in output.items():
            if not isinstance(value, Parameter):
                bad_types['.'.join(key)] = type(value).__name__
        if bad_types:
            raise TypeError(f'Unsupported types: {bad_types}')

        return output, episode_id

    @abc.abstractmethod
    def _do_get_params(self, rng: Randomness) -> typing.Tuple[ParameterModel, typing.Union[int, None]]:
        """Get the next instance of episode parameters from this provider

        This is an abstract method that must be overridden by each subclass.

        DO NOT CALL DIRECTLY.  USE get_params.

        Parameters
        ----------
        rng : Union[Generator, RandomState]
            Random number generator from which to draw random values.

        Returns
        -------
        ParameterCollection
            The parameters for this episode
        episode_id
            The episode index number for this set of parameters
        """
        raise NotImplementedError

    def compute_metrics(self) -> typing.Dict[str, typing.Any]:  # pylint: disable=no-self-use
        """Get metrics on the operation of this provider.

        Often used in `on_episode_end` training callbacks.
        """
        return {}

    def update(self, results: dict, rng: Randomness) -> None:  # pylint: disable=no-self-use, unused-argument
        """Update the operation of this provider.

        Often used in `on_train_result` training callbacks.

        Parameters
        ----------
        results : dict
            As described by ray.rllib.agents.callbacks.DefaultCallbacks.on_train_result.
            See https://docs.ray.io/en/master/_modules/ray/rllib/agents/callbacks.html#DefaultCallbacks.on_train_result
        rng : Union[Generator, RandomState]
            Random number generator from which to draw random values.
        """
        ...

    def save_checkpoint(self, checkpoint_path: PathLike) -> None:  # pylint: disable=no-self-use, unused-argument
        """Save the internal state of the parameter provider.

        Parameters
        ----------
        checkpoint_path : PathLike
            Filesystem path at which to save the checkpoint
        """
        ...

    def load_checkpoint(self, checkpoint_path: PathLike) -> None:  # pylint: disable=no-self-use, unused-argument
        """Load the internal state from a checkpoint.

        Parameters
        ----------
        checkpoint_path : PathLike
            Filesystem path from which to restore the checkpoint
        """
        ...

get_validator: Type[corl.episode_parameter_providers.core.EpisodeParameterProviderValidator] property readonly ¤

Get the validator for this class.

compute_metrics(self) ¤

Get metrics on the operation of this provider.

Often used in on_episode_end training callbacks.

Source code in corl/episode_parameter_providers/core.py
def compute_metrics(self) -> typing.Dict[str, typing.Any]:  # pylint: disable=no-self-use
    """Get metrics on the operation of this provider.

    Often used in `on_episode_end` training callbacks.
    """
    return {}

get_params(self, rng) ¤

Get the next instance of episode parameters from this provider

Subclasses: Override _do_get_params.

Parameters¤

rng : Union[Generator, RandomState] Random number generator from which to draw random values.

Returns¤

ParameterModel The parameters for this episode episode_id The episode index number for this set of parameters

Source code in corl/episode_parameter_providers/core.py
def get_params(self, rng: Randomness) -> typing.Tuple[ParameterModel, typing.Union[int, None]]:
    """Get the next instance of episode parameters from this provider

    Subclasses: Override _do_get_params.

    Parameters
    ----------
    rng : Union[Generator, RandomState]
        Random number generator from which to draw random values.

    Returns
    -------
    ParameterModel
        The parameters for this episode
    episode_id
        The episode index number for this set of parameters
    """

    output, episode_id = self._do_get_params(rng)

    extra_keys = output.keys() - self.config.parameters.keys()
    if extra_keys:
        raise KeyError(f'Extra keys provided: {extra_keys}')

    missing_keys = self.config.parameters.keys() - output.keys()
    if missing_keys:
        raise KeyError(f'Missing keys: {missing_keys}')

    bad_types: typing.Dict[str, str] = {}
    for key, value in output.items():
        if not isinstance(value, Parameter):
            bad_types['.'.join(key)] = type(value).__name__
    if bad_types:
        raise TypeError(f'Unsupported types: {bad_types}')

    return output, episode_id

load_checkpoint(self, checkpoint_path) ¤

Load the internal state from a checkpoint.

Parameters¤

checkpoint_path : PathLike Filesystem path from which to restore the checkpoint

Source code in corl/episode_parameter_providers/core.py
def load_checkpoint(self, checkpoint_path: PathLike) -> None:  # pylint: disable=no-self-use, unused-argument
    """Load the internal state from a checkpoint.

    Parameters
    ----------
    checkpoint_path : PathLike
        Filesystem path from which to restore the checkpoint
    """
    ...

save_checkpoint(self, checkpoint_path) ¤

Save the internal state of the parameter provider.

Parameters¤

checkpoint_path : PathLike Filesystem path at which to save the checkpoint

Source code in corl/episode_parameter_providers/core.py
def save_checkpoint(self, checkpoint_path: PathLike) -> None:  # pylint: disable=no-self-use, unused-argument
    """Save the internal state of the parameter provider.

    Parameters
    ----------
    checkpoint_path : PathLike
        Filesystem path at which to save the checkpoint
    """
    ...

update(self, results, rng) ¤

Update the operation of this provider.

Often used in on_train_result training callbacks.

Parameters¤

results : dict As described by ray.rllib.agents.callbacks.DefaultCallbacks.on_train_result. See https://docs.ray.io/en/master/_modules/ray/rllib/agents/callbacks.html#DefaultCallbacks.on_train_result rng : Union[Generator, RandomState] Random number generator from which to draw random values.

Source code in corl/episode_parameter_providers/core.py
def update(self, results: dict, rng: Randomness) -> None:  # pylint: disable=no-self-use, unused-argument
    """Update the operation of this provider.

    Often used in `on_train_result` training callbacks.

    Parameters
    ----------
    results : dict
        As described by ray.rllib.agents.callbacks.DefaultCallbacks.on_train_result.
        See https://docs.ray.io/en/master/_modules/ray/rllib/agents/callbacks.html#DefaultCallbacks.on_train_result
    rng : Union[Generator, RandomState]
        Random number generator from which to draw random values.
    """
    ...

EpisodeParameterProviderValidator (BaseModel) pydantic-model ¤

Validation model for the inputs of EpisodeParameterProvider

Source code in corl/episode_parameter_providers/core.py
class EpisodeParameterProviderValidator(BaseModel):
    """Validation model for the inputs of EpisodeParameterProvider"""
    parameters: ParameterModel = {}

    class Config:
        """Allow arbitrary types for Parameter"""
        arbitrary_types_allowed = True
        allow_mutation = False

Config ¤

Allow arbitrary types for Parameter

Source code in corl/episode_parameter_providers/core.py
class Config:
    """Allow arbitrary types for Parameter"""
    arbitrary_types_allowed = True
    allow_mutation = False