Source code for layeredrl.utils.normalization

import torch
import torch.nn as nn
from torch.nn import Module


[docs] class RunningBatchNorm(Module): """Batch normalization using running estimates of mean and variance."""
[docs] def __init__( self, num_features: int, eps=1.0e-5, momentum=0.1, device=None, dtype=None, track_mean=True, freeze_after=None, ): super().__init__() self._num_features = num_features self._eps = eps self._momentum = momentum self._track_mean = track_mean self._freeze_after = freeze_after self.mean = torch.nn.Parameter( torch.zeros(num_features, device=device, dtype=dtype), requires_grad=False, ) self.var = torch.nn.Parameter( torch.ones(num_features, device=device, dtype=dtype), requires_grad=False, ) # register 'initialized' buffer self.register_buffer("_initialized", torch.tensor(0, dtype=torch.uint8)) self.register_buffer("_counter", torch.tensor(0, dtype=torch.int64))
[docs] def forward(self, x: torch.Tensor): if self.training and ( self._freeze_after is None or self._counter < self._freeze_after ): if not self._initialized: # use statistics of first batch for initialization if self._track_mean: self.mean.data = x.mean(dim=0) self.var.data = x.var(dim=0) self._initialized = torch.tensor(1, dtype=torch.uint8) else: if self._track_mean: self.mean.data = ( 1 - self._momentum ) * self.mean + self._momentum * x.mean(dim=0) self.var.data = ( 1 - self._momentum ) * self.var + self._momentum * x.var(dim=0) self._counter += 1 return (x - self.mean) / (self.var + self._eps).sqrt()
[docs] class Standardizer(Module): """Standardize a random vector (trained externally)."""
[docs] def __init__(self, latent_state_dim: int, device: torch.device): super().__init__() self.mat = nn.Parameter(torch.eye(latent_state_dim, device=device))
[docs] def get_symm_mat(self): return 0.5 * (self.mat.triu() + self.mat.triu().T)
[docs] def forward(self, x): return x @ self.get_symm_mat()