Yaml loader
Air Force Research Laboratory (AFRL) Autonomous Capabilities Team (ACT3) Reinforcement Learning (RL) Core.
This is a US Government Work not subject to copyright protection in the US.
The use, dissemination or disclosure of data in this file is subject to limitation or restriction. See accompanying README and LICENSE for details.
Loader (SafeLoader)
¤
YAML Loader with !include
constructor.
Source code in corl/parsers/yaml_loader.py
class Loader(yaml.SafeLoader): # pylint: disable=too-few-public-methods,W0223
"""YAML Loader with `!include` constructor."""
def __init__(self, stream: IO) -> None:
"""Initialise Loader."""
try:
self._root = os.path.split(stream.name)[0]
except AttributeError:
self._root = os.path.curdir
super().__init__(stream)
self._include_mapping: dict = {}
def construct_python_tuple(self, node):
"""Adds in the capability to process tuples in yaml files
"""
return tuple(self.construct_sequence(node))
def construct_sequence(self, node, deep=False):
"""Construct a sequence from a YAML sequence node
This method extends yaml.constructor.BaseConstructor.construct_sequence by adding support for children with the tag
`!include-extend`. Any object with this tag should be constructable to produce a sequence of objects. Even though
`!include-extend` is a tag on the child object, the sequence produced by this child is not added as a single element to the sequence
being produced by this method. Rather, the output sequence is extended with this list. Any children with other tags are appended
into the list in the same manner as yaml.constructor.BaseConstructor.construct_sequence.
Examples
--------
Loader.add_constructor("!include-extend", construct_include)
with open("primary.yml", "r") as fp:
config = yaml.load(fp, Loader)
<file primary.yml>
root:
tree1:
- apple
- banana
- cherry
tree2:
- type: int
value: 3
- type: float
value: 3.14
- type: str
value: pi
tree3:
- date
- elderberry
- !include-extend secondary.yml
- mango
<file secondary.yml>
- fig
- grape
- honeydew
- jackfruit
- kiwi
- lemon
The output of the code above is:
config = {
'root': {
'tree1': ['apple', 'banana', 'cherry'],
'tree2': [
{'type': 'int', 'value': 3},
{'type': 'float', 'value': 3.14},
{'type': 'str', 'value': 'pi'}
],
'tree3': ['date', 'elderberry', 'fig', 'grape', 'honeydew', 'jackfruit', 'kiwi', 'lemon', 'mango']
}
}
"""
if not isinstance(node, SequenceNode):
return super().construct_sequence(node, deep=deep)
output = []
for child in node.value:
this_output = self.construct_object(child, deep=deep)
if child.tag == '!include-extend':
if not isinstance(this_output, collections.abc.Sequence):
raise ConstructorError(
None,
None,
"expected a sequence returned by 'include-extend', but found %s" % type(this_output).__name__,
child.start_mark
)
output.extend(this_output)
else:
output.append(this_output)
return output
def flatten_mapping(self, node):
merge = []
index = 0
while index < len(node.value):
key_node, value_node = node.value[index]
if key_node.tag == 'tag:yaml.org,2002:merge':
del node.value[index]
if isinstance(value_node, yaml.MappingNode):
self.flatten_mapping(value_node)
merge.extend(value_node.value)
elif isinstance(value_node, yaml.SequenceNode):
submerge = []
for subnode in value_node.value:
if not isinstance(subnode, yaml.MappingNode):
raise yaml.ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s" % subnode.id,
subnode.start_mark
)
self.flatten_mapping(subnode)
submerge.append(subnode.value)
submerge.reverse()
for value in submerge:
merge.extend(value)
else:
# TODO FIGURE OUT HOW TO DUMP AND ACCESS THE BASE NODE!!!!
# if value_node.tag == '!include-direct':
# filename = os.path.realpath(os.path.join(self._root, self.construct_scalar(value_node)))
# d = yaml.dump(self._include_mapping[filename])
# # for k, v in .items():
# # sk = yaml.ScalarNode(tag='tag:yaml.org,2002:str', value=str(k))
# # if isinstance(v, int):
# # sv = yaml.ScalarNode(tag='tag:yaml.org,2002:int', value=str(v))
# # else:
# # sv = yaml.ScalarNode(tag='tag:yaml.org,2002:seq', value=str(v))
# # merge.extend([(sk, sv)])
# else:
raise yaml.ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a mapping or list of mappings for merging, but found %s" % value_node.id,
value_node.start_mark
)
elif key_node.tag == 'tag:yaml.org,2002:value':
key_node.tag = 'tag:yaml.org,2002:str'
index += 1
else:
index += 1
if merge:
node.value = merge + node.value
__init__(self, stream)
special
¤
Initialise Loader.
Source code in corl/parsers/yaml_loader.py
def __init__(self, stream: IO) -> None:
"""Initialise Loader."""
try:
self._root = os.path.split(stream.name)[0]
except AttributeError:
self._root = os.path.curdir
super().__init__(stream)
self._include_mapping: dict = {}
construct_python_tuple(self, node)
¤
Adds in the capability to process tuples in yaml files
Source code in corl/parsers/yaml_loader.py
def construct_python_tuple(self, node):
"""Adds in the capability to process tuples in yaml files
"""
return tuple(self.construct_sequence(node))
construct_sequence(self, node, deep=False)
¤
Construct a sequence from a YAML sequence node
This method extends yaml.constructor.BaseConstructor.construct_sequence by adding support for children with the tag
!include-extend
. Any object with this tag should be constructable to produce a sequence of objects. Even though
!include-extend
is a tag on the child object, the sequence produced by this child is not added as a single element to the sequence
being produced by this method. Rather, the output sequence is extended with this list. Any children with other tags are appended
into the list in the same manner as yaml.constructor.BaseConstructor.construct_sequence.
Examples¤
Loader.add_constructor("!include-extend", construct_include) with open("primary.yml", "r") as fp: config = yaml.load(fp, Loader)
The output of the code above is: config = { 'root': { 'tree1': ['apple', 'banana', 'cherry'], 'tree2': [ {'type': 'int', 'value': 3}, {'type': 'float', 'value': 3.14}, {'type': 'str', 'value': 'pi'} ], 'tree3': ['date', 'elderberry', 'fig', 'grape', 'honeydew', 'jackfruit', 'kiwi', 'lemon', 'mango'] } }
Source code in corl/parsers/yaml_loader.py
def construct_sequence(self, node, deep=False):
"""Construct a sequence from a YAML sequence node
This method extends yaml.constructor.BaseConstructor.construct_sequence by adding support for children with the tag
`!include-extend`. Any object with this tag should be constructable to produce a sequence of objects. Even though
`!include-extend` is a tag on the child object, the sequence produced by this child is not added as a single element to the sequence
being produced by this method. Rather, the output sequence is extended with this list. Any children with other tags are appended
into the list in the same manner as yaml.constructor.BaseConstructor.construct_sequence.
Examples
--------
Loader.add_constructor("!include-extend", construct_include)
with open("primary.yml", "r") as fp:
config = yaml.load(fp, Loader)
<file primary.yml>
root:
tree1:
- apple
- banana
- cherry
tree2:
- type: int
value: 3
- type: float
value: 3.14
- type: str
value: pi
tree3:
- date
- elderberry
- !include-extend secondary.yml
- mango
<file secondary.yml>
- fig
- grape
- honeydew
- jackfruit
- kiwi
- lemon
The output of the code above is:
config = {
'root': {
'tree1': ['apple', 'banana', 'cherry'],
'tree2': [
{'type': 'int', 'value': 3},
{'type': 'float', 'value': 3.14},
{'type': 'str', 'value': 'pi'}
],
'tree3': ['date', 'elderberry', 'fig', 'grape', 'honeydew', 'jackfruit', 'kiwi', 'lemon', 'mango']
}
}
"""
if not isinstance(node, SequenceNode):
return super().construct_sequence(node, deep=deep)
output = []
for child in node.value:
this_output = self.construct_object(child, deep=deep)
if child.tag == '!include-extend':
if not isinstance(this_output, collections.abc.Sequence):
raise ConstructorError(
None,
None,
"expected a sequence returned by 'include-extend', but found %s" % type(this_output).__name__,
child.start_mark
)
output.extend(this_output)
else:
output.append(this_output)
return output
apply_patches(config)
¤
updates the base setup with patches
Returns:
Type | Description |
---|---|
The combined dict |
Source code in corl/parsers/yaml_loader.py
def apply_patches(config):
"""updates the base setup with patches
Arguments:
config [dict, list] -- The base and patch if list, else dict
Returns:
The combined dict
"""
def merge(source, destination):
"""
run me with nosetests --with-doctest file.py
>>> a = { 'first' : { 'all_rows' : { 'pass' : 'dog', 'number' : '1' } } }
>>> b = { 'first' : { 'all_rows' : { 'fail' : 'cat', 'number' : '5' } } }
>>> merge(b, a) == { 'first' : { 'all_rows' : { 'pass' : 'dog', 'fail' : 'cat', 'number' : '5' } } }
True
"""
for key, value in source.items():
if isinstance(value, dict):
# get node or create one
node = destination.setdefault(key, {})
merge(value, node)
else:
destination[key] = value
return destination
if isinstance(config, list):
config_new = copy.deepcopy(config[0])
for item in config[1:]:
if item is not None:
config_new = merge(item, config_new)
return config_new
return config
construct_include(loader, node)
¤
Include file referenced at node.
Source code in corl/parsers/yaml_loader.py
def construct_include(loader: Loader, node: yaml.Node) -> Any:
"""Include file referenced at node."""
filename = os.path.realpath(os.path.join(loader._root, loader.construct_scalar(node))) # type: ignore # pylint: disable=protected-access # noqa: E501
extension = os.path.splitext(filename)[1].lstrip(".")
with open(filename, "r") as fp:
if extension in ("yaml", "yml"): # pylint: disable=no-else-return
return yaml.load(fp, Loader)
elif extension in ("json", ):
return json.load(fp)
else:
return "".join(fp.readlines())
construct_include_arr(loader, node)
¤
Identical to above, but accepts an array and appends results as an array.
Source code in corl/parsers/yaml_loader.py
def construct_include_arr(loader: Loader, node: yaml.Node) -> Any:
"""Identical to above, but accepts an array and appends results as an array."""
sequence = loader.construct_sequence(node)
data: typing.List = []
for item in sequence:
filename = os.path.abspath(os.path.join(loader._root, item)) # pylint: disable=protected-access
extension = os.path.splitext(filename)[1].lstrip(".")
with open(filename, "r") as f:
if extension in ("yaml", "yml"): # pylint: disable=no-else-return
data = data + (yaml.load(f, Loader))
elif extension in ("json", ):
data = data + (json.load(f))
else:
data = data + ("".join(f.readlines())) # type: ignore
return data
construct_include_direct(loader, node)
¤
Include file referenced at node.
Source code in corl/parsers/yaml_loader.py
def construct_include_direct(loader: Loader, node: yaml.Node) -> Any:
"""Include file referenced at node."""
filename = os.path.realpath(os.path.join(loader._root, loader.construct_scalar(node))) # type: ignore # pylint: disable=protected-access # noqa: E501
extension = os.path.splitext(filename)[1].lstrip(".")
with open(filename, "r") as fp:
if extension in ("yaml", "yml"): # pylint: disable=no-else-return
temp = yaml.load(fp, Loader)
loader._include_mapping[filename] = temp # pylint: disable=protected-access
return temp
elif extension in ("json", ):
return json.load(fp)
else:
return "".join(fp.readlines())
construct_tune_function(loader, node)
¤
Include expression referenced at node.
Source code in corl/parsers/yaml_loader.py
def construct_tune_function(loader: Loader, node: yaml.Node) -> Any: # pylint: disable=unused-argument
"""Include expression referenced at node."""
if isinstance(node.value, str) and "tune" in node.value:
return eval(node.value) # pylint: disable=eval-used
return node.value
load_file(config_filename)
¤
Utility function to load in a specified yaml file
Source code in corl/parsers/yaml_loader.py
def load_file(config_filename: str):
"""
Utility function to load in a specified yaml file
"""
with open(config_filename, "r") as fp:
config = yaml.load(fp, Loader)
return config
separate_config(config)
¤
Utility function to separate the env specific configs from the tune configs
Source code in corl/parsers/yaml_loader.py
def separate_config(config: typing.Dict):
"""
Utility function to separate the env specific configs from the tune configs
"""
# we can call ray.init without any arguments so default is no arguments
ray_config = apply_patches(config.get("ray_config", {}))
# we must have a tune config or else we cannot call tune.run
if "tune_config" not in config:
raise ValueError(f"Could not find a tune_config in {config}")
tune_config = apply_patches(config["tune_config"])
# must also get an env_config or else we don't know which environment to run```
if "env_config" in config:
env_config = apply_patches(config["env_config"])
else:
raise ValueError(f"Could not find a env_config in {config} or rllib_config")
# must get a rllib config in some way or else we aren't going to run anything useful
rllib_configs = {}
if "rllib_configs" in config:
for key, value in config["rllib_configs"].items():
rllib_configs[key] = apply_patches(value)
else:
raise ValueError(f"Could not find a rllib_config in {config} or 'config' in tune_config")
for key in rllib_configs:
rllib_configs[key]["env_config"] = copy.deepcopy(env_config)
rllib_configs[key]["env_config"].setdefault("environment", {})["horizon"] = rllib_configs[key].get("horizon", 1000)
# a trainable config is not necessary
trainable_config = apply_patches(config.get("trainable_config", None))
return ray_config, rllib_configs, tune_config, trainable_config