Skip to content

Arithmetic multi glue


Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.

This is a US Government Work not subject to copyright protection in the US.

The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.


ArithmeticMultiGlue implementation

ArithmeticMultiGlue (BaseMultiWrapperGlue) ¤

ArithmeticMultiGlue takes in a list of wrapped glues and performs some arithmetic operation on their output

Source code in corl/glues/common/arithmetic_multi_glue.py
class ArithmeticMultiGlue(BaseMultiWrapperGlue):
    """
    ArithmeticMultiGlue takes in a list of wrapped glues and performs some
    arithmetic operation on their output
    """

    def __init__(self, **kwargs) -> None:
        self.config: ArithmeticMultiGlueValidator
        super().__init__(**kwargs)
        self.operator = np.sum
        # if self.config.mode == "sub":
        #     self.operator = np.subtract
        # elif self.config.mode == "mult":
        #     self.operator = np.multiply
        # elif self.config.mode == "div":
        #     self.operator = np.divide
        self.field_names = []
        for glue in self.glues():
            space = glue.observation_space()
            if len(space.spaces) > 1:
                raise RuntimeError("ArithmeticMultiGlue can only wrap a glue with one output")
            self.field_names.append(list(space.spaces.keys())[0])

    class Fields:
        """
        Field data
        """
        RESULT = "result"

    @property
    def get_validator(self) -> typing.Type[ArithmeticMultiGlueValidator]:
        return ArithmeticMultiGlueValidator

    @lru_cache(maxsize=1)
    def get_unique_name(self):
        """Class method that retreives the unique name for the glue instance
        """
        tmp = [glue.get_unique_name() for glue in self.glues()]
        if any(tmp_str is None for tmp_str in tmp):
            return None
        wrapped_glue_names = "".join(tmp)
        return wrapped_glue_names + self.config.mode

    def invalid_value(self) -> OrderedDict:
        """When invalid return a value of 0

        TODO: this may need to be self.min in the case that the minimum is larger than 0 (i.e. a harddeck)

        Returns:
            OrderedDict -- Dictionary with <FIELD> entry containing 1D array
        """
        d = OrderedDict()
        d[f"{self.Fields.RESULT}"] = np.asarray(
            [(self.config.limit.maximum + self.config.limit.minimum) / 2], dtype=np.float32
        )  # type: ignore
        return d

    @lru_cache(maxsize=1)
    def observation_space(self):
        d = gym.spaces.dict.Dict()

        d.spaces[f"{self.Fields.RESULT}"] = gym.spaces.Box(
            self.config.limit.minimum, self.config.limit.maximum, shape=(1, ), dtype=np.float32
        )
        return d

    def get_observation(self):
        d = OrderedDict()
        tmp_output = [glue.get_observation()[field_name] for glue, field_name in zip(self.glues(), self.field_names)]
        d[self.Fields.RESULT] = np.array([self.operator(tmp_output)], dtype=np.float32)
        return d

    @lru_cache(maxsize=1)
    def action_space(self) -> gym.spaces.Space:
        return None

    def apply_action(self, action, observation):
        return None

get_validator: Type[corl.glues.common.arithmetic_multi_glue.ArithmeticMultiGlueValidator] property readonly ¤

returns the validator for this class

Returns:

Type Description
Type[corl.glues.common.arithmetic_multi_glue.ArithmeticMultiGlueValidator]

BaseAgentGlueValidator -- A pydantic validator to be used to validate kwargs

Fields ¤

Field data

Source code in corl/glues/common/arithmetic_multi_glue.py
class Fields:
    """
    Field data
    """
    RESULT = "result"

apply_action(self, action, observation) ¤

Apply the action for the controller, etc.

Parameters¤

action The action for the class to apply to the platform observation The current observations before appling the action

Source code in corl/glues/common/arithmetic_multi_glue.py
def apply_action(self, action, observation):
    return None

get_observation(self) ¤

Get the actual observation for the platform using the state of the platform, controller, sensors, etc.

Returns¤

EnvSpaceUtil.sample_type The actual observation for this platform from this glue class

Source code in corl/glues/common/arithmetic_multi_glue.py
def get_observation(self):
    d = OrderedDict()
    tmp_output = [glue.get_observation()[field_name] for glue, field_name in zip(self.glues(), self.field_names)]
    d[self.Fields.RESULT] = np.array([self.operator(tmp_output)], dtype=np.float32)
    return d

get_unique_name(self) ¤

Class method that retreives the unique name for the glue instance

Source code in corl/glues/common/arithmetic_multi_glue.py
@lru_cache(maxsize=1)
def get_unique_name(self):
    """Class method that retreives the unique name for the glue instance
    """
    tmp = [glue.get_unique_name() for glue in self.glues()]
    if any(tmp_str is None for tmp_str in tmp):
        return None
    wrapped_glue_names = "".join(tmp)
    return wrapped_glue_names + self.config.mode

invalid_value(self) ¤

When invalid return a value of 0

TODO: this may need to be self.min in the case that the minimum is larger than 0 (i.e. a harddeck)

Returns:

Type Description
OrderedDict

OrderedDict -- Dictionary with entry containing 1D array

Source code in corl/glues/common/arithmetic_multi_glue.py
def invalid_value(self) -> OrderedDict:
    """When invalid return a value of 0

    TODO: this may need to be self.min in the case that the minimum is larger than 0 (i.e. a harddeck)

    Returns:
        OrderedDict -- Dictionary with <FIELD> entry containing 1D array
    """
    d = OrderedDict()
    d[f"{self.Fields.RESULT}"] = np.asarray(
        [(self.config.limit.maximum + self.config.limit.minimum) / 2], dtype=np.float32
    )  # type: ignore
    return d

ArithmeticMultiGlueValidator (BaseMultiWrapperGlueValidator) pydantic-model ¤

mode: what arithmetic operation to run on the output of the wrapped glues limit: the expected limit for this glue

Source code in corl/glues/common/arithmetic_multi_glue.py
class ArithmeticMultiGlueValidator(BaseMultiWrapperGlueValidator):
    """
    mode: what arithmetic operation to run on the output of the wrapped glues
    limit: the expected limit for this glue
    """
    mode: typing.Literal["sum", "sub", "mult", "div"] = "sum"
    limit: LimitConfigValidator