Source code for layeredrl.utils.misc

from contextlib import contextmanager
from copy import deepcopy
from typing import Tuple, Dict, Union

import gymnasium as gym
import numpy as np
import torch


[docs] def to_torch( array: Union[np.ndarray, torch.Tensor], device: torch.device = torch.device("cpu") ) -> torch.Tensor: """Convert an array to a torch tensor. Args: array: The array. Returns: The torch tensor. """ if isinstance(array, torch.Tensor): return array else: return torch.as_tensor(array, device=device)
[docs] def to_numpy( array: Union[np.ndarray, torch.Tensor], device: torch.device = torch.device("cpu") ) -> np.ndarray: """Convert an array to a numpy array. Args: array: The array. Returns: The numpy array. """ if isinstance(array, np.ndarray): return array else: return array.cpu().numpy()
[docs] def copy_torch_or_numpy(array: Union[np.ndarray, torch.Tensor]) -> np.ndarray: """Copy a torch tensor or numpy array. Args: array: The array. Returns: The copy. """ if isinstance(array, np.ndarray): return array.copy() else: return array.clone()
[docs] @contextmanager def temp_eval_mode(net): """Temporarily set the net to eval mode.""" training = net.training try: net.eval() yield net finally: if training: net.train()
[docs] def get_key_indices(input_space: gym.spaces.Dict) -> Tuple[Dict, Dict]: """Get index ranges that correspond to keys of a Dict space in flattened vector. Also returns a dictionary containing the shapes of these observation components. Returns: - A dictionary with keys corresponding to the observation components and values being tuples of the form (start, end), where start and end are the indices at which the observation component starts and ends. The nested dictionary structure of the observation is preserved. - A dictionary of the same structure but with values being the shapes of the observation components.""" def _construct_dummy_obs(spaces_dict, counter=[0]): """Construct dummy observation which has an array repeating a different integer as the value of each component.""" dummy_obs = {} for i, (k, v) in enumerate(spaces_dict.items()): if isinstance(v, gym.spaces.Dict): dummy_obs[k] = _construct_dummy_obs(v.spaces, counter) else: dummy_obs[k] = counter * np.ones(v.shape, dtype=np.int32) counter[0] += 1 return dummy_obs dummy_obs = _construct_dummy_obs(input_space.spaces) flat_dummy_obs = gym.spaces.flatten(input_space, dummy_obs) def _get_indices_and_shape(dummy_obs, flat_dummy_obs): indices = {} shape = {} for k, v in dummy_obs.items(): if isinstance(v, dict): indices[k], shape[k] = _get_indices_and_shape(v, flat_dummy_obs) else: where = np.where(flat_dummy_obs == v.flatten()[0])[0] indices[k] = (int(where[0]), int(where[-1]) + 1) shape[k] = v.shape return indices, shape return _get_indices_and_shape(dummy_obs, flat_dummy_obs)
[docs] def get_obs_indices(env: gym.Env) -> Tuple[Dict, Dict]: """Get index ranges that correspond to keys of the observation Dict space in flattened vector. Also returns a dictionary containing the shapes of these observation components. Args: env: The environment with Dict observation space. Returns: - A dictionary with keys corresponding to the observation components and values being tuples of the form (start, end), where start and end are the indices at which the observation component starts and ends. The nested dictionary structure of the observation is preserved. - A dictionary of the same structure but with values being the shapes of the observation components.""" if isinstance(env, gym.vector.VectorEnv): single_env = env.env_fns[0]() single_obs_space = deepcopy(single_env.unwrapped.observation_space) single_env.close() del single_env else: single_obs_space = env.unwrapped.observation_space assert isinstance( single_obs_space, gym.spaces.Dict ), "Environment observation space must be of type Dict." return get_key_indices(single_obs_space)