Source code for layeredrl.hierarchies.hierarchy

from copy import copy
from pathlib import Path
from typing import Callable, List, Optional, Tuple

import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter
import yaml

from ..levels import Level
from ..utils.misc import to_torch


def _filter_dict_active_instances(active_instances: torch.Tensor, info: dict) -> dict:
    """Filter the info dictionary for the given active instances.

    Args:
        active_instances: A boolean tensor indicating which instances to keep.
        info: The info dictionary to filter.
    Returns:
        The filtered info dictionary.
    """
    return {key: value[active_instances] for key, value in info.items()}


[docs] class Hierarchy:
[docs] def __init__( self, levels: List[Level], env: gym.Env, env_obs_maps: Optional[List[Callable[[torch.Tensor], torch.Tensor]]] = None, mapped_env_obs_shapes: Optional[List[Tuple]] = None, keep_params: bool = False, device: torch.device = torch.device("cpu"), writer: Optional[SummaryWriter] = None, ): """Initialize the hierarchy. Makes sure the action a level emits fits the input expected by the level below it. Args: levels: The levels of the hierarchy. env: The environment. obs_input_maps: A list of functions that map the environment observation to a vector that is provided to the corresponding level. This can be used to implement information hiding and for moving a trained level from one environment to another with a different observation space. If None, the identity map is used. mapped_env_obs_shapes: The shapes of the output of the env_obs_map of each level. If None, the dimension of the environment observation space is used. If negative, the negative dimension is added to the dimension of the environment observation space. This is useful if the map from the environment observation to the level observation is dropping some components. keep_params: Whether to keep the parameters of the levels instead of resetting them. Setting this to True is only valid if the levels were already initialized before. device: The device to use. writer: The TensorBoard writer to use for logging. If None, no logging is done. """ self.levels = levels self.env = env self.active_level = None self.n_levels = len(levels) assert ( mapped_env_obs_shapes is None or len(mapped_env_obs_shapes) == self.n_levels ), "mapped_env_obs_shapes must be None or have the same length as levels" assert ( env_obs_maps is None or len(env_obs_maps) == self.n_levels ), "env_obs_maps must be None or have the same length as levels" self.mapped_env_obs_shapes = ( [None] * self.n_levels if mapped_env_obs_shapes is None else mapped_env_obs_shapes ) # Total number of environment steps the hierarchy has seen during training (summed over # environment instances). self.n_total_env_steps = 0 # if mapped env obs dimension is negative, add to dimension of the environment observation # space for i, shape in enumerate(self.mapped_env_obs_shapes): if shape is not None and len(shape) == 1 and shape[0] < 0: self.mapped_env_obs_shapes[i] = ( self.env.observation_space.shape[-1] + shape[0], ) self.env_obs_maps = ( [None] * self.n_levels if env_obs_maps is None else env_obs_maps ) self.device = device self.writer = writer if isinstance(env, gym.vector.VectorEnv): self.action_space = env.single_action_space self.observation_space = env.single_observation_space self.n_env_instances = env.num_envs else: self.action_space = env.action_space self.observation_space = env.observation_space self.n_env_instances = 1 # initialize levels such that their action space is the input space of the level below # (or the environment action space for the lowest level) for i in range(self.n_levels): if i == self.n_levels - 1: action_space = self.action_space else: action_space = levels[i + 1].get_input_space() if i == 0: parent_predictor = None elif hasattr(levels[i - 1], "predictor"): parent_predictor = levels[i - 1].predictor else: parent_predictor = None levels[i].initialize( env_obs_space=self.observation_space, action_space=action_space, n_env_instances=self.n_env_instances, parent_predictor=parent_predictor, env_obs_map=self.env_obs_maps[i], mapped_env_obs_shape=self.mapped_env_obs_shapes[i], keep_params=keep_params, ) # It is not necessary to set the number of environment instances for the # levels again because this already happens when calling Level.initialize(...). self.set_n_env_instances(self.n_env_instances, propagate_to_levels=False) self.training = True
[docs] def reset(self) -> None: """Reset the hierarchy. Call this at the beginning of the session. The highest level is active at the beginning for all environment instances. Do not call at the end of episodes. """ self.active_level = torch.zeros( self.n_env_instances, dtype=torch.int, device=self.device ) for level in self.levels: level.reset()
[docs] def soft_reset(self) -> None: """Soft reset the hierarchy. Call this when manually resetting the (vector) environment. This will not affect things like warm up periods etc. and can therefore be called without influencing the learning process.""" self.active_level = torch.zeros( self.n_env_instances, dtype=torch.int, device=self.device ) for level in self.levels: level.soft_reset()
[docs] def get_copy(self) -> "Hierarchy": """Return a copy of the hierarchy. No models are copied, only the structure and state of the hierarchy. The copy of the hierarchy can be used for testing rollouts without influencing the state of the original hierarchy, for example. Learning with the copy will influence the original hierarchy, however, and is not recommended. """ hierarchy = copy(self) hierarchy.levels = [level.get_copy() for level in self.levels] hierarchy.set_n_env_instances(self.n_env_instances, propagate_to_levels=False) hierarchy.reset() return hierarchy
[docs] def set_n_env_instances( self, n_env_instances: int, propagate_to_levels: bool = True ) -> None: """Set the number of environment instances. Args: n_env_instances: The number of environment instances. """ self.n_env_instances = n_env_instances self.mapped_env_obs = [] self.level_action = [] self.level_action_info = [] for level in self.levels: if propagate_to_levels: level.set_n_env_instances(n_env_instances) self.mapped_env_obs.append( torch.zeros( size=(self.n_env_instances,) + level.mapped_env_obs_shape, dtype=torch.float, device=self.device, ) ) action_shape, action_dtype = level.get_action_shape_and_type() level_action = torch.zeros( size=(self.n_env_instances,) + action_shape, dtype=action_dtype, device=self.device, ) self.level_action.append(level_action) self.level_action_info.append(dict())
def _update_level_action_info( self, level: int, info: dict, active_instances: torch.Tensor ) -> None: """Update the level action info dictionary for the given level. Args: level: The level for which to update the action info. info: The info dictionary to update the level action info with. active_instances: A boolean tensor indicating which instances to update. """ for key, value in info.items(): if key not in self.level_action_info[level]: self.level_action_info[level][key] = torch.zeros( (self.n_env_instances,) + value.shape[1:], dtype=value.dtype, device=self.device, ) self.level_action_info[level][key][active_instances] = value
[docs] def get_action(self, obs: torch.Tensor) -> torch.Tensor: """Get an action for the given observation. Note that obs and the returned action have a batch dimension corresponding to environment instances. The method descends the hierarchy from top to bottom, starting with the active level. From thereon, an action is obtained for each level which is then passed to the level below. The action of the lowest level is returned (to be executed in the environment). Args: obs: The environment observation. Returns: The action for the environment. """ # Make sure obs is tensor obs = to_torch(obs, self.device) # downward pass for obtaining actions for i, level in enumerate(self.levels): # only get action for those instances for which the level is active active_instances = self.active_level == i if active_instances.any(): level_obs = obs[active_instances].float() higher_level_action = ( self.level_action[i - 1][active_instances] if i > 0 else None ) higher_level_action_info = ( _filter_dict_active_instances( active_instances, self.level_action_info[i - 1] ) if i > 0 else None ) mapped_level_obs = level.env_obs_map(level_obs) self.mapped_env_obs[i][active_instances] = mapped_level_obs level_action, level_info = level.get_action( mapped_level_obs, higher_level_action, higher_level_action_info, active_instances, ) self.level_action[i][active_instances] = level_action.type_as( self.level_action[i] ) self._update_level_action_info(i, level_info, active_instances) if i < self.n_levels - 1: # after obtaining the action, the level below becomes active # except for the last level which stays active self.active_level[active_instances] = i + 1 # return the action of the lowest level return self.level_action[-1]
[docs] def process_transition( self, obs_next: torch.Tensor, rew: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, ) -> None: """Process the environment transition and return control to the higher levels where appropriate. Starting from the lowest level, the active level can pass control back to the level above. This can continue until a level stays active or the highest level is reached. While ascending the hierarchy, register the (semi-MDP) transitions with the levels. Args: obs_next: The next environment observation. rew: The reward of the environment transition. terminated: Whether the episode terminated. Tensor with one entry per environment instance. truncated: Whether the episode was truncated. Tensor with one entry per environment instance. """ obs_next = to_torch(obs_next, self.device).float() rew = to_torch(rew, self.device) terminated = to_torch(terminated, self.device) truncated = to_torch(truncated, self.device) self.n_total_env_steps += self.n_env_instances # register environment reward for i in range(self.n_levels): self.levels[i].register_env_reward(rew) self.levels[i].n_total_env_steps = self.n_total_env_steps # upward pass for returning control to the higher levels for i in reversed(range(self.n_levels)): # only check those instances for which the level is active active_instances = self.active_level == i if active_instances.any(): next_level_obs = obs_next[active_instances].float() next_mapped_level_obs = self.levels[i].env_obs_map(next_level_obs) higher_level_action = ( self.level_action[i - 1][active_instances] if i > 0 else None ) level_action = self.level_action[i][active_instances] return_control = self.levels[i].process_transition( self.mapped_env_obs[i][active_instances], higher_level_action, level_action, next_mapped_level_obs, terminated[active_instances], truncated[active_instances], active_instances, ) if i > 0: # return control to the level above if the level is done, # if the episode was terminated, or if the episode was truncated return_control_or_terminated = torch.logical_or( return_control, terminated[active_instances] ) # This generates somewhat incomplete transitions on the higher level # but it's still better than not having a transition that truncates # or terminates on the higher level rc_or_term_or_trunc = torch.logical_or( return_control_or_terminated, truncated[active_instances] ) self.active_level[rc_or_term_or_trunc] = i - 1 # reset the number of steps the level has been in control in # these cases self.levels[i].n_steps_in_control[return_control_or_terminated] = 0 # give control to the highest level if the episode truncated self.active_level[truncated] = 0
[docs] def learn(self) -> None: """Learn from the collected transitions.""" for level in self.levels: level.learn()
[docs] def eval(self) -> None: """Set all levels of the hierarchy to evaluation mode.""" self.training = False for level in self.levels: level.eval()
[docs] def train(self) -> None: """Set all levels of the hierarchy to training mode.""" self.training = True for level in self.levels: level.train()
[docs] def save(self, path: Path) -> None: """Save the hierarchy to the given path.""" path.mkdir(parents=True, exist_ok=True) hierarchy_state = {"self.n_total_env_steps": self.n_total_env_steps} with open(path / "hierarchy_state.yaml", "w") as f: yaml.dump(hierarchy_state, f) for i, level in enumerate(self.levels): level_path = path / f"level_{i}" level_path.mkdir(parents=True, exist_ok=True) level.save(level_path)
[docs] def load(self, path: Path) -> None: """Load the hierarchy from the given path.""" with open(path / "hierarchy_state.yaml", "r") as f: hierarchy_state = yaml.safe_load(f) self.n_total_env_steps = hierarchy_state["self.n_total_env_steps"] for i, level in enumerate(self.levels): level.load(path / f"level_{i}") level.n_total_env_steps = self.n_total_env_steps
[docs] def save_buffers(self, path: Path) -> None: """Save the replay buffers of all levels to the given path.""" for i, level in enumerate(self.levels): level_path = path / f"level_{i}" level_path.mkdir(parents=True, exist_ok=True) level.save_buffers(level_path)
[docs] def load_buffers(self, path: Path) -> None: """Load the replay buffers of all levels from the given path.""" for i, level in enumerate(self.levels): level.load_buffers(path / f"level_{i}")