Skip to content

Multi agent env


Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.

This is a US Government Work not subject to copyright protection in the US.

The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.


ACT3MultiAgentEnv (MultiAgentEnv) ¤

ACT3MultiAgentEnv create a RLLIB MultiAgentEnv environment. The following class is intended to wrap the interactions with RLLIB and the backend simulator environment. All items here are intended to be common parts for running the RLLIB environment with ${simulator} being the unique interaction parts.

  1. Includes wrapping the creation of the simulator specific to run
  2. Includes interactions with the dones, rewards, and glues
  3. etc...
Source code in corl/environment/multi_agent_env.py
class ACT3MultiAgentEnv(MultiAgentEnv):
    """
    ACT3MultiAgentEnv create a RLLIB MultiAgentEnv environment. The following class is intended to wrap
    the interactions with RLLIB and the backend simulator environment. All items here are intended to be
    common parts for running the RLLIB environment with ${simulator} being the unique interaction parts.

    1. Includes wrapping the creation of the simulator specific to run
    2. Includes interactions with the dones, rewards, and glues
    3. etc...
    """
    episode_parameter_provider_name: str = 'environment'

    def __init__(self, config: EnvContext) -> None:  # pylint: disable=too-many-statements, super-init-not-called
        """
        __init__ initializes the rllib multi agent environment

        Parameters
        ----------
        config : ray.rllib.env.env_context.EnvContext
            Passed in configuration for setting items up.
            Must have a 'simulator' key whose value is a BaseIntegrator type
        """
        try:
            config_vars = vars(config)
        except TypeError:
            config_vars = {}
        self.config: ACT3MultiAgentEnvValidator = self.get_validator(**config, **config_vars)

        # Random numbers
        self.seed(self.config.seed)

        # setup default instance variables
        self._actions: list = []
        self._obs_buffer = ObsBuffer()
        self._reward: RewardDict = RewardDict()
        self._done: DoneDict = DoneDict()
        self._info: OrderedDict = OrderedDict()
        self._episode_length: int = 0
        self._episode: int = 0
        self._episode_id: typing.Union[int, None]

        # agent glue dict is a mapping from agent id to a dict with keys for the glue names
        # and values of the actual glue object
        self._agent_glue_dict: OrderedDict = OrderedDict()
        self._agent_glue_obs_export_behavior: OrderedDict = OrderedDict()

        # Create the logger
        self._logger = logging.getLogger(ACT3MultiAgentEnv.__name__)

        # Extra simulation init args
        # assign the new output_path with the worker index back to the config for the sim/integration output_path
        extra_sim_init_args: typing.Dict[str, typing.Any] = {
            "output_path": str(self.config.output_path),
            "worker_index": self.config.worker_index,
            "vector_index": self.config.vector_index if self.config.vector_index else 0,
        }

        self.agent_dict, extra_sim_init_args["agent_configs"] = env_creation.create_agent_sim_configs(
            self.config.agents, self.config.agent_platforms, self.config.simulator.type, self.config.platforms, self.config.epp_registry,
            multiple_workers=(self.config.num_workers > 0)
        )

        def compute_lcm(values: typing.List[fractions.Fraction]) -> float:
            assert len(values) > 0
            lcm = values[0].denominator
            for v in values:
                lcm = lcm // math.gcd(lcm, v.denominator) * v.denominator
            return 1.0 / lcm

        max_rate = self.config.max_agent_rate
        self._agent_periods = {
            agent_id: fractions.Fraction(1.0 / agent.frame_rate).limit_denominator(max_rate)
            for agent_id, agent in self.agent_dict.items()
        }
        self._agent_process_time: typing.Dict[str, float] = defaultdict(lambda: sys.float_info.min)
        self.sim_period = compute_lcm(list(self._agent_periods.values()))
        extra_sim_init_args['frame_rate'] = 1.0 / self.sim_period

        for agent_name, platform in self.config.other_platforms.items():
            extra_sim_init_args["agent_configs"][agent_name] = {
                "platform_config": platform,
                "parts_list": [],
            }

        # Debug logging
        self._logger.debug(f"output_path : {self.config.output_path}")

        # Sample parameter provider
        default_parameters = self.config.epp.config.parameters
        self.local_variable_store = flatten_dict.unflatten({k: v.get_value(self.rng) for k, v in default_parameters.items()})
        for agent in self.agent_dict.values():
            agent.fill_parameters(rng=self.rng, default_parameters=True)

        # Create the simulator for this gym environment
        # ----  oddity from other simulator bases HLP
        if not hasattr(self, "_simulator"):

            class SimulatorWrapper(self.config.simulator.type):  # type: ignore
                """Wrapper that injects platforms/time into state dict"""

                def _clear_data(self) -> None:
                    if 'sim_time' in self._state:
                        del self._state['sim_time']
                    if 'sim_platforms' in self._state:
                        del self._state['sim_platforms']

                def _inject_data(self, state: StateDict) -> StateDict:
                    """Ensures that time/platforms exists in state"""
                    if 'sim_time' not in state:
                        state['sim_time'] = self.sim_time

                    if 'sim_platforms' not in state:
                        state['sim_platforms'] = self.platforms

                    return state

                def step(self) -> StateDict:
                    """Steps the simulation - injects data into StateDict"""
                    self._clear_data()
                    return self._inject_data(super().step())

                def reset(self, *args, **kwargs) -> StateDict:
                    """Resets the simulation - injects data into StateDict"""
                    self._clear_data()
                    return self._inject_data(super().reset(*args, **kwargs))

            simulator_factory = copy.deepcopy(self.config.simulator)
            simulator_factory.type = SimulatorWrapper  # type: ignore

            self._simulator: BaseSimulator = simulator_factory.build(**extra_sim_init_args)

        self._state, self._sim_reset_args = self._reset_simulator(extra_sim_init_args["agent_configs"])

        # Make the glue objects from the glue mapping now that we have a simulator created
        self._make_glues()

        # create dictionary to hold done history
        self.__setup_state_history()

        # Create the observation and action space now that we have the glue
        self._observation_space: gym.spaces.Dict = self.__create_space(space_getter=lambda glue_obj: glue_obj.observation_space())
        self._action_space: gym.spaces.Dict = self.__create_space(space_getter=lambda glue_obj: glue_obj.action_space())
        gym_space_sort(self._action_space)
        self._normalized_observation_space: gym.spaces.Dict = self.__create_space(
            space_getter=lambda glue_obj: glue_obj.normalized_observation_space()
            if glue_obj.config.training_export_behavior == TrainingExportBehavior.INCLUDE else None
        )
        self._normalized_action_space: gym.spaces.Dict = self.__create_space(
            space_getter=lambda glue_obj: glue_obj.normalized_action_space()
        )
        gym_space_sort(self._normalized_action_space)
        self._observation_units = self.__create_space(
            space_getter=lambda glue_obj: glue_obj.observation_units() if hasattr(glue_obj, "observation_units") else None
        )
        self._shared_done: DoneDict = DoneDict()
        self._done_info: OrderedDict = OrderedDict()
        self._reward_info: OrderedDict = OrderedDict()

        self._episode_init_params: dict

        self.done_string = ""
        self._agent_ids = set(self._action_space.spaces.keys())

        self._skip_action = False

    @property
    def get_validator(self) -> typing.Type[ACT3MultiAgentEnvValidator]:
        """Get the validator for this class."""
        return ACT3MultiAgentEnvValidator

    def reset(self):

        # Sample parameter provider
        current_parameters, self._episode_id = self.config.epp.get_params(self.rng)
        self.local_variable_store = flatten_dict.unflatten({k: v.get_value(self.rng) for k, v in current_parameters.items()})
        for agent in self.agent_dict.values():
            agent.fill_parameters(self.rng)

        # 3. Reset the Done and Reward dictionaries for the next iteration
        self._make_rewards()
        self._make_dones()
        self._shared_done: DoneDict = self._make_shared_dones()

        self._reward: RewardDict = RewardDict()
        self._done: DoneDict = DoneDict()
        self._done_info.clear()

        self.set_default_done_reward()
        # 4. Reset the simulation/integration
        self._state, self._sim_reset_args = self._reset_simulator()
        self._episode_length = 0
        self._actions.clear()
        self._episode += 1

        self._agent_process_time.clear()

        #####################################################################
        # Make glue sections - Given the state of the simulation we need to
        # update the platform interfaces.
        #####################################################################
        self._make_glues()

        #####################################################################
        # get observations
        # For each configured agent read the observations/measurements
        #####################################################################
        agent_list = list(self.agent_dict.keys())
        self._obs_buffer.next_observation = self.__get_observations_from_glues(agent_list)
        self._obs_buffer.update_obs_pointer()
        # The following loop guarantees that durring training that the glue
        # states start with valid values for rates. The number of recommended
        # steps for sim is at least equal to the depth of the rate observation
        # tree (ex: speed - 2, acceleration - 3, jerk - 4) - recommend defaulting
        # to 4 as we do not go higher thank jerk
        # 1 step is always added for the inital obs in reset
        warmup = self.config.sim_warmup_steps
        for _ in range(warmup):
            self._state = self._simulator.step()
            self._obs_buffer.next_observation = self.__get_observations_from_glues(agent_list)
            self._obs_buffer.update_obs_pointer()

        self.__setup_state_history()

        for platform in self._state.sim_platforms:
            self._state.step_state[platform.name] = None
            self._state.episode_history[platform.name].clear()
            self._state.episode_state[platform.name] = OrderedDict()

        # Sanity Checks and Scale
        # The current deep sanity check will not raise error if values are from sample are different from space during reset
        if self.config.deep_sanity_check:
            try:
                self.__sanity_check(self._observation_space, self._obs_buffer.observation)
            except ValueError as err:
                self._save_state_pickle(err)
        else:
            if not self._observation_space.contains(self._obs_buffer.observation):
                raise ValueError('obs not contained in obs space')

        self._create_actions(self.agent_dict, self._obs_buffer.observation)

        #####################################################################
        # return results to RLLIB - Note that RLLIB does not do a recursive
        # isinstance call and as such need to make sure items are
        # OrderedDicts
        #####################################################################
        trainable_observations, _ = self.create_training_observations(agent_list, self._obs_buffer)
        return trainable_observations

    def _reset_simulator(self, agent_configs=None) -> typing.Tuple[StateDict, typing.Dict[str, typing.Any]]:

        sim_reset_args = copy.deepcopy(self.config.simulator_reset_parameters)
        v_store = self.local_variable_store
        deepmerge.always_merger.merge(sim_reset_args, v_store.get('simulator_reset', {}))

        self._process_references(sim_reset_args, v_store)

        for agent_data in self.agent_dict.values():
            deepmerge.always_merger.merge(
                sim_reset_args, {'platforms': {
                    agent_data.platform_name: agent_data.get_simulator_reset_parameters()
                }}
            )

        if agent_configs is not None:
            sim_reset_args["agent_configs_reset"] = agent_configs

        return self._simulator.reset(sim_reset_args), sim_reset_args

    def _process_references(self, sim_reset_args: dict, v_store: dict) -> None:
        """Process the reference store look ups for the position data

        Parameters
        ----------
        sim_reset_args : dict
            The simulator reset parameters
        v_store : dict
            The variable store
        """
        plat_str = "platforms"
        pos_str = "position"
        ref_str = "reference"
        if plat_str in sim_reset_args:
            for plat_k, plat_v in sim_reset_args[plat_str].items():
                if pos_str in plat_v:
                    for position_k, position_v in plat_v[pos_str].items():
                        if isinstance(position_v, dict) and ref_str in position_v:
                            sim_reset_args[plat_str][plat_k][pos_str][position_k] = v_store["reference_store"].get(
                                position_v[ref_str], self.config.reference_store[position_v[ref_str]]
                            )

    def _get_operable_agents(self):
        """Determines which agents are operable in the sim, this becomes stale after the simulation is stepped"""
        operable_platform_names = [
            item.name for item in self.state.sim_platforms if item.operable and not self.state.episode_state.get(item.name, {})
        ]

        operable_agents = {}
        for agent_name, agent in self.agent_dict.items():
            if agent.platform_name in operable_platform_names:
                operable_agents[agent_name] = agent
            else:
                agent.set_removed(True)

        return operable_agents

    def step(self, action_dict: dict) -> typing.Tuple[OrderedDict, OrderedDict, OrderedDict, OrderedDict]:
        # pylint: disable=R0912, R0914, R0915
        """Returns observations from ready agents.

        The returns are dicts mapping from agent_id strings to values. The
        number of agents in the env can vary over time.

            obs (StateDict): New observations for each ready agent.
                episode is just started, the value will be None.
            dones (StateDict): Done values for each ready agent. The
                special key "__all__" (required) is used to indicate env
                termination.
            infos (StateDict): Optional info values for each agent id.
        """
        self._episode_length += 1

        operable_agents = self._get_operable_agents()

        # look to add this bugsplat, but this check won't work for multi fps
        # if set(operable_agents.keys()) != set(action_dict.keys()):
        #     raise RuntimeError("Operable_agents and action_dict keys differ!"
        #                        f"operable={set(operable_agents.keys())} != act={set(action_dict.keys())} "
        #                        "If this happens that means either your platform is not setting non operable correctly"
        #                        " (if extra keys are in operable set) or you do not have a done condition covering "
        #                        "a condition where your platform is going non operable. (if extra keys in act)")

        if self._skip_action:
            raw_action_dict = {}
        else:
            raw_action_dict = self.__apply_action(operable_agents, action_dict)

        # Save current action for future debugging
        self._actions.append(action_dict)

        try:
            self._state = self._simulator.step()
        except ValueError as err:
            self._save_state_pickle(err)

        # MTB - Changing to not replace operable_agents variable
        #       We calculate observations on agents operable after sim step
        #       - This is done because otherwise observations would be invalid
        #       Calculate Dones/Rewards on agents operable before sim step
        #       - This is done because if an agent "dies" it needs to have a final done calculated
        operable_agents_after_step = self._get_operable_agents()

        #####################################################################
        # get next observations - For each configured platform read the
        # observations/measurements
        #####################################################################
        self._obs_buffer.next_observation = self.__get_observations_from_glues(operable_agents_after_step.keys())

        self._info.clear()
        self.__get_info_from_glue(operable_agents_after_step.keys())

        #####################################################################
        # Process the done conditions
        # 1. Reset the rewards from the last step
        # 2. loops over all agents and processes the reward conditions per
        #    agent
        #####################################################################

        agents_done = self.__get_done_from_agents(operable_agents.keys(), raw_action_dict=raw_action_dict)

        expected_done_keys = set(operable_agents.keys())
        expected_done_keys.add('__all__')
        if set(agents_done.keys()) != expected_done_keys:
            raise RuntimeError(
                f'Local dones do not match expected keys.  Received "{agents_done.keys()}".  Expected "{expected_done_keys}".'
            )

        # compute if done all
        if not agents_done['__all__']:
            agent_dones = [v for k, v in agents_done.items() if k != '__all__']
            if self.config.end_episode_on_first_agent_done:
                agents_done['__all__'] = any(agent_dones)
            else:
                agents_done['__all__'] = all(agent_dones)

        shared_dones, shared_done_info = self._shared_done(
            observation=self._obs_buffer.observation,
            action=raw_action_dict,
            next_observation=self._obs_buffer.next_observation,
            next_state=self._state,
            observation_space=self._observation_space,
            observation_units=self._observation_units,
            local_dones=copy.deepcopy(agents_done),
            local_done_info=copy.deepcopy(self._done_info)
        )

        if shared_dones.keys():
            if set(shared_dones.keys()) != expected_done_keys:
                raise RuntimeError(
                    f'Shared dones do not match expected keys.  Received "{shared_dones.keys()}".  Expected "{expected_done_keys}".'
                )
            for key in expected_done_keys:
                agents_done[key] |= shared_dones[key]

        assert shared_done_info is not None

        local_done_info_keys = set(self._done_info.keys())
        shared_done_info_keys = set(shared_done_info)
        common_keys = local_done_info_keys & shared_done_info_keys
        if common_keys:
            raise RuntimeError(f'Dones have common names: "{common_keys}"')

        for done_name, done_keys in shared_done_info.items():
            for agent_name, done_status in done_keys.items():
                self._done_info[agent_name][done_name] = OrderedDict([(self.agent_dict[agent_name].platform_name, done_status)])

        # compute if done all
        if not agents_done['__all__']:
            agent_dones = [v for k, v in agents_done.items() if k != '__all__']
            if self.config.end_episode_on_first_agent_done:
                agents_done['__all__'] = any(agent_dones)
            else:
                agents_done['__all__'] = all(agent_dones)

        # Tell the simulator to mark the episode complete
        if agents_done['__all__']:
            self._simulator.mark_episode_done(self._done_info, self._state.episode_state)

        self._reward.reset()

        if agents_done['__all__']:
            agents_to_process_this_timestep = list(operable_agents.keys())
        else:

            def do_process_agent(self, agent_id) -> bool:
                frame_rate = self._agent_periods[agent_id].numerator / self._agent_periods[agent_id].denominator
                return self._state.sim_time >= self._agent_process_time[agent_id] + frame_rate - self.config.timestep_epsilon

            agents_to_process_this_timestep = list(filter(partial(do_process_agent, self), operable_agents.keys()))

        for agent_id in agents_to_process_this_timestep:
            self._agent_process_time[agent_id] = self._state.sim_time

        reward = self.__get_reward_from_agents(agents_to_process_this_timestep, raw_action_dict=raw_action_dict)

        self._simulator.save_episode_information(self.done_info, self.reward_info, self._obs_buffer.observation)
        # copy over observation from next to previous - There is no real reason to deep
        # copy here. The process of getting a new observation from the glue copies. All
        # we need to do is maintain the order of two buffers!!!.
        # Tested with: They are different and decreasing as expected
        #   print(f"C: {self._obs_buffer.observation['blue0']['ObserveSensor_Sensor_Fuel']}")
        #   print(f"N: {self._obs_buffer.next_observation['blue0']['ObserveSensor_Sensor_Fuel']}")
        self._obs_buffer.update_obs_pointer()
        # Sanity checks and Scale - ensure run first time and run only every N times...
        # Same as RLLIB - This can add a bit of time as we are exploring complex dictionaries
        # default to every time if not specified... Once the limits are good we it is
        # recommended to increase this for training

        if self.config.deep_sanity_check:
            if self._episode_length % self.config.sanity_check_obs == 0:
                try:
                    self.__sanity_check(self._observation_space, self._obs_buffer.observation)
                except ValueError as err:
                    self._save_state_pickle(err)
        else:
            if not self._observation_space.contains(self._obs_buffer.observation):
                raise ValueError('obs not contained in obs space')

        complete_trainable_observations, complete_unnormalized_observations = self.create_training_observations(
            operable_agents, self._obs_buffer
        )
        trainable_observations = OrderedDict()
        for agent_id in agents_to_process_this_timestep:
            trainable_observations[agent_id] = complete_trainable_observations[agent_id]

        trainable_rewards = get_dictionary_subset(reward, agents_to_process_this_timestep)
        trainable_dones = get_dictionary_subset(agents_done, ["__all__"] + agents_to_process_this_timestep)
        trainable_info = get_dictionary_subset(self._info, agents_to_process_this_timestep)

        # add platform obs and env data to trainable_info (for use by custom policies)
        for agent_id in agents_to_process_this_timestep:
            if agent_id not in trainable_info:
                trainable_info[agent_id] = {}

            trainable_info[agent_id]['env'] = {'sim_period': self.sim_period}
            trainable_info[agent_id]['platform_obs'] = {}

            plat_name = self.agent_dict[agent_id].platform_name
            for platform_agent in self.agent_dict:
                if self.agent_dict[platform_agent].platform_name == plat_name:
                    trainable_info[agent_id]['platform_obs'][platform_agent] = complete_unnormalized_observations[platform_agent]

        # if not done all, delete any platforms from simulation that are done, so they don't interfere
        platforms_deleted = set()
        if not agents_done['__all__']:
            for agent_key, value in agents_done.items():
                if agent_key != '__all__' and value:
                    plat_name = self.agent_dict[agent_key].platform_name
                    if plat_name not in platforms_deleted:
                        self.simulator.delete_platform(plat_name)
                        platforms_deleted.add(plat_name)

        # if a platform has been deleted, we need to make sure that all agents on that platform
        # are also done
        for agent_key in trainable_dones.keys():
            if agent_key == '__all__':
                continue
            plat_name = self.agent_dict[agent_key].platform_name
            if plat_name in platforms_deleted:
                trainable_dones[agent_key] = True

        #####################################################################
        # return results to RLLIB - Note that RLLIB does not do a recursive
        # isinstance call and as such need to make sure items are
        # OrderedDicts
        #####################################################################
        return trainable_observations, trainable_rewards, trainable_dones, trainable_info

    def __get_done_from_agents(self, alive_agents: typing.Iterable[str], raw_action_dict):

        def or_merge(config, path, base, nxt):  # pylint: disable=unused-argument
            return base or nxt

        merge_strategies = copy.deepcopy(deepmerge.DEFAULT_TYPE_SPECIFIC_MERGE_STRATEGIES)
        merge_strategies.append((bool, or_merge))
        or_merger = deepmerge.Merger(merge_strategies, [], [])

        done = OrderedDict()
        done["__all__"] = False
        for agent_id in alive_agents:
            agent_class = self.agent_dict[agent_id]
            platform_done, done_info = agent_class.get_dones(
                observation=self._obs_buffer.observation,
                action=raw_action_dict,
                next_observation=self._obs_buffer.next_observation,
                next_state=self._state,
                observation_space=self._observation_space,
                observation_units=self._observation_units
            )
            done[agent_id] = platform_done[agent_class.platform_name]
            # get around reduction
            done["__all__"] = done["__all__"] if done["__all__"] else platform_done.get("__all__", False)
            or_merger.merge(self._done_info.setdefault(agent_id, {}), done_info)
            # self._done_info[agent_id] = done_info
        return done

    def __get_reward_from_agents(self, alive_agents: typing.Iterable[str], raw_action_dict):
        reward = OrderedDict()
        for agent_id in alive_agents:
            agent_class = self.agent_dict[agent_id]
            agent_reward, reward_info = agent_class.get_rewards(
                observation=self._obs_buffer.observation,
                action=raw_action_dict,
                next_observation=self._obs_buffer.next_observation,
                state=self._state,
                next_state=self._state,
                observation_space=self._observation_space,
                observation_units=self._observation_units
            )
            # it is possible to have a HL policy that does not compute an reward
            # in this case just return a zero for reward value
            if agent_id in agent_reward:
                reward[agent_id] = agent_reward[agent_id]
            else:
                reward[agent_id] = 0
            self._reward_info[agent_id] = reward_info
        return reward

    def set_default_done_reward(self):
        """
        Populate the done/rewards with default values
        """
        for key in self.agent_dict.keys():  # pylint: disable=C0201
            self._done[key] = False
            self._reward[key] = 0  # pylint: disable=protected-access
            self._shared_done[key] = False
        self._done[DoneFuncBase._ALL] = False  # pylint: disable=protected-access
        self._shared_done[DoneFuncBase._ALL] = False  # pylint: disable=protected-access

    def create_training_observations(self, alive_agents: typing.Iterable[str],
                                     observations: ObsBuffer) -> typing.Tuple[OrderedDict, OrderedDict]:
        """
        Filters and normalizes observations (the sample of the space) using the glue normalize functions.

        Parameters
        ----------
        alive_agents:
            The agents that are still alive
        observations:
            The observations

        Returns
        -------
        OrderedDict:
            the filtered/normalized observation samples
        """
        this_steps_obs = OrderedDict()
        for agent_id in alive_agents:
            if agent_id in observations.observation:
                this_steps_obs[agent_id] = observations.observation[agent_id]
            elif agent_id in observations.next_observation:
                this_steps_obs[agent_id] = observations.next_observation[agent_id]
            else:
                raise RuntimeError(
                    "ERROR: create_training_observations tried to retrieve obs for this training step"
                    f" but {agent_id=} was not able to be found in either the current obs data or the "
                    " obs from the previous timestep as a fallback"
                )

        def do_export(agent_id, obs_name, _obs):
            if agent_id in alive_agents:
                glue_obj = self.agent_dict[agent_id].get_glue(obs_name)
                if glue_obj is not None:
                    return glue_obj.config.training_export_behavior == TrainingExportBehavior.INCLUDE
            return False

        filtered_observations = filter_observations(this_steps_obs, do_export)

        normalized_observations = mutate_observations(
            filtered_observations,
            lambda agent_id,  # type: ignore
            obs_name,
            obs: self.agent_dict[agent_id].normalize_observation(obs_name, obs)  # type: ignore
        )

        return normalized_observations, filtered_observations

    def __get_observations_from_glues(self, alive_agents: typing.Iterable[str]) -> OrderedDict:  # pylint: disable=protected-access
        """
        Gets the observation dict from all the glue objects for each agent

        Returns
        -------
        OrderedDict:
            The observation dict from all the glues
        """
        return_observation: OrderedDict = OrderedDict()
        p_names = [item.name for item in self._state.sim_platforms]
        for agent_id in alive_agents:
            agent_class = self.agent_dict[agent_id]
            # TODO: Why is this check required here?
            # Why does 'alive_agents' contain agents with platforms that don't exist (or have been removed)?
            # This should only happpen if the user has accidentally passed in 'dead' agents
            if agent_class.platform_name not in p_names:
                self._logger.warning(
                    f"{agent_id} on {agent_class.platform_name} is not in the list of (alive) sim_platforms: {self._state.sim_platforms}"
                )
                agent_class.set_removed(True)
            else:
                glue_obj_obs = agent_class.get_observations()
                if len(glue_obj_obs) > 0:
                    return_observation[agent_id] = glue_obj_obs
        return return_observation

    def __apply_action(self, operable_agents, action_dict):
        raw_action_dict = OrderedDict()
        for agent_id, agent_class in operable_agents.items():
            if agent_id in action_dict:
                raw_action_dict[agent_id] = agent_class.apply_action(action_dict[agent_id])
        return raw_action_dict

    def __get_info_from_glue(self, alive_agents: typing.Iterable[str]):
        for agent_id in alive_agents:
            agent_class = self.agent_dict[agent_id]
            glue_obj_info = agent_class.get_info_dict()
            if len(glue_obj_info) > 0:
                self._info[agent_id] = glue_obj_info

    def _get_observation_units_from_glues(self) -> OrderedDict:  # pylint: disable=protected-access
        """
        Gets the observation dict from all the glue objects for each agent

        Returns
        -------
        OrderedDict:
            The observation dict from all the glues
        """
        return_observation: OrderedDict = OrderedDict()
        p_names = [item.name for item in self._state.sim_platforms]
        for agent_id, glue_name_obj_pair in self._agent_glue_dict.items():
            for glue_name, glue_object in glue_name_obj_pair.items():
                if glue_object._agent_id not in p_names:  # pylint: disable=protected-access
                    glue_object.set_agent_removed(True)
                try:
                    glue_obj_obs = glue_object.observation_units()
                except AttributeError:
                    glue_obj_obs = None

                return_observation.setdefault(agent_id, OrderedDict())[glue_name] = glue_obj_obs
        return return_observation

    def _make_glues(self) -> None:
        """
        """
        env_ref_stores = [self.local_variable_store.get('reference_store', {}), self.config.reference_store]

        plat_to_agent: typing.Dict[str, typing.List[str]] = defaultdict(lambda: [])
        for agent, agent_class in self.agent_dict.items():
            # get the platform for this agent
            plat = self._get_platform_by_name(agent_class.platform_name)
            plat_to_agent[plat.name].append(agent)

            agent_class.make_glues(plat, agent, env_ref_stores=env_ref_stores)

        if self.config.simulator.config.get("disable_exclusivity_check", False):
            return

        def itr_controller_glues(glue):
            if isinstance(glue, BaseAgentControllerGlue):
                yield glue
            if isinstance(glue, BaseWrapperGlue):
                yield from itr_controller_glues(glue.glue())
            if isinstance(glue, BaseDictWrapperGlue):
                for g in glue.glues().values():
                    yield from itr_controller_glues(g)
            if isinstance(glue, BaseMultiWrapperGlue):
                for g in glue.glues():
                    yield from itr_controller_glues(g)

        # validate
        for plat_name, agents in plat_to_agent.items():
            exclusiveness: typing.Set[str] = set()

            for agent in agents:
                agent_class = self.agent_dict[agent]

                for glue in agent_class.agent_glue_dict.values():
                    for controller_glue in itr_controller_glues(glue):
                        assert isinstance(controller_glue, ControllerGlue
                                          ), (f"Unknown controller glue type {type(controller_glue)} on platform {plat_name}")
                        controller_exclusiveness = controller_glue.controller.exclusiveness
                        assert len(controller_exclusiveness.intersection(exclusiveness)
                                   ) == 0, (f"Controllers not mutually exclusive on platform {plat_name}")
                        exclusiveness.update(controller_exclusiveness)

    def _make_rewards(self) -> None:
        """
        """
        env_ref_stores = [self.local_variable_store.get('reference_store', {}), self.config.reference_store]

        for agent, agent_class in self.agent_dict.items():
            agent_class.make_rewards(agent, env_ref_stores=env_ref_stores)

    def _make_dones(self) -> None:
        """
        """
        env_ref_stores = [self.local_variable_store.get('reference_store', {}), self.config.reference_store]

        warmup_steps = self.config.sim_warmup_steps
        episode_length_done = Functor(
            functor=EpisodeLengthDone,
            config={'horizon': {
                'value': (self.config.horizon + warmup_steps) * self.sim_period, 'units': 'second'
            }}
        )

        for agent, agent_class in self.agent_dict.items():
            env_dones = chain(self.config.dones.world, self.config.dones.task[agent_class.platform_name], [episode_length_done])
            env_params = [
                self.local_variable_store.get('world', {}), self.local_variable_store.get('task', {}).get(agent_class.platform_name, {})
            ]
            agent_class.make_dones(agent, agent_class.platform_name, dones=env_dones, env_params=env_params, env_ref_stores=env_ref_stores)

    def _make_shared_dones(self) -> DoneDict:  # pylint: disable=no-self-use
        """
        _get_shared_done_functors gets and initializes the
        shared done dict used for this iteration

        The shared done dictionary does not correspond to individual
        agents but looks sharedly at all agents.

        this will be called after any updates to the simulator
        configuration during reset

        Returns
        -------
        DoneDict
            The DoneDict with functors used for this iteration
        """
        done_conditions = []
        ref_sources = [self.local_variable_store.get('reference_store', {}), self.config.reference_store]
        param_sources = [self.local_variable_store.get('shared', {})]
        for done_functor in self.config.dones.shared:
            tmp = done_functor.create_functor_object(param_sources=param_sources, ref_sources=ref_sources)
            done_conditions.append(tmp)
        return DoneDict(processing_funcs=done_conditions)

    @staticmethod
    def _create_actions(agents, observations={}, rewards={}, dones={}, info={}):  # pylint: disable=dangerous-default-value
        for agent_name, agent_class in agents.items():
            agent_class.create_next_action(
                observations.get(agent_name), rewards.get(agent_name), dones.get(agent_name), info.get(agent_name)
            )

    def __create_space(self, space_getter) -> gym.spaces.Dict:
        """
        _create_space Create spaces for all agents by calling their glue objects

        Parameters
        ----------
        space_getter
            A function that takes a glue_obj: BaseAgentGlue and returns the space for creating this space
            For example    space_getter=lambda glue_obj: glue_obj.observation_space()
        Returns
        -------
            A space build from all the glues for each agent in this Environment
        """
        # init our return Dict
        return_space = gym.spaces.dict.Dict()
        # loop over all the agents and their glue_name_obj_pairs list
        for agent_id, agent_class in self.agent_dict.items():
            tmp_space = agent_class.create_space(space_getter)
            # if this agent provided anything add it to the return space
            if tmp_space:
                return_space.spaces[agent_id] = tmp_space
        return return_space

    @property
    def action_space(self) -> gym.spaces.Space:
        """
        action_space The action space

        Returns
        -------
        typing.Dict[str,gym.spaces.tuple.Tuple]
            The action space
        """
        return self._normalized_action_space

    ############
    # Properties
    ############
    @property
    def glue_info(self) -> OrderedDict:
        """[summary]

        Returns:
            Union[OrderedDict, None] -- [description]
        """
        return self._info

    @property
    def done_info(self) -> OrderedDict:
        """[summary]

        Returns
        -------
        Union[OrderedDict, None]
            [description]
        """
        return self._done_info

    @property
    def reward_info(self) -> OrderedDict:
        """[summary]

        Returns
        -------
        Union[OrderedDict, None]
            [description]
        """
        return self._reward_info

    @property
    def observation_space(self) -> gym.spaces.Space:
        """
        observation_space The observation space setup by the user

        Returns
        -------
        gym.spaces.dict.Dict
            The observation space
        """
        return self._normalized_observation_space

    @property
    def state(self) -> StateDict:
        """
        state of platform object.  Current state.

        Returns
        -------
        StateDict
            the dict storing the curent state of environment.
        """
        return self._state

    @property
    def simulator(self) -> BaseSimulator:
        """
        simulator simulator instance

        Returns
        -------
        BaseSimulator
            The simulator instance in the base
        """
        return self._simulator

    @property
    def observation(self) -> OrderedDict:
        """
        observation get the observation for the agents in this environment

        Returns
        -------
        OrderedDict
            the dict holding the observations for the agents
        """
        return self._obs_buffer.observation

    @property
    def episode_id(self) -> typing.Union[int, None]:
        """
        get the current episode parameter provider episode id

        Returns
        -------
        int or None
            the episode id
        """
        return self._episode_id

    def _get_platform_by_name(self, platform_id: str) -> BasePlatform:
        platform: BasePlatform = None  # type: ignore
        for plat in self._state.sim_platforms:
            if plat.name == platform_id:
                platform = plat

        if platform is None or not issubclass(platform.__class__, BasePlatform):
            self._logger.error("-" * 100)
            self._logger.error(f"{self._state}")
            for i in self._state.sim_platforms:
                self._logger.error(f"{i.name}")
            raise ValueError(f"{self.__class__.__name__} glue could not find a platform named {platform_id} of class BasePlatform")

        return platform

    def __setup_state_history(self):
        self._state.episode_history = defaultdict(partial(deque, maxlen=self.config.horizon))
        self._state.episode_state = OrderedDict()
        self._state.step_state = OrderedDict()

    def post_process_trajectory(self, agent_id, batch, episode, policy):
        """easy accessor for calling post process trajectory
            correctly

        Arguments:
            agent_id {[type]} -- agent id
            batch {[type]} -- post processed Batch - be careful modifying
        """
        self.agent_dict[agent_id].post_process_trajectory(
            agent_id,
            episode.worker.env._state,  # pylint: disable=protected-access
            batch,
            episode,
            policy,
            episode.worker.env._reward_info  # pylint: disable=protected-access
        )

    @staticmethod
    def __sanity_check(space: gym.spaces.Space, space_sample: EnvSpaceUtil.sample_type) -> None:
        """
        Sanity checks a space_sample against a space
        1. Check to ensure that the sample from the integration base
           Fall within the expected range of values.

        Note: space_sample and space expected to match up on
        Key level entries

        Parameters
        ----------
        space: gym.spaces.Space
            the space to check the sample against
        space_sample: EnvSpaceUtil.sample_type
            the sample to check if it is actually in the bounds of the space

        Returns
        -------
        OrderedDict:
            the scaled_observations
        """
        if not space.contains(space_sample):
            space.contains(space_sample)
            EnvSpaceUtil.deep_sanity_check_space_sample(space, space_sample)

    def _save_state_pickle(self, err: ValueError):
        """saves state for later debug

        Arguments:
            err {ValueError} -- Traceback for the error

        Raises:
            err: Customized error message to raise for the exception

        """
        out_pickle = str(self.config.output_path / f"sanity_check_failure_{self._episode}.pkl")
        p_dict: typing.Dict[str, typing.Any] = {}
        p_dict["err"] = str(err)

        class NumpyArrayEncoder(JSONEncoder):
            """Encode the numpy types for json
            """

            def default(self, obj):  # pylint: disable=arguments-differ
                val = None
                if isinstance(
                    obj,
                    (
                        np.int_,
                        np.intc,
                        np.intp,
                        np.int8,
                        np.int16,
                        np.int32,
                        np.int64,
                        np.uint8,
                        np.uint16,
                        np.uint32,
                        np.uint64,
                    ),
                ):
                    val = int(obj)
                elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
                    val = float(obj)
                elif isinstance(obj, (np.ndarray, )):  # This is the fix
                    val = obj.tolist()
                elif isinstance(obj, np.bool_):
                    val = 'True' if obj is True else 'False'
                else:
                    val = json.JSONEncoder.default(self, obj)

                return val

        def to_dict(input_ordered_dict):
            return loads(dumps(input_ordered_dict, cls=NumpyArrayEncoder))

        p_dict['action'] = self._actions  # type: ignore
        p_dict["observation"] = self._obs_buffer.observation  # type: ignore
        p_dict["dones"] = to_dict(self._done_info)  # type: ignore
        # p_dict["env_config"] = copy.deepcopy(self.env_config)  # type: ignore
        p_dict["step"] = str(self._episode_length)

        pickle.dump(p_dict, open(out_pickle, "wb"))

        raise ValueError(f"Error occurred: {err} \n Saving sanity check failure output pickle to file: {out_pickle}")

    def seed(self, seed=None):
        """generates environment seed through rllib

        Keyword Arguments:
            seed {[int]} -- seed to set environment with (default: {None})

        Returns:
            [int] -- [seed value]
        """
        if not hasattr(self, "rng"):
            self.rng, self.config.seed = gym.utils.seeding.np_random(seed)
        return [self.config.seed]

