Source code for layeredrl.nets.random_dynamics

from typing import Union

from gymnasium.spaces import Box, Discrete
from scipy.stats import special_ortho_group
import torch
import torch.nn as nn


[docs] class RandomDynamics(nn.Module): """Random but fixed dynamics from random linear transformation. Note: Assumes only state deltas are to be predicted."""
[docs] def __init__( self, state_space: Box, action_space: Union[Box, Discrete], std: float = 0.3, device: torch.device = torch.device("cpu"), n_modes: int = 1, ): """Initialize the model. Args: state_space: The state space. action_space: The action space. device: The device to use. """ super().__init__() self.state_space = state_space self.action_space = action_space assert isinstance(state_space, Box) assert isinstance(action_space, Box) self.std = std self.device = device self.n_modes = n_modes # generate matrix with orthonormal columns for mapping action to state (delta) action_dim = action_space.shape[0] state_dim = state_space.shape[0] self.A = nn.Parameter( torch.tensor( special_ortho_group.rvs(state_dim)[:, :action_dim], device=device ).float(), requires_grad=False, ) self.transform = nn.Parameter( torch.eye(state_dim, device=device), requires_grad=False )
[docs] def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: """Predict the next state given the current state and action. Args: state: The current state. action: The action. Returns: Mean and standard deviation of next state and expected reward. """ s_mean = torch.einsum("ij,jk,...k->...i", self.transform, self.A, action) # Add mode dimension s_mean = s_mean[..., None, :].expand( *s_mean.shape[:-1], self.n_modes, s_mean.shape[-1] ) weights = torch.ones(s_mean.shape[:-1], device=self.device) / self.n_modes s_std = self.std * torch.ones_like(s_mean, device=self.device) return ( (s_mean, s_std), weights, torch.zeros(s_std.shape[:-2], device=self.device), )