Source code for layeredrl.models.model

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

import torch
from tianshou.data import Batch


[docs] class Model(ABC, torch.nn.Module): """Abstract base class for dynamics models."""
[docs] def __init__( self, n_models: int = 1, n_particles_per_model: int = 1, device: torch.device = torch.device("cpu"), ): """Initialize the model. Args: n_particles: The number of particles in the ensemble. device: The device to use. """ super().__init__() self.n_models = n_models self.n_particles_per_model = n_particles_per_model self.device = device self.register_buffer("n_total_env_steps", torch.tensor(0, dtype=torch.long))
[docs] @abstractmethod def get_parameters(self) -> torch.Tensor: """Get the parameters of the model. Returns: An iterator over the parameters of the model. """ pass
[docs] def set_n_total_env_steps(self, n_total_env_steps: int) -> None: """Set the total number of environment steps. Args: n_total_env_steps: The total number of environment steps. """ self.n_total_env_steps[...] = n_total_env_steps
[docs] @abstractmethod def loss(self, batch: Batch) -> torch.Tensor: """Compute the loss for the given batch. Args: batch: The batch. The first dimension corresponds to the batch dimension (e.g. environments). For example, batch.state.shape = (batch_size, per_env_size, state_dim) Returns: The loss. """ pass
[docs] def learn(self, batch_lst: List[Batch]) -> None: """Learn from the given batch. Args: batch_lst: A list of training batches. The first dimension corresponds to the transitions. Returns: The loss after the updates. """ pass
[docs] @abstractmethod def rollout( self, initial_state: torch.Tensor, context: torch.Tensor, actions: torch.Tensor, deterministic: bool = False, ) -> torch.Tensor: """Rollout with the given actions from the given initial state (open loop). Note that everything is assumed to have a 'batch' dimension, useful for parallelizing, e.g. for vectorized environments. Args: initial_state: The initial state. Shape: (batch_size, state_dim) context: The context, i.e., information that is constant over the whole rollout. actions: The actions. Shape: (batch_size, horizon, action_dim) deterministic: Whether to use the mean of the predicted distribution or to sample from it. Returns: Batch containing the resulting states, termination probabilities and aleatoric and epistemic uncertanties. state shape: (batch_size, n_models, n_particles_per_model, horizon + 1, state_dim) state_mean shape: (batch_size, n_models, horizon + 1, state_dim) term_prob shape: (batch_size, n_models, n_particles_per_model, horizon) """ pass
[docs] @abstractmethod def predict( self, state: torch.Tensor, context: torch.Tensor, action: torch.Tensor, std: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, ...]: """Predict the next state given the current state and action. Note that everything is assumed to have a 'batch' dimension, useful for parallelizing, e.g. for vectorized environments. This uses the mean of the predicted distribution and does not sample. Args: state: The current state. context: The context, i.e., information that is constant over timesteps. action: The action. std: Overwrites the standard deviation that the model predicts if provided. Returns: Mean, weights and standard deviations of the modes of the mixture of Gaussians that make up the ensemble. Also averaged termination probability. Shape for state and std: (batch_size, n_models, n_modes, state_dim) Shape for weights: (batch_size, n_models, n_modes) Shape for term_prob: (batch_size) """ pass
[docs] @abstractmethod def sample( self, state: torch.Tensor, context: torch.Tensor, action: torch.Tensor ) -> torch.Tensor: """Sample the next state given the current state and action. Note that everything is assumed to have a 'batch' dimension, useful for parallelizing, e.g. for vectorized environments. Args: state: The current state. context: The context, i.e., information that is constant over timesteps. action: The action. Returns: The sampled next state. """ pass
[docs] def get_prob( self, state: torch.Tensor, context: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, ) -> Tuple: """Get probability (density) of next state, the termination probability, and an info dict. Note that everything is assumed to have a 'batch' dimension, useful for parallelization. Args: state: The current state. context: The context, i.e., information that is constant over timesteps. action: The action. next_state: The next state. Returns: A tuple containing: - The probability (density) of the next state given the current state and action under the model. - The termination probability given the current state and action under the model. - An info dict with additional information. """ log_prob, term_prob, info = self.get_log_prob( state, context, action, next_state ) # clip from below before applying exp to avoid nan log_prob = log_prob.clamp(min=-700.0) return torch.exp(log_prob), term_prob, info
[docs] @abstractmethod def get_log_prob( self, state: torch.Tensor, context: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, ) -> Tuple: """Get log of probability (density) of next state, the termination probability, and an info dict. Note that everything is assumed to have a 'batch' dimension, useful for parallelization. Args: state: The current state. context: The context, i.e., information that is constant over timesteps. action: The action. next_state: The next state. Returns: A tuple containing: - The log probability (density) of the next state given the current state and action under the model. - The termination probability given the current state and action under the model. - An info dict with additional information. """ pass