action_space: Space property readonly ¤

action_space The action space

Returns¤

typing.Dict[str,gym.spaces.tuple.Tuple] The action space

done_info: OrderedDict property readonly ¤

[summary]

Returns¤

Union[OrderedDict, None] [description]

episode_id: Optional[int] property readonly ¤

get the current episode parameter provider episode id

Returns¤

int or None the episode id

get_validator: Type[corl.environment.multi_agent_env.ACT3MultiAgentEnvValidator] property readonly ¤

Get the validator for this class.

glue_info: OrderedDict property readonly ¤

[summary]

Returns:

Type Description
OrderedDict

Union[OrderedDict, None] -- [description]

observation: OrderedDict property readonly ¤

observation get the observation for the agents in this environment

Returns¤

OrderedDict the dict holding the observations for the agents

observation_space: Space property readonly ¤

observation_space The observation space setup by the user

Returns¤

gym.spaces.dict.Dict The observation space

reward_info: OrderedDict property readonly ¤

[summary]

Returns¤

Union[OrderedDict, None] [description]

simulator: BaseSimulator property readonly ¤

simulator simulator instance

Returns¤

BaseSimulator The simulator instance in the base

state: StateDict property readonly ¤

state of platform object. Current state.

Returns¤

StateDict the dict storing the curent state of environment.

