Base experiment
Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.
This is a US Government Work not subject to copyright protection in the US.
The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.
BaseAutoDetect
¤
Base class interface for setting rllib config if in auto mode
Source code in corl/experiments/base_experiment.py
class BaseAutoDetect:
"""Base class interface for setting rllib config if in auto mode
"""
def autodetect_system(self) -> str:
"""gets the default system based on user defined function
Returns
-------
str
the base system to use
"""
return "local"
autodetect_system(self)
¤
gets the default system based on user defined function
Returns¤
str the base system to use
Source code in corl/experiments/base_experiment.py
def autodetect_system(self) -> str:
"""gets the default system based on user defined function
Returns
-------
str
the base system to use
"""
return "local"
BaseExperiment (ABC)
¤
Experiment provides an anstract class to run specific types of experiments this allows users to do specific setup steps or to run some sort of custom training loop
Source code in corl/experiments/base_experiment.py
class BaseExperiment(abc.ABC):
"""
Experiment provides an anstract class to run specific types of experiments
this allows users to do specific setup steps or to run some sort of custom training
loop
"""
def __init__(self, **kwargs) -> None:
self.config: BaseExperimentValidator = self.get_validator(**kwargs)
@property
def get_validator(self) -> typing.Type[BaseExperimentValidator]:
"""
Get the validator for this experiment class,
the kwargs sent to the experiment class will
be validated using this object and add a self.config
attr to the experiment class
"""
return BaseExperimentValidator
@property
def get_policy_validator(self) -> typing.Type[BasePolicyValidator]:
"""
Get the policy validator for this experiment class,
the kwargs sent to the experiment class will
be validated using this object and add a self.config
attr to the policy config
"""
return BasePolicyValidator
@abc.abstractmethod
def run_experiment(self, args: argparse.Namespace):
"""
Runs the experiment associated with this experiment class
Arguments:
args {argparse.Namespace} -- The args provided by the argparse
in corl.train_rl
"""
...
def create_agents(
self, platform_configs: typing.Sequence[typing.Tuple[str, str]], agent_configs: typing.Sequence[typing.Tuple[str, str, str, str]]
) -> typing.Tuple[dict, dict]:
"""Create the requested agents and add them to the environment configuration.
Parameters
----------
agent_configs : typing.Sequence[typing.Tuple[str, str, str, str]]
A sequence of agents. Each agent consists of a name, configuration filename, platform filename
and policy configuration filename.
"""
platforms = {}
for platform_name, platform_file in platform_configs:
assert platform_name not in platforms, 'duplicate platforms not allowed'
platform_config = load_file(platform_file)
platforms[platform_name] = platform_config
agents = {}
for policy_name, platform_name, agent_file, policy_file in agent_configs:
assert platform_name in platforms, f"invalid platform '{platform_name}' not in {platforms}"
config = load_file(agent_file)
parsed_agent = AgentParseBase(**config)
policy_config = load_file(policy_file)
parsed_policy = self.get_policy_validator(**policy_config)
agents[policy_name] = AgentParseInfo(class_config=parsed_agent, platform_name=platform_name, policy_config=parsed_policy)
return agents, platforms
@staticmethod
def create_other_platforms(other_platforms_config: typing.Sequence[typing.Tuple[str, str]]) -> dict:
"""Create the requested other platforms and add them to the environment configuration.
Parameters
----------
other_platforms_config : typing.Sequence[typing.Tuple[str, str]]
A sequence of platforms. Each platform consists of a name and platform filename.
"""
other_platforms = dict()
if other_platforms_config:
for platform_name, platform_file in other_platforms_config:
platform_config = load_file(platform_file)
other_platforms[platform_name] = platform_config
return other_platforms
get_policy_validator: Type[corl.policies.base_policy.BasePolicyValidator]
property
readonly
¤
Get the policy validator for this experiment class, the kwargs sent to the experiment class will be validated using this object and add a self.config attr to the policy config
get_validator: Type[corl.experiments.base_experiment.BaseExperimentValidator]
property
readonly
¤
Get the validator for this experiment class, the kwargs sent to the experiment class will be validated using this object and add a self.config attr to the experiment class
create_agents(self, platform_configs, agent_configs)
¤
Create the requested agents and add them to the environment configuration.
Parameters¤
agent_configs : typing.Sequence[typing.Tuple[str, str, str, str]] A sequence of agents. Each agent consists of a name, configuration filename, platform filename and policy configuration filename.
Source code in corl/experiments/base_experiment.py
def create_agents(
self, platform_configs: typing.Sequence[typing.Tuple[str, str]], agent_configs: typing.Sequence[typing.Tuple[str, str, str, str]]
) -> typing.Tuple[dict, dict]:
"""Create the requested agents and add them to the environment configuration.
Parameters
----------
agent_configs : typing.Sequence[typing.Tuple[str, str, str, str]]
A sequence of agents. Each agent consists of a name, configuration filename, platform filename
and policy configuration filename.
"""
platforms = {}
for platform_name, platform_file in platform_configs:
assert platform_name not in platforms, 'duplicate platforms not allowed'
platform_config = load_file(platform_file)
platforms[platform_name] = platform_config
agents = {}
for policy_name, platform_name, agent_file, policy_file in agent_configs:
assert platform_name in platforms, f"invalid platform '{platform_name}' not in {platforms}"
config = load_file(agent_file)
parsed_agent = AgentParseBase(**config)
policy_config = load_file(policy_file)
parsed_policy = self.get_policy_validator(**policy_config)
agents[policy_name] = AgentParseInfo(class_config=parsed_agent, platform_name=platform_name, policy_config=parsed_policy)
return agents, platforms
create_other_platforms(other_platforms_config)
staticmethod
¤
Create the requested other platforms and add them to the environment configuration.
Parameters¤
other_platforms_config : typing.Sequence[typing.Tuple[str, str]] A sequence of platforms. Each platform consists of a name and platform filename.
Source code in corl/experiments/base_experiment.py
@staticmethod
def create_other_platforms(other_platforms_config: typing.Sequence[typing.Tuple[str, str]]) -> dict:
"""Create the requested other platforms and add them to the environment configuration.
Parameters
----------
other_platforms_config : typing.Sequence[typing.Tuple[str, str]]
A sequence of platforms. Each platform consists of a name and platform filename.
"""
other_platforms = dict()
if other_platforms_config:
for platform_name, platform_file in other_platforms_config:
platform_config = load_file(platform_file)
other_platforms[platform_name] = platform_config
return other_platforms
run_experiment(self, args)
¤
Runs the experiment associated with this experiment class
Source code in corl/experiments/base_experiment.py
@abc.abstractmethod
def run_experiment(self, args: argparse.Namespace):
"""
Runs the experiment associated with this experiment class
Arguments:
args {argparse.Namespace} -- The args provided by the argparse
in corl.train_rl
"""
...
BaseExperimentValidator (BaseModel)
pydantic-model
¤
Base Validator to subclass for Experiments subclassing BaseExperiment
Source code in corl/experiments/base_experiment.py
class BaseExperimentValidator(BaseModel):
"""
Base Validator to subclass for Experiments subclassing BaseExperiment
"""
...
ExperimentParse (BaseModel)
pydantic-model
¤
[summary] experiment_class: The experiment class to run config: the configuration to pass to that experiment
Source code in corl/experiments/base_experiment.py
class ExperimentParse(BaseModel):
"""[summary]
experiment_class: The experiment class to run
config: the configuration to pass to that experiment
"""
experiment_class: PyObject
auto_system_detect_class: typing.Optional[PyObject] = None
config: typing.Dict[str, typing.Any]
@validator('experiment_class')
def check_experiment_class(cls, v):
"""
Validates the experiment class actually subclasses BaseExperiment Class
"""
if not issubclass(v, BaseExperiment):
raise ValueError(f"Experiment functors must subclass BaseExperiment, but experiment {v}")
return v
@validator('auto_system_detect_class')
def check_auto_system_detect_class(cls, v):
"""
Validates the auto system detect class actually subclasses BaseAutoDetect Class
"""
if v is not None:
if not issubclass(v, BaseAutoDetect):
raise ValueError(f"Experiment functors must subclass BaseAutoDetect, but experiment {v}")
return v
check_auto_system_detect_class(v)
classmethod
¤
Validates the auto system detect class actually subclasses BaseAutoDetect Class
Source code in corl/experiments/base_experiment.py
@validator('auto_system_detect_class')
def check_auto_system_detect_class(cls, v):
"""
Validates the auto system detect class actually subclasses BaseAutoDetect Class
"""
if v is not None:
if not issubclass(v, BaseAutoDetect):
raise ValueError(f"Experiment functors must subclass BaseAutoDetect, but experiment {v}")
return v
check_experiment_class(v)
classmethod
¤
Validates the experiment class actually subclasses BaseExperiment Class
Source code in corl/experiments/base_experiment.py
@validator('experiment_class')
def check_experiment_class(cls, v):
"""
Validates the experiment class actually subclasses BaseExperiment Class
"""
if not issubclass(v, BaseExperiment):
raise ValueError(f"Experiment functors must subclass BaseExperiment, but experiment {v}")
return v