Source code for layeredrl.nets.concat_net

from typing import Any, Dict, Tuple

import numpy as np
import torch

from tianshou.utils.net.common import Net


[docs] class ConcatNet(Net): """A network that concatenates mapped_env_obs, level_input, action, n_remaining_states, and level_state along the last dimension. For use with Tianshou level."""
[docs] def __init__( self, mapped_env_obs_shape: int, level_input_dim: int, level_state_dims: Dict[str, int], scale_remaining_steps: float = 1.0, *args, **kwargs, ): """Initialize the network.""" self.mapped_env_obs_shape = mapped_env_obs_shape self.mapped_env_obs_rank = len(mapped_env_obs_shape) self.scale_remaining_steps = scale_remaining_steps state_shape = ( np.prod(mapped_env_obs_shape) + level_input_dim + sum(level_state_dims.values()), ) super().__init__(state_shape=state_shape, *args, **kwargs)
[docs] def forward( self, obs: Dict, state: Any = None, **kwargs: Any ) -> Tuple[torch.Tensor, Any]: """Concatenate the inputs along the last dimension.""" element_lst = [ obs["mapped_env_obs"].flatten(start_dim=-self.mapped_env_obs_rank) ] if obs["level_input"] is not None: element_lst.append(obs["level_input"]) if "action" in obs: element_lst.append(obs["action"]) if "n_remaining_steps" in obs["level_state"]: # rescale the remaining steps, usually to [0, 1] element_lst.append( obs["level_state"]["n_remaining_steps"] * self.scale_remaining_steps ) element_lst.extend( v for k, v in obs["level_state"].items() if k != "n_remaining_steps" ) total_obs = torch.cat(element_lst, dim=-1) return super().forward(total_obs, state, **kwargs)