__init__(self, config) special ¤

init initializes the rllib multi agent environment

Parameters¤

config : ray.rllib.env.env_context.EnvContext Passed in configuration for setting items up. Must have a 'simulator' key whose value is a BaseIntegrator type

Source code in corl/environment/multi_agent_env.py
def __init__(self, config: EnvContext) -> None:  # pylint: disable=too-many-statements, super-init-not-called
    """
    __init__ initializes the rllib multi agent environment

    Parameters
    ----------
    config : ray.rllib.env.env_context.EnvContext
        Passed in configuration for setting items up.
        Must have a 'simulator' key whose value is a BaseIntegrator type
    """
    try:
        config_vars = vars(config)
    except TypeError:
        config_vars = {}
    self.config: ACT3MultiAgentEnvValidator = self.get_validator(**config, **config_vars)

    # Random numbers
    self.seed(self.config.seed)

    # setup default instance variables
    self._actions: list = []
    self._obs_buffer = ObsBuffer()
    self._reward: RewardDict = RewardDict()
    self._done: DoneDict = DoneDict()
    self._info: OrderedDict = OrderedDict()
    self._episode_length: int = 0
    self._episode: int = 0
    self._episode_id: typing.Union[int, None]

    # agent glue dict is a mapping from agent id to a dict with keys for the glue names
    # and values of the actual glue object
    self._agent_glue_dict: OrderedDict = OrderedDict()
    self._agent_glue_obs_export_behavior: OrderedDict = OrderedDict()

    # Create the logger
    self._logger = logging.getLogger(ACT3MultiAgentEnv.__name__)

    # Extra simulation init args
    # assign the new output_path with the worker index back to the config for the sim/integration output_path
    extra_sim_init_args: typing.Dict[str, typing.Any] = {
        "output_path": str(self.config.output_path),
        "worker_index": self.config.worker_index,
        "vector_index": self.config.vector_index if self.config.vector_index else 0,
    }

    self.agent_dict, extra_sim_init_args["agent_configs"] = env_creation.create_agent_sim_configs(
        self.config.agents, self.config.agent_platforms, self.config.simulator.type, self.config.platforms, self.config.epp_registry,
        multiple_workers=(self.config.num_workers > 0)
    )

    def compute_lcm(values: typing.List[fractions.Fraction]) -> float:
        assert len(values) > 0
        lcm = values[0].denominator
        for v in values:
            lcm = lcm // math.gcd(lcm, v.denominator) * v.denominator
        return 1.0 / lcm

    max_rate = self.config.max_agent_rate
    self._agent_periods = {
        agent_id: fractions.Fraction(1.0 / agent.frame_rate).limit_denominator(max_rate)
        for agent_id, agent in self.agent_dict.items()
    }
    self._agent_process_time: typing.Dict[str, float] = defaultdict(lambda: sys.float_info.min)
    self.sim_period = compute_lcm(list(self._agent_periods.values()))
    extra_sim_init_args['frame_rate'] = 1.0 / self.sim_period

    for agent_name, platform in self.config.other_platforms.items():
        extra_sim_init_args["agent_configs"][agent_name] = {
            "platform_config": platform,
            "parts_list": [],
        }

    # Debug logging
    self._logger.debug(f"output_path : {self.config.output_path}")

    # Sample parameter provider
    default_parameters = self.config.epp.config.parameters
    self.local_variable_store = flatten_dict.unflatten({k: v.get_value(self.rng) for k, v in default_parameters.items()})
    for agent in self.agent_dict.values():
        agent.fill_parameters(rng=self.rng, default_parameters=True)

    # Create the simulator for this gym environment
    # ----  oddity from other simulator bases HLP
    if not hasattr(self, "_simulator"):

        class SimulatorWrapper(self.config.simulator.type):  # type: ignore
            """Wrapper that injects platforms/time into state dict"""

            def _clear_data(self) -> None:
                if 'sim_time' in self._state:
                    del self._state['sim_time']
                if 'sim_platforms' in self._state:
                    del self._state['sim_platforms']

            def _inject_data(self, state: StateDict) -> StateDict:
                """Ensures that time/platforms exists in state"""
                if 'sim_time' not in state:
                    state['sim_time'] = self.sim_time

                if 'sim_platforms' not in state:
                    state['sim_platforms'] = self.platforms

                return state

            def step(self) -> StateDict:
                """Steps the simulation - injects data into StateDict"""
                self._clear_data()
                return self._inject_data(super().step())

            def reset(self, *args, **kwargs) -> StateDict:
                """Resets the simulation - injects data into StateDict"""
                self._clear_data()
                return self._inject_data(super().reset(*args, **kwargs))

        simulator_factory = copy.deepcopy(self.config.simulator)
        simulator_factory.type = SimulatorWrapper  # type: ignore

        self._simulator: BaseSimulator = simulator_factory.build(**extra_sim_init_args)

    self._state, self._sim_reset_args = self._reset_simulator(extra_sim_init_args["agent_configs"])

    # Make the glue objects from the glue mapping now that we have a simulator created
    self._make_glues()

    # create dictionary to hold done history
    self.__setup_state_history()

    # Create the observation and action space now that we have the glue
    self._observation_space: gym.spaces.Dict = self.__create_space(space_getter=lambda glue_obj: glue_obj.observation_space())
    self._action_space: gym.spaces.Dict = self.__create_space(space_getter=lambda glue_obj: glue_obj.action_space())
    gym_space_sort(self._action_space)
    self._normalized_observation_space: gym.spaces.Dict = self.__create_space(
        space_getter=lambda glue_obj: glue_obj.normalized_observation_space()
        if glue_obj.config.training_export_behavior == TrainingExportBehavior.INCLUDE else None
    )
    self._normalized_action_space: gym.spaces.Dict = self.__create_space(
        space_getter=lambda glue_obj: glue_obj.normalized_action_space()
    )
    gym_space_sort(self._normalized_action_space)
    self._observation_units = self.__create_space(
        space_getter=lambda glue_obj: glue_obj.observation_units() if hasattr(glue_obj, "observation_units") else None
    )
    self._shared_done: DoneDict = DoneDict()
    self._done_info: OrderedDict = OrderedDict()
    self._reward_info: OrderedDict = OrderedDict()

    self._episode_init_params: dict

    self.done_string = ""
    self._agent_ids = set(self._action_space.spaces.keys())

    self._skip_action = False

