Torch frame stack
Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.
This is a US Government Work not subject to copyright protection in the US.
The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.
TorchFrameStack (TorchModelV2, Module)
¤
Generic fully connected network.
Source code in corl/models/torch_frame_stack.py
class TorchFrameStack(TorchModelV2, nn.Module): # type: ignore
"""Generic fully connected network."""
PREV_N_OBS = "prev_n_obs"
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
num_frames: int = 1
):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
hiddens = list(model_config.get("fcnet_hiddens", []))
post_fcnet_hiddens = list(model_config.get("post_fcnet_hiddens", []))
activation = model_config.get("fcnet_activation")
if not model_config.get("fcnet_hiddens", []):
activation = model_config.get("post_fcnet_activation")
no_final_linear = model_config.get("no_final_linear")
self.vf_share_layers = model_config.get("vf_share_layers")
self.free_log_std = model_config.get("free_log_std")
num_frames = model_config["custom_model_config"].get("num_frames", 1)
self.view_requirements[TorchFrameStack.PREV_N_OBS
] = ViewRequirement(data_col="obs", shift="-{}:0".format(num_frames - 1), space=obs_space)
# Generate free-floating bias variables for the second half of
# the outputs.
if self.free_log_std:
assert num_outputs % 2 == 0, ("num_outputs must be divisible by two", num_outputs)
num_outputs = num_outputs // 2
layers = []
prev_layer_size = int(obs_space.shape[-1])
self._logits = None
# Create layers 0 to second-last.
for size in hiddens[:-1]:
layers.append(SlimFC(in_size=prev_layer_size, out_size=size, initializer=normc_initializer(1.0), activation_fn=activation))
prev_layer_size = size
layers.append(nn.Flatten())
prev_layer_size = size * num_frames
for size in post_fcnet_hiddens:
layers.append(SlimFC(in_size=prev_layer_size, out_size=size, initializer=normc_initializer(1.0), activation_fn=activation))
prev_layer_size = size
# The last layer is adjusted to be of size num_outputs, but it's a
# layer with activation.
if no_final_linear and num_outputs:
layers.append(
SlimFC(in_size=prev_layer_size, out_size=num_outputs, initializer=normc_initializer(1.0), activation_fn=activation)
)
prev_layer_size = num_outputs
# Finish the layers with the provided sizes (`hiddens`), plus -
# iff num_outputs > 0 - a last linear layer of size num_outputs.
else:
if len(hiddens) > 0:
layers.append(
SlimFC(in_size=prev_layer_size, out_size=hiddens[-1], initializer=normc_initializer(1.0), activation_fn=activation)
)
prev_layer_size = hiddens[-1]
if num_outputs:
self._logits = SlimFC(
in_size=prev_layer_size, out_size=num_outputs, initializer=normc_initializer(0.01), activation_fn=None
)
else:
self.num_outputs = ([int(np.product(obs_space.shape))] + hiddens[-1:])[-1]
# Layer to add the log std vars to the state-dependent means.
if self.free_log_std and self._logits:
self._append_free_log_std = AppendBiasLayer(num_outputs)
self._hidden_layers = nn.Sequential(*layers)
self._value_branch_separate = None
if not self.vf_share_layers:
# Build a parallel set of hidden layers for the value net.
prev_vf_layer_size = int(obs_space.shape[-1])
vf_layers = []
for size in hiddens[:-1]:
vf_layers.append(
SlimFC(in_size=prev_vf_layer_size, out_size=size, activation_fn=activation, initializer=normc_initializer(1.0))
)
prev_vf_layer_size = size
vf_layers.append(nn.Flatten())
prev_vf_layer_size = size * num_frames
vf_layers.append(
SlimFC(in_size=prev_vf_layer_size, out_size=hiddens[-1], activation_fn=activation, initializer=normc_initializer(1.0))
)
prev_vf_layer_size = hiddens[-1]
# for size in hiddens:
# vf_layers.append(
# SlimFC(
# in_size=prev_vf_layer_size,
# out_size=size,
# activation_fn=activation,
# initializer=normc_initializer(1.0)))
# prev_vf_layer_size = size
self._value_branch_separate = nn.Sequential(*vf_layers)
self._value_branch = SlimFC(in_size=prev_layer_size, out_size=1, initializer=normc_initializer(0.01), activation_fn=None)
# print("*************************************************")
# print(model_config)
# print("************************************************")
# print(num_frames)
# print("************************************************")
# print(self._hidden_layers)
# print(self._logits)
# print("***********************************************")
# print(self._value_branch_separate)
# print(self._value_branch)
# # exit(1)
# Holds the current "base" output (before logits layer).
self._features = None
# Holds the last input, in case value branch is separate.
self._last_flat_in = None
@override(TorchModelV2)
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]): # type: ignore
self._last_flat_in = input_dict[TorchFrameStack.PREV_N_OBS].float()
self._features = self._hidden_layers(self._last_flat_in)
logits = self._logits(self._features) if self._logits else \
self._features
if self.free_log_std:
logits = self._append_free_log_std(logits)
return logits, state
@override(TorchModelV2)
def value_function(self) -> TensorType:
assert self._features is not None, "must call forward() first"
if self._value_branch_separate:
return self._value_branch(self._value_branch_separate(self._last_flat_in)).squeeze(1)
else:
return self._value_branch(self._features).squeeze(1)
forward(self, input_dict, state, seq_lens)
¤
Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by call before being passed to forward(). To access the flattened observation tensor, refer to input_dict["obs_flat"].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of call.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_dict |
Dict[str, Any] |
dictionary of input tensors, including "obs", "obs_flat", "prev_action", "prev_reward", "is_training", "eps_id", "agent_id", "infos", and "t". |
required |
state |
List[Any] |
list of state tensors with sizes matching those returned by get_initial_state + the batch dimension |
required |
seq_lens |
Any |
1d tensor holding input sequence lengths |
required |
Returns:
Type | Description |
---|---|
(Any, List[Any]) |
A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any. |
Examples:
>>> import numpy as np
>>> from ray.rllib.models.modelv2 import ModelV2
>>> class MyModel(ModelV2):
... # ...
>>> def forward(self, input_dict, state, seq_lens):
>>> model_out, self._value_out = self.base_model(
... input_dict["obs"])
>>> return model_out, state
Source code in corl/models/torch_frame_stack.py
@override(TorchModelV2)
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]): # type: ignore
self._last_flat_in = input_dict[TorchFrameStack.PREV_N_OBS].float()
self._features = self._hidden_layers(self._last_flat_in)
logits = self._logits(self._features) if self._logits else \
self._features
if self.free_log_std:
logits = self._append_free_log_std(logits)
return logits, state
value_function(self)
¤
Returns the value function output for the most recent forward pass.
Note that a forward
call has to be performed first, before this
methods can return anything and thus that calling this method does not
cause an extra forward pass through the network.
Returns:
Type | Description |
---|---|
Any |
Value estimate tensor of shape [BATCH]. |
Source code in corl/models/torch_frame_stack.py
@override(TorchModelV2)
def value_function(self) -> TensorType:
assert self._features is not None, "must call forward() first"
if self._value_branch_separate:
return self._value_branch(self._value_branch_separate(self._last_flat_in)).squeeze(1)
else:
return self._value_branch(self._features).squeeze(1)