Source code for layeredrl.nets.encoder_net

from typing import Tuple

import torch
from torch import nn

from ..utils.normalization import RunningBatchNorm
from .encoder import Encoder


[docs] class EncoderNet(Encoder): """A linear encoder mapping env obs to a latent state and a context variable."""
[docs] def __init__( self, mapped_env_obs_shape: Tuple[int, ...], latent_state_dim: int, context_dim: int, standardize: bool = False, bn_momentum: float = 0.1, freeze_after: int = None, device: torch.device = torch.device("cpu"), ): """Initialize the network. Args: mapped_env_obs_dim: The dimension of the mapped environment observation space. latent_state_dim: The dimension of the latent state space. standardize: Whether to standardize the input. bn_momentum: The momentum for the batch norm. freeze_after: The number of batches after which to freeze the normalization. device: The device to use. """ super().__init__() self._latent_state_dim = latent_state_dim self._context_dim = context_dim self.standardize = standardize self.device = device assert ( len(mapped_env_obs_shape) == 1 ), "Only 1D mapped observations are supported." if standardize: self.input_bn = RunningBatchNorm( num_features=mapped_env_obs_shape[0], momentum=bn_momentum, freeze_after=freeze_after, device=self.device, dtype=torch.float32, ) else: self.input_bn = None def get_net(output_dim): return nn.Linear( mapped_env_obs_shape[0], output_dim, bias=False, device=device, ) self.state_net = get_net(latent_state_dim) self.context_net = get_net(context_dim) self.mask = torch.nn.Parameter( torch.zeros(mapped_env_obs_shape[0], device=device, dtype=torch.bool), requires_grad=False, )
[docs] def forward( self, mapped_env_obs: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: if self.standardize: mapped_env_obs = self.input_bn(mapped_env_obs) context = self.context_net(~(self.mask) * mapped_env_obs) state = self.state_net(self.mask * mapped_env_obs) return state, context
@property def latent_state_dim(self) -> int: """Dimension of the latent state space.""" return self._latent_state_dim @property def context_dim(self) -> int: """Dimension of the context space.""" return self._context_dim