create_training_observations(self, alive_agents, observations) ¤

Filters and normalizes observations (the sample of the space) using the glue normalize functions.

Parameters¤

Alive_agents

The agents that are still alive

Observations

The observations

Returns¤

Ordereddict

the filtered/normalized observation samples

Source code in corl/environment/multi_agent_env.py
def create_training_observations(self, alive_agents: typing.Iterable[str],
                                 observations: ObsBuffer) -> typing.Tuple[OrderedDict, OrderedDict]:
    """
    Filters and normalizes observations (the sample of the space) using the glue normalize functions.

    Parameters
    ----------
    alive_agents:
        The agents that are still alive
    observations:
        The observations

    Returns
    -------
    OrderedDict:
        the filtered/normalized observation samples
    """
    this_steps_obs = OrderedDict()
    for agent_id in alive_agents:
        if agent_id in observations.observation:
            this_steps_obs[agent_id] = observations.observation[agent_id]
        elif agent_id in observations.next_observation:
            this_steps_obs[agent_id] = observations.next_observation[agent_id]
        else:
            raise RuntimeError(
                "ERROR: create_training_observations tried to retrieve obs for this training step"
                f" but {agent_id=} was not able to be found in either the current obs data or the "
                " obs from the previous timestep as a fallback"
            )

    def do_export(agent_id, obs_name, _obs):
        if agent_id in alive_agents:
            glue_obj = self.agent_dict[agent_id].get_glue(obs_name)
            if glue_obj is not None:
                return glue_obj.config.training_export_behavior == TrainingExportBehavior.INCLUDE
        return False

    filtered_observations = filter_observations(this_steps_obs, do_export)

    normalized_observations = mutate_observations(
        filtered_observations,
        lambda agent_id,  # type: ignore
        obs_name,
        obs: self.agent_dict[agent_id].normalize_observation(obs_name, obs)  # type: ignore
    )

    return normalized_observations, filtered_observations

