Source code for layeredrl.utils.distributions

import math

import torch


[docs] def cdf_normal(value, loc, scale): """Compute the CDF of a normal distribution.""" return 0.5 * (1 + torch.erf((value - loc) * scale.reciprocal() / math.sqrt(2)))
[docs] def sample_truncated_normal(loc, scale, lower_bound, upper_bound, n_samples, device): """Sample from a truncated normal distribution. Note that everything is assumed to have a batch dimension. Sampling is done via the inverse CDF method rather than rejection sampling.""" batch_size = loc.shape[0] dim = loc.shape[1] exp_loc = loc[:, None, :].expand((batch_size, n_samples, dim)) exp_scale = scale[:, None, :].expand((batch_size, n_samples, dim)) a = cdf_normal(lower_bound, loc, scale) b = cdf_normal(upper_bound, loc, scale) diff = (b - a)[:, None, :].expand((batch_size, n_samples, dim)) a = a[:, None, :].expand((batch_size, n_samples, dim)) u = a + diff * torch.rand((batch_size, n_samples, dim), device=device) return exp_loc + exp_scale * torch.erfinv(2 * u - 1) * math.sqrt(2)
[docs] def get_normal_prob( x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor ) -> torch.Tensor: """Get the probability (density) of x under a Normal distribution with the given mean and standard deviation. Args: x: The value. mean: The mean of the Gaussian. std: The standard deviation of the Normal distribution. Returns: The probability (density) of x under the Normal distribution. """ prefactor = 1.0 / torch.sqrt( 2 * torch.tensor(torch.pi) * torch.prod(std, dim=-1) ** 2 ) exponent = -0.5 * (((x - mean) / std) ** 2).sum(dim=-1) return prefactor * torch.exp(exponent)
[docs] def get_normal_log_prob( x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, beta: float = 0.0, ) -> torch.Tensor: """Get the log probability of x under a 1D Normal distribution with the given mean and standard deviation. Can multiply a scale invariance factor to the second term of the log probability to remove the gradient of the log probability with respect to the scale of x and mean. No sum over the last dimension is performed. Args: x: The value. mean: The mean of the Gaussian. std: The standard deviation of the Normal distribution. beta: The beta parameter for \beta-NLL. Returns: The log probability of x under the Normal distribution. """ first_term = -0.5 * torch.log(2 * torch.pi * std**2) second_term = -0.5 * (((x - mean) / std) ** 2) log_prob = first_term + second_term if beta > 0.0: # See https://arxiv.org/abs/2203.09168 log_prob = torch.pow(std, 2.0 * beta).detach() * log_prob return first_term + second_term