post_process_trajectory(self, agent_id, batch, episode, policy) ¤

easy accessor for calling post process trajectory correctly

Source code in corl/environment/multi_agent_env.py
def post_process_trajectory(self, agent_id, batch, episode, policy):
    """easy accessor for calling post process trajectory
        correctly

    Arguments:
        agent_id {[type]} -- agent id
        batch {[type]} -- post processed Batch - be careful modifying
    """
    self.agent_dict[agent_id].post_process_trajectory(
        agent_id,
        episode.worker.env._state,  # pylint: disable=protected-access
        batch,
        episode,
        policy,
        episode.worker.env._reward_info  # pylint: disable=protected-access
    )

reset(self) ¤

Resets the env and returns observations from ready agents.

Returns:

Type Description

New observations for each ready agent.

Examples:

>>> from ray.rllib.env.multi_agent_env import MultiAgentEnv
>>> class MyMultiAgentEnv(MultiAgentEnv):
...     # Define your env here.
...     ...
>>> env = MyMultiAgentEnv()
>>> obs = env.reset()
>>> print(obs)
{
    "car_0": [2.4, 1.6],
    "car_1": [3.4, -3.2],
    "traffic_light_1": [0, 3, 5, 1],
}
Source code in corl/environment/multi_agent_env.py
def reset(self):

    # Sample parameter provider
    current_parameters, self._episode_id = self.config.epp.get_params(self.rng)
    self.local_variable_store = flatten_dict.unflatten({k: v.get_value(self.rng) for k, v in current_parameters.items()})
    for agent in self.agent_dict.values():
        agent.fill_parameters(self.rng)

    # 3. Reset the Done and Reward dictionaries for the next iteration
    self._make_rewards()
    self._make_dones()
    self._shared_done: DoneDict = self._make_shared_dones()

    self._reward: RewardDict = RewardDict()
    self._done: DoneDict = DoneDict()
    self._done_info.clear()

    self.set_default_done_reward()
    # 4. Reset the simulation/integration
    self._state, self._sim_reset_args = self._reset_simulator()
    self._episode_length = 0
    self._actions.clear()
    self._episode += 1

    self._agent_process_time.clear()

    #####################################################################
    # Make glue sections - Given the state of the simulation we need to
    # update the platform interfaces.
    #####################################################################
    self._make_glues()

    #####################################################################
    # get observations
    # For each configured agent read the observations/measurements
    #####################################################################
    agent_list = list(self.agent_dict.keys())
    self._obs_buffer.next_observation = self.__get_observations_from_glues(agent_list)
    self._obs_buffer.update_obs_pointer()
    # The following loop guarantees that durring training that the glue
    # states start with valid values for rates. The number of recommended
    # steps for sim is at least equal to the depth of the rate observation
    # tree (ex: speed - 2, acceleration - 3, jerk - 4) - recommend defaulting
    # to 4 as we do not go higher thank jerk
    # 1 step is always added for the inital obs in reset
    warmup = self.config.sim_warmup_steps
    for _ in range(warmup):
        self._state = self._simulator.step()
        self._obs_buffer.next_observation = self.__get_observations_from_glues(agent_list)
        self._obs_buffer.update_obs_pointer()

    self.__setup_state_history()

    for platform in self._state.sim_platforms:
        self._state.step_state[platform.name] = None
        self._state.episode_history[platform.name].clear()
        self._state.episode_state[platform.name] = OrderedDict()

    # Sanity Checks and Scale
    # The current deep sanity check will not raise error if values are from sample are different from space during reset
    if self.config.deep_sanity_check:
        try:
            self.__sanity_check(self._observation_space, self._obs_buffer.observation)
        except ValueError as err:
            self._save_state_pickle(err)
    else:
        if not self._observation_space.contains(self._obs_buffer.observation):
            raise ValueError('obs not contained in obs space')

    self._create_actions(self.agent_dict, self._obs_buffer.observation)

    #####################################################################
    # return results to RLLIB - Note that RLLIB does not do a recursive
    # isinstance call and as such need to make sure items are
    # OrderedDicts
    #####################################################################
    trainable_observations, _ = self.create_training_observations(agent_list, self._obs_buffer)
    return trainable_observations

seed(self, seed=None) ¤

generates environment seed through rllib

Keyword arguments:

Name Type Description
seed {[int]} -- seed to set environment with (default

{None})

Returns:

Type Description

[int] -- [seed value]

Source code in corl/environment/multi_agent_env.py
def seed(self, seed=None):
    """generates environment seed through rllib

    Keyword Arguments:
        seed {[int]} -- seed to set environment with (default: {None})

    Returns:
        [int] -- [seed value]
    """
    if not hasattr(self, "rng"):
        self.rng, self.config.seed = gym.utils.seeding.np_random(seed)
    return [self.config.seed]

set_default_done_reward(self) ¤

Populate the done/rewards with default values

Source code in corl/environment/multi_agent_env.py
def set_default_done_reward(self):
    """
    Populate the done/rewards with default values
    """
    for key in self.agent_dict.keys():  # pylint: disable=C0201
        self._done[key] = False
        self._reward[key] = 0  # pylint: disable=protected-access
        self._shared_done[key] = False
    self._done[DoneFuncBase._ALL] = False  # pylint: disable=protected-access
    self._shared_done[DoneFuncBase._ALL] = False  # pylint: disable=protected-access

step(self, action_dict) ¤

Returns observations from ready agents.

The returns are dicts mapping from agent_id strings to values. The number of agents in the env can vary over time.

obs (StateDict): New observations for each ready agent.
    episode is just started, the value will be None.
dones (StateDict): Done values for each ready agent. The
    special key "__all__" (required) is used to indicate env
    termination.
infos (StateDict): Optional info values for each agent id.
Source code in corl/environment/multi_agent_env.py
def step(self, action_dict: dict) -> typing.Tuple[OrderedDict, OrderedDict, OrderedDict, OrderedDict]:
    # pylint: disable=R0912, R0914, R0915
    """Returns observations from ready agents.

    The returns are dicts mapping from agent_id strings to values. The
    number of agents in the env can vary over time.

        obs (StateDict): New observations for each ready agent.
            episode is just started, the value will be None.
        dones (StateDict): Done values for each ready agent. The
            special key "__all__" (required) is used to indicate env
            termination.
        infos (StateDict): Optional info values for each agent id.
    """
    self._episode_length += 1

    operable_agents = self._get_operable_agents()

    # look to add this bugsplat, but this check won't work for multi fps
    # if set(operable_agents.keys()) != set(action_dict.keys()):
    #     raise RuntimeError("Operable_agents and action_dict keys differ!"
    #                        f"operable={set(operable_agents.keys())} != act={set(action_dict.keys())} "
    #                        "If this happens that means either your platform is not setting non operable correctly"
    #                        " (if extra keys are in operable set) or you do not have a done condition covering "
    #                        "a condition where your platform is going non operable. (if extra keys in act)")

    if self._skip_action:
        raw_action_dict = {}
    else:
        raw_action_dict = self.__apply_action(operable_agents, action_dict)

    # Save current action for future debugging
    self._actions.append(action_dict)

    try:
        self._state = self._simulator.step()
    except ValueError as err:
        self._save_state_pickle(err)

    # MTB - Changing to not replace operable_agents variable
    #       We calculate observations on agents operable after sim step
    #       - This is done because otherwise observations would be invalid
    #       Calculate Dones/Rewards on agents operable before sim step
    #       - This is done because if an agent "dies" it needs to have a final done calculated
    operable_agents_after_step = self._get_operable_agents()

    #####################################################################
    # get next observations - For each configured platform read the
    # observations/measurements
    #####################################################################
    self._obs_buffer.next_observation = self.__get_observations_from_glues(operable_agents_after_step.keys())

    self._info.clear()
    self.__get_info_from_glue(operable_agents_after_step.keys())

    #####################################################################
    # Process the done conditions
    # 1. Reset the rewards from the last step
    # 2. loops over all agents and processes the reward conditions per
    #    agent
    #####################################################################

    agents_done = self.__get_done_from_agents(operable_agents.keys(), raw_action_dict=raw_action_dict)

    expected_done_keys = set(operable_agents.keys())
    expected_done_keys.add('__all__')
    if set(agents_done.keys()) != expected_done_keys:
        raise RuntimeError(
            f'Local dones do not match expected keys.  Received "{agents_done.keys()}".  Expected "{expected_done_keys}".'
        )

    # compute if done all
    if not agents_done['__all__']:
        agent_dones = [v for k, v in agents_done.items() if k != '__all__']
        if self.config.end_episode_on_first_agent_done:
            agents_done['__all__'] = any(agent_dones)
        else:
            agents_done['__all__'] = all(agent_dones)

    shared_dones, shared_done_info = self._shared_done(
        observation=self._obs_buffer.observation,
        action=raw_action_dict,
        next_observation=self._obs_buffer.next_observation,
        next_state=self._state,
        observation_space=self._observation_space,
        observation_units=self._observation_units,
        local_dones=copy.deepcopy(agents_done),
        local_done_info=copy.deepcopy(self._done_info)
    )

    if shared_dones.keys():
        if set(shared_dones.keys()) != expected_done_keys:
            raise RuntimeError(
                f'Shared dones do not match expected keys.  Received "{shared_dones.keys()}".  Expected "{expected_done_keys}".'
            )
        for key in expected_done_keys:
            agents_done[key] |= shared_dones[key]

    assert shared_done_info is not None

    local_done_info_keys = set(self._done_info.keys())
    shared_done_info_keys = set(shared_done_info)
    common_keys = local_done_info_keys & shared_done_info_keys
    if common_keys:
        raise RuntimeError(f'Dones have common names: "{common_keys}"')

    for done_name, done_keys in shared_done_info.items():
        for agent_name, done_status in done_keys.items():
            self._done_info[agent_name][done_name] = OrderedDict([(self.agent_dict[agent_name].platform_name, done_status)])

    # compute if done all
    if not agents_done['__all__']:
        agent_dones = [v for k, v in agents_done.items() if k != '__all__']
        if self.config.end_episode_on_first_agent_done:
            agents_done['__all__'] = any(agent_dones)
        else:
            agents_done['__all__'] = all(agent_dones)

    # Tell the simulator to mark the episode complete
    if agents_done['__all__']:
        self._simulator.mark_episode_done(self._done_info, self._state.episode_state)

    self._reward.reset()

    if agents_done['__all__']:
        agents_to_process_this_timestep = list(operable_agents.keys())
    else:

        def do_process_agent(self, agent_id) -> bool:
            frame_rate = self._agent_periods[agent_id].numerator / self._agent_periods[agent_id].denominator
            return self._state.sim_time >= self._agent_process_time[agent_id] + frame_rate - self.config.timestep_epsilon

        agents_to_process_this_timestep = list(filter(partial(do_process_agent, self), operable_agents.keys()))

    for agent_id in agents_to_process_this_timestep:
        self._agent_process_time[agent_id] = self._state.sim_time

    reward = self.__get_reward_from_agents(agents_to_process_this_timestep, raw_action_dict=raw_action_dict)

    self._simulator.save_episode_information(self.done_info, self.reward_info, self._obs_buffer.observation)
    # copy over observation from next to previous - There is no real reason to deep
    # copy here. The process of getting a new observation from the glue copies. All
    # we need to do is maintain the order of two buffers!!!.
    # Tested with: They are different and decreasing as expected
    #   print(f"C: {self._obs_buffer.observation['blue0']['ObserveSensor_Sensor_Fuel']}")
    #   print(f"N: {self._obs_buffer.next_observation['blue0']['ObserveSensor_Sensor_Fuel']}")
    self._obs_buffer.update_obs_pointer()
    # Sanity checks and Scale - ensure run first time and run only every N times...
    # Same as RLLIB - This can add a bit of time as we are exploring complex dictionaries
    # default to every time if not specified... Once the limits are good we it is
    # recommended to increase this for training

    if self.config.deep_sanity_check:
        if self._episode_length % self.config.sanity_check_obs == 0:
            try:
                self.__sanity_check(self._observation_space, self._obs_buffer.observation)
            except ValueError as err:
                self._save_state_pickle(err)
    else:
        if not self._observation_space.contains(self._obs_buffer.observation):
            raise ValueError('obs not contained in obs space')

    complete_trainable_observations, complete_unnormalized_observations = self.create_training_observations(
        operable_agents, self._obs_buffer
    )
    trainable_observations = OrderedDict()
    for agent_id in agents_to_process_this_timestep:
        trainable_observations[agent_id] = complete_trainable_observations[agent_id]

    trainable_rewards = get_dictionary_subset(reward, agents_to_process_this_timestep)
    trainable_dones = get_dictionary_subset(agents_done, ["__all__"] + agents_to_process_this_timestep)
    trainable_info = get_dictionary_subset(self._info, agents_to_process_this_timestep)

    # add platform obs and env data to trainable_info (for use by custom policies)
    for agent_id in agents_to_process_this_timestep:
        if agent_id not in trainable_info:
            trainable_info[agent_id] = {}

        trainable_info[agent_id]['env'] = {'sim_period': self.sim_period}
        trainable_info[agent_id]['platform_obs'] = {}

        plat_name = self.agent_dict[agent_id].platform_name
        for platform_agent in self.agent_dict:
            if self.agent_dict[platform_agent].platform_name == plat_name:
                trainable_info[agent_id]['platform_obs'][platform_agent] = complete_unnormalized_observations[platform_agent]

    # if not done all, delete any platforms from simulation that are done, so they don't interfere
    platforms_deleted = set()
    if not agents_done['__all__']:
        for agent_key, value in agents_done.items():
            if agent_key != '__all__' and value:
                plat_name = self.agent_dict[agent_key].platform_name
                if plat_name not in platforms_deleted:
                    self.simulator.delete_platform(plat_name)
                    platforms_deleted.add(plat_name)

    # if a platform has been deleted, we need to make sure that all agents on that platform
    # are also done
    for agent_key in trainable_dones.keys():
        if agent_key == '__all__':
            continue
        plat_name = self.agent_dict[agent_key].platform_name
        if plat_name in platforms_deleted:
            trainable_dones[agent_key] = True

    #####################################################################
    # return results to RLLIB - Note that RLLIB does not do a recursive
    # isinstance call and as such need to make sure items are
    # OrderedDicts
    #####################################################################
    return trainable_observations, trainable_rewards, trainable_dones, trainable_info

ACT3MultiAgentEnvEppParameters (BaseModel) pydantic-model ¤

typing.Dict[str, typing.Dict[str, typing.Any]] = {}

keys: done name, parameter name

typing.Dict[str, typing.Dict[str, typing.Dict[str, typing.Any]]] = {}

keys: agent name, done name, parameter name

typing.Dict[str, typing.Dict[str, typing.Any]] = {}

keys: done name, parameter name

typing.Dict[str, typing.Any] = {}

keys: reference name

typing.Dict[str, typing.Any] = {}

keys: whatever the simulator wants, but it needs to be kwargs to simulator reset

Source code in corl/environment/multi_agent_env.py
class ACT3MultiAgentEnvEppParameters(BaseModel):
    """
    world: typing.Dict[str, typing.Dict[str, typing.Any]] = {}
        keys: done name, parameter name

    task: typing.Dict[str, typing.Dict[str, typing.Dict[str, typing.Any]]] = {}
        keys: agent name, done name, parameter name

    shared: typing.Dict[str, typing.Dict[str, typing.Any]] = {}
        keys: done name, parameter name

    reference_store: typing.Dict[str, typing.Any] = {}
        keys: reference name

    simulator_reset: typing.Dict[str, typing.Any] = {}
        keys: whatever the simulator wants, but it needs to be kwargs to simulator reset
    """
    world: typing.Dict[str, typing.Dict[str, typing.Any]] = {}
    task: typing.Dict[str, typing.Dict[str, typing.Dict[str, typing.Any]]] = {}
    shared: typing.Dict[str, typing.Dict[str, typing.Any]] = {}
    reference_store: typing.Dict[str, typing.Any] = {}
    simulator_reset: typing.Dict[str, typing.Any] = {}

    @staticmethod
    def _validate_leaves_are_parameters(obj):
        if isinstance(obj, dict):
            for _key, value in obj.items():
                ACT3MultiAgentEnvEppParameters._validate_leaves_are_parameters(value)
        elif not isinstance(obj, Parameter):
            raise TypeError(f"Invalid type: {type(obj)} (required type: {Parameter.__qualname__})")

    @validator('world', 'task', 'shared', 'reference_store', 'simulator_reset')
    def validate_leaves_are_parameters(cls, v):
        """
        checks to make sure outer most leaf nodes of config are parameters
        """
        ACT3MultiAgentEnvEppParameters._validate_leaves_are_parameters(v)
        return v

validate_leaves_are_parameters(v) classmethod ¤

checks to make sure outer most leaf nodes of config are parameters

Source code in corl/environment/multi_agent_env.py
@validator('world', 'task', 'shared', 'reference_store', 'simulator_reset')
def validate_leaves_are_parameters(cls, v):
    """
    checks to make sure outer most leaf nodes of config are parameters
    """
    ACT3MultiAgentEnvEppParameters._validate_leaves_are_parameters(v)
    return v

ACT3MultiAgentEnvValidator (BaseModel) pydantic-model ¤

Validation model for the inputs of ACT3MultiAgentEnv

Source code in corl/environment/multi_agent_env.py
class ACT3MultiAgentEnvValidator(BaseModel):
    """Validation model for the inputs of ACT3MultiAgentEnv"""
    num_workers: NonNegativeInt = 0
    worker_index: NonNegativeInt = 0
    vector_index: typing.Optional[NonNegativeInt] = None
    remote: bool = False
    deep_sanity_check: bool = True

    seed: PositiveInt = 0
    horizon: PositiveInt = 1000
    sanity_check_obs: PositiveInt = 50
    sensors_grid: typing.Optional[typing.List]
    plugin_paths: typing.List[str] = []

    # Regex allows letters, numbers, underscore, dash, dot
    # Regex in output_path validator also allows forward slash
    # Empty string is not allowed
    TrialName: typing.Optional[Annotated[str, Field(regex=r'^[\w\.-]+$')]] = None
    output_date_string: typing.Optional[Annotated[str, Field(regex=r'^[\w\.-]+$')]] = None
    skip_pbs_date_update: bool = False
    # MyPy error ignored because it is handled by the pre-validator
    output_path: DirectoryPath = None  # type: ignore[assignment]

    agent_platforms: typing.Dict
    agents: typing.Dict[str, AgentParseInfo]

    simulator: Factory
    platforms: typing.Type[BaseAvailablePlatformTypes]
    other_platforms: typing.Dict[str, typing.Dict[str, typing.Any]] = {}

    reference_store: typing.Dict[str, ObjectStoreElem] = {}
    dones: EnvironmentDoneValidator = EnvironmentDoneValidator()
    end_episode_on_first_agent_done: bool = False
    simulator_reset_parameters: typing.Dict[str, typing.Any] = {}

    episode_parameter_provider: Factory
    episode_parameter_provider_parameters: ACT3MultiAgentEnvEppParameters = None  # type: ignore
    epp_registry: typing.Dict[str, EpisodeParameterProvider] = None  # type: ignore

    max_agent_rate: int = 20  # the maximum rate (in Hz) that an agent may be run at
    timestep_epsilon: float = 1e-3
    sim_warmup_steps: int = 0  # number of times to step simulator before getting initial obs

    @property
    def epp(self) -> EpisodeParameterProvider:
        """
        return the current episode parameter provider
        """
        return self.epp_registry[ACT3MultiAgentEnv.episode_parameter_provider_name]  # pylint: disable=unsubscriptable-object

    class Config:
        """Allow arbitrary types for Parameter"""
        arbitrary_types_allowed = True

    @validator('seed', pre=True)
    def get_seed(cls, v):
        """Compute a valid seed"""
        _, seed = gym.utils.seeding.np_random(v)
        return seed

    @validator('plugin_paths')
    def add_plugin_paths(cls, v):
        """Use the plugin path attribute to initialize the plugin library."""
        PluginLibrary.add_paths(v)
        return v

    @validator('output_path', pre=True, always=True)
    def create_output_path(cls, v, values):
        """Build the output path."""

        v = v or 'data/act3/ray_results'
        v = parse_obj_as(Annotated[str, Field(regex=r'^[\w/\.-]+$')], v)

        if values['TrialName'] is not None:
            if values["skip_pbs_date_update"]:
                trial_prefix = ''
            else:
                trial_prefix = os.environ.get('PBS_JOBID', os.environ.get('TRIAL_NAME_PREFIX', ''))

            if trial_prefix:
                v = os.path.join(v, f'{trial_prefix}-{values["TrialName"]}')
            else:
                v = os.path.join(v, values['TrialName'])

        if values['output_date_string'] is not None and not values["skip_pbs_date_update"]:
            v = os.path.join(v, values['output_date_string'])

        v = os.path.abspath(v)
        v = os.path.join(v, str(values['worker_index']).zfill(4))
        if values['vector_index'] is not None:
            v = os.path.join(v, str(values['vector_index']).zfill(4))
        os.makedirs(v, exist_ok=True)

        return v

    @validator('simulator', pre=True)
    def resolve_simulator_plugin(cls, v):
        """Determine the simulator from the plugin library."""
        try:
            v['type']
        except (TypeError, KeyError):
            # Let pydantic print out an error when there is no type field
            return v

        match = PluginLibrary.FindMatch(v['type'], {})
        if not issubclass(match, BaseSimulator):
            raise TypeError(f"Simulator must subclass BaseSimulator, but is is of type {v['type']}")

        return {'type': match, 'config': v.get('config')}

    @validator('platforms', pre=True)
    def resolve_platforms(cls, v, values):
        """Determine the platforms from the plugin library."""

        if not isinstance(v, str):
            return v
        return PluginLibrary.FindMatch(v, {'simulator': values['simulator'].type})

    @validator('agents')
    def agents_not_empty(cls, v, values):
        """Ensure that at least one agent exists"""
        if len(v) == 0:
            raise RuntimeError('No agents exist')

        for agent_name, agent in v.items():
            assert agent.platform_name in values['agent_platforms'], f"missing platform '{agent.platform_name}' for agent '{agent_name}'"
            assert agent.platform_name == agent_name.split("_")[0], f"invalid platform name {agent.platform_name} for agent {agent_name}"
        return v

    resolve_reference_store_factory = validator('reference_store', pre=True, each_item=True, allow_reuse=True)(Factory.resolve_factory)

    @validator('dones', always=True)
    def agents_match(cls, v, values):
        """Ensure that platform in task dones match provided platforms"""
        # No extra agents in task dones
        for platform in v.task.keys():
            if platform not in values['agent_platforms']:
                raise RuntimeError(f'Platform {platform} lists a done condition but is not an allowed platform')

        # Task dones exist for all agents.  Make empty ones if necessary
        for platform in values['agent_platforms']:
            if platform not in v.task:
                v.task[platform] = {}

        return v

    @validator('simulator_reset_parameters', pre=True)
    def update_units_and_parameters(cls, v):
        """Update simulation reset parameters to meet base simulator requirements."""
        return validation_helper_units_and_parameters(v)

    @validator('episode_parameter_provider_parameters', always=True, pre=True)
    def build_episode_parameter_provider_parameters(cls, _v, values) -> ACT3MultiAgentEnvEppParameters:
        """Create the episode parameter provider for this configuration"""

        for key in ['reference_store', 'dones', 'simulator_reset_parameters']:
            assert key in values

        reference_parameters: typing.Dict[str, Parameter] = {}
        for ref_name, ref_value in values['reference_store'].items():
            if isinstance(ref_value, Parameter):
                reference_parameters[ref_name] = ref_value

        world_parameters: typing.Dict[str, typing.Dict[str, Parameter]] = {}
        for functor in values['dones'].world:
            functor.add_to_parameter_store(world_parameters)

        task_parameters: typing.Dict[str, typing.Dict[str, typing.Dict[str, Parameter]]] = {}
        for agent, task_dones in values['dones'].task.items():
            agent_task_parameters: typing.Dict[str, typing.Dict[str, Parameter]] = {}
            for functor in task_dones:
                functor.add_to_parameter_store(agent_task_parameters)
            task_parameters[agent] = agent_task_parameters

        shared_parameters: typing.Dict[str, typing.Dict[str, Parameter]] = {}
        for functor in values['dones'].shared:
            functor.add_to_parameter_store(shared_parameters)

        sim_parameters_flat = {
            name: param
            for name,
            param in flatten_dict.flatten(values['simulator_reset_parameters']).items()
            if isinstance(param, Parameter)
        }
        sim_parameters = flatten_dict.unflatten(sim_parameters_flat)

        return ACT3MultiAgentEnvEppParameters(
            world=world_parameters,
            task=task_parameters,
            shared=shared_parameters,
            reference_store=reference_parameters,
            simulator_reset=sim_parameters
        )

    @validator('epp_registry', always=True, pre=True)
    def construct_epp_registry_if_necessary_and_validate(cls, epp_registry, values):
        """
        validates the Episode Parameter provider registry
        """
        if epp_registry is None:
            epp_registry = {}
            env_epp_parameters = dict(values['episode_parameter_provider_parameters'])
            flat_env_epp_parameters = flatten_dict.flatten(env_epp_parameters)
            env_epp = values['episode_parameter_provider'].build(parameters=flat_env_epp_parameters)
            epp_registry[ACT3MultiAgentEnv.episode_parameter_provider_name] = env_epp

            for agent_id, agent_info in values['agents'].items():
                agent = agent_info.class_config.agent(
                    agent_name=agent_id, platform_name=agent_info.platform_name, **agent_info.class_config.config
                )
                epp_registry[agent_id] = agent.config.epp

        if ACT3MultiAgentEnv.episode_parameter_provider_name not in epp_registry:
            raise ValueError(f"Missing EPP for '{ACT3MultiAgentEnv.episode_parameter_provider_name}'")

        for agent_id in values['agents']:
            if agent_id not in epp_registry:
                raise ValueError(f"Missing EPP for '{agent_id}'")

        for key, epp in epp_registry.items():
            if not isinstance(epp, EpisodeParameterProvider):
                raise TypeError(
                    f"Invalid type for epp_registry['{key}']: {type(epp)}, only {EpisodeParameterProvider.__qualname__} allowed"
                )

        return epp_registry

epp: EpisodeParameterProvider property readonly ¤

return the current episode parameter provider

Config ¤

Allow arbitrary types for Parameter

Source code in corl/environment/multi_agent_env.py
class Config:
    """Allow arbitrary types for Parameter"""
    arbitrary_types_allowed = True

add_plugin_paths(v) classmethod ¤

Use the plugin path attribute to initialize the plugin library.

Source code in corl/environment/multi_agent_env.py
@validator('plugin_paths')
def add_plugin_paths(cls, v):
    """Use the plugin path attribute to initialize the plugin library."""
    PluginLibrary.add_paths(v)
    return v

agents_match(v, values) classmethod ¤

Ensure that platform in task dones match provided platforms

Source code in corl/environment/multi_agent_env.py
@validator('dones', always=True)
def agents_match(cls, v, values):
    """Ensure that platform in task dones match provided platforms"""
    # No extra agents in task dones
    for platform in v.task.keys():
        if platform not in values['agent_platforms']:
            raise RuntimeError(f'Platform {platform} lists a done condition but is not an allowed platform')

    # Task dones exist for all agents.  Make empty ones if necessary
    for platform in values['agent_platforms']:
        if platform not in v.task:
            v.task[platform] = {}

    return v

agents_not_empty(v, values) classmethod ¤

Ensure that at least one agent exists

Source code in corl/environment/multi_agent_env.py
@validator('agents')
def agents_not_empty(cls, v, values):
    """Ensure that at least one agent exists"""
    if len(v) == 0:
        raise RuntimeError('No agents exist')

    for agent_name, agent in v.items():
        assert agent.platform_name in values['agent_platforms'], f"missing platform '{agent.platform_name}' for agent '{agent_name}'"
        assert agent.platform_name == agent_name.split("_")[0], f"invalid platform name {agent.platform_name} for agent {agent_name}"
    return v

build_episode_parameter_provider_parameters(_v, values) classmethod ¤

Create the episode parameter provider for this configuration

Source code in corl/environment/multi_agent_env.py
@validator('episode_parameter_provider_parameters', always=True, pre=True)
def build_episode_parameter_provider_parameters(cls, _v, values) -> ACT3MultiAgentEnvEppParameters:
    """Create the episode parameter provider for this configuration"""

    for key in ['reference_store', 'dones', 'simulator_reset_parameters']:
        assert key in values

    reference_parameters: typing.Dict[str, Parameter] = {}
    for ref_name, ref_value in values['reference_store'].items():
        if isinstance(ref_value, Parameter):
            reference_parameters[ref_name] = ref_value

    world_parameters: typing.Dict[str, typing.Dict[str, Parameter]] = {}
    for functor in values['dones'].world:
        functor.add_to_parameter_store(world_parameters)

    task_parameters: typing.Dict[str, typing.Dict[str, typing.Dict[str, Parameter]]] = {}
    for agent, task_dones in values['dones'].task.items():
        agent_task_parameters: typing.Dict[str, typing.Dict[str, Parameter]] = {}
        for functor in task_dones:
            functor.add_to_parameter_store(agent_task_parameters)
        task_parameters[agent] = agent_task_parameters

    shared_parameters: typing.Dict[str, typing.Dict[str, Parameter]] = {}
    for functor in values['dones'].shared:
        functor.add_to_parameter_store(shared_parameters)

    sim_parameters_flat = {
        name: param
        for name,
        param in flatten_dict.flatten(values['simulator_reset_parameters']).items()
        if isinstance(param, Parameter)
    }
    sim_parameters = flatten_dict.unflatten(sim_parameters_flat)

    return ACT3MultiAgentEnvEppParameters(
        world=world_parameters,
        task=task_parameters,
        shared=shared_parameters,
        reference_store=reference_parameters,
        simulator_reset=sim_parameters
    )

construct_epp_registry_if_necessary_and_validate(epp_registry, values) classmethod ¤

validates the Episode Parameter provider registry

Source code in corl/environment/multi_agent_env.py
@validator('epp_registry', always=True, pre=True)
def construct_epp_registry_if_necessary_and_validate(cls, epp_registry, values):
    """
    validates the Episode Parameter provider registry
    """
    if epp_registry is None:
        epp_registry = {}
        env_epp_parameters = dict(values['episode_parameter_provider_parameters'])
        flat_env_epp_parameters = flatten_dict.flatten(env_epp_parameters)
        env_epp = values['episode_parameter_provider'].build(parameters=flat_env_epp_parameters)
        epp_registry[ACT3MultiAgentEnv.episode_parameter_provider_name] = env_epp

        for agent_id, agent_info in values['agents'].items():
            agent = agent_info.class_config.agent(
                agent_name=agent_id, platform_name=agent_info.platform_name, **agent_info.class_config.config
            )
            epp_registry[agent_id] = agent.config.epp

    if ACT3MultiAgentEnv.episode_parameter_provider_name not in epp_registry:
        raise ValueError(f"Missing EPP for '{ACT3MultiAgentEnv.episode_parameter_provider_name}'")

    for agent_id in values['agents']:
        if agent_id not in epp_registry:
            raise ValueError(f"Missing EPP for '{agent_id}'")

    for key, epp in epp_registry.items():
        if not isinstance(epp, EpisodeParameterProvider):
            raise TypeError(
                f"Invalid type for epp_registry['{key}']: {type(epp)}, only {EpisodeParameterProvider.__qualname__} allowed"
            )

    return epp_registry

create_output_path(v, values) classmethod ¤

Build the output path.

Source code in corl/environment/multi_agent_env.py
@validator('output_path', pre=True, always=True)
def create_output_path(cls, v, values):
    """Build the output path."""

    v = v or 'data/act3/ray_results'
    v = parse_obj_as(Annotated[str, Field(regex=r'^[\w/\.-]+$')], v)

    if values['TrialName'] is not None:
        if values["skip_pbs_date_update"]:
            trial_prefix = ''
        else:
            trial_prefix = os.environ.get('PBS_JOBID', os.environ.get('TRIAL_NAME_PREFIX', ''))

        if trial_prefix:
            v = os.path.join(v, f'{trial_prefix}-{values["TrialName"]}')
        else:
            v = os.path.join(v, values['TrialName'])

    if values['output_date_string'] is not None and not values["skip_pbs_date_update"]:
        v = os.path.join(v, values['output_date_string'])

    v = os.path.abspath(v)
    v = os.path.join(v, str(values['worker_index']).zfill(4))
    if values['vector_index'] is not None:
        v = os.path.join(v, str(values['vector_index']).zfill(4))
    os.makedirs(v, exist_ok=True)

    return v

get_seed(v) classmethod ¤

Compute a valid seed

Source code in corl/environment/multi_agent_env.py
@validator('seed', pre=True)
def get_seed(cls, v):
    """Compute a valid seed"""
    _, seed = gym.utils.seeding.np_random(v)
    return seed

resolve_platforms(v, values) classmethod ¤

Determine the platforms from the plugin library.

Source code in corl/environment/multi_agent_env.py
@validator('platforms', pre=True)
def resolve_platforms(cls, v, values):
    """Determine the platforms from the plugin library."""

    if not isinstance(v, str):
        return v
    return PluginLibrary.FindMatch(v, {'simulator': values['simulator'].type})

resolve_reference_store_factory(v) classmethod ¤

Validator for converting a factory into the built object.

Usage in a pydantic model: resolve_factory = validator('name', pre=True, allow_reuse=True)(Factory.resolve_factory)

Source code in corl/environment/multi_agent_env.py
@classmethod
def resolve_factory(cls, v):
    """Validator for converting a factory into the built object.

    Usage in a pydantic model:
    resolve_factory = validator('name', pre=True, allow_reuse=True)(Factory.resolve_factory)
    """
    try:
        v['type']
    except (TypeError, KeyError):
        # Not something that should be built with the factory
        return v
    else:
        factory = cls(**v)
        return factory.build()

resolve_simulator_plugin(v) classmethod ¤

Determine the simulator from the plugin library.

Source code in corl/environment/multi_agent_env.py
@validator('simulator', pre=True)
def resolve_simulator_plugin(cls, v):
    """Determine the simulator from the plugin library."""
    try:
        v['type']
    except (TypeError, KeyError):
        # Let pydantic print out an error when there is no type field
        return v

    match = PluginLibrary.FindMatch(v['type'], {})
    if not issubclass(match, BaseSimulator):
        raise TypeError(f"Simulator must subclass BaseSimulator, but is is of type {v['type']}")

    return {'type': match, 'config': v.get('config')}

update_units_and_parameters(v) classmethod ¤

Update simulation reset parameters to meet base simulator requirements.

Source code in corl/environment/multi_agent_env.py
@validator('simulator_reset_parameters', pre=True)
def update_units_and_parameters(cls, v):
    """Update simulation reset parameters to meet base simulator requirements."""
    return validation_helper_units_and_parameters(v)

EnvironmentDoneValidator (BaseModel) pydantic-model ¤

Validation model for the dones of ACT3MultiAgentEnv

Source code in corl/environment/multi_agent_env.py
class EnvironmentDoneValidator(BaseModel):
    """Validation model for the dones of ACT3MultiAgentEnv"""
    world: typing.List[Functor] = []
    task: typing.Dict[str, typing.List[Functor]] = {}
    shared: typing.List[Functor] = []

    @validator('world', each_item=True)
    def check_world(cls, v):
        """Check if dones subclass DoneFuncBase"""
        cls.check_done(v)
        return v

    @validator('task', each_item=True)
    def check_task(cls, v):
        """Check if dones subclass DoneFuncBase"""
        for elem in v:
            cls.check_done(elem)
        return v

    @validator('shared', each_item=True)
    def check_shared(cls, v):
        """Check if dones subclass SharedDoneFuncBase"""
        if not issubclass(v.functor, SharedDoneFuncBase):
            raise TypeError(f"Shared Done functors must subclass SharedDoneFuncBase, but done {v.name} is of type {v.functor}")
        return v

    @classmethod
    def check_done(cls, v) -> None:
        """Check if dones subclass DoneFuncBase"""
        if not issubclass(v.functor, DoneFuncBase):
            raise TypeError(f"Done functors must subclass DoneFuncBase, but done {v.name} is of type {v.functor}")
        if issubclass(v.functor, EpisodeLengthDone):
            raise ValueError("Cannot specify EpisodeLengthDone as it is automatically added")

check_done(v) classmethod ¤

Check if dones subclass DoneFuncBase

Source code in corl/environment/multi_agent_env.py
@classmethod
def check_done(cls, v) -> None:
    """Check if dones subclass DoneFuncBase"""
    if not issubclass(v.functor, DoneFuncBase):
        raise TypeError(f"Done functors must subclass DoneFuncBase, but done {v.name} is of type {v.functor}")
    if issubclass(v.functor, EpisodeLengthDone):
        raise ValueError("Cannot specify EpisodeLengthDone as it is automatically added")

check_shared(v) classmethod ¤

Check if dones subclass SharedDoneFuncBase

Source code in corl/environment/multi_agent_env.py
@validator('shared', each_item=True)
def check_shared(cls, v):
    """Check if dones subclass SharedDoneFuncBase"""
    if not issubclass(v.functor, SharedDoneFuncBase):
        raise TypeError(f"Shared Done functors must subclass SharedDoneFuncBase, but done {v.name} is of type {v.functor}")
    return v

check_task(v) classmethod ¤

Check if dones subclass DoneFuncBase

Source code in corl/environment/multi_agent_env.py
@validator('task', each_item=True)
def check_task(cls, v):
    """Check if dones subclass DoneFuncBase"""
    for elem in v:
        cls.check_done(elem)
    return v

check_world(v) classmethod ¤

Check if dones subclass DoneFuncBase

Source code in corl/environment/multi_agent_env.py
@validator('world', each_item=True)
def check_world(cls, v):
    """Check if dones subclass DoneFuncBase"""
    cls.check_done(v)
    return v