Source code for layeredrl.optimizers.cem

from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch
from torch import Tensor

from .optimizer import Optimizer
from ..utils.distributions import sample_truncated_normal


[docs] class CEM(Optimizer):
[docs] def __init__( self, n_iterations: int, n_samples: int, initial_sigma: torch.Tensor, elite_ratio: float = 0.2, lower_bound: Optional[torch.Tensor] = None, upper_bound: Optional[torch.Tensor] = None, clip: bool = False, momentum: float = 0.0, return_mode: str = "mean", record_samples: bool = False, *args, **kwargs, ): """Initialize the CEM optimizer. Args: n_iterations: The number of iterations to run CEM for. n_samples: The number of samples to use draw in each iteration of CEM. initial_sigma: The initial standard deviation of the samples. Shape: (batch_size, x_dim) elite_ratio: The ratio of samples to keep per iteration. lower_bound: The lower bound of the samples. upper_bound: The upper bound of the samples. Either both lower_bound and upper_bound must be None or neither. clip: Whether to clip the samples to the bounds. momentum: Momentum factor for updating the mean. return_mode: Whether to return the mean of the elite samples ("mean"), the best sample ("best"), or a random sample from the last Gaussian distribution ("random"). record_samples: Whether to record the samples drawn during optimization. """ super().__init__(*args, n_samples=n_samples, **kwargs) self.n_iterations = n_iterations self.initial_sigma = initial_sigma self.elite_ratio = elite_ratio self.lower_bound = ( lower_bound.to(device=self.device) if lower_bound is not None else None ) self.upper_bound = ( upper_bound.to(device=self.device) if upper_bound is not None else None ) self.clip = clip self.momentum = momentum assert return_mode in ["mean", "best", "random"] self.return_mode = return_mode self.record_samples = record_samples self.samples = [] self.n_elites = int(self.n_samples * self.elite_ratio) assert self.n_elites > 1 self.mu = None self.sigma = None self.truncated = lower_bound is not None and upper_bound is not None
[docs] def reset(self, initial_x: torch.Tensor) -> None: """Reset the optimizer to the given initial guess. If this is not called, the optimizer will derive the initial guess from its internal state. Args: initial_x: The initial guess for the optimal x. Shape: (batch_size, dim) """ self.mu = initial_x.to(self.device) self.batch_size = initial_x.shape[0] self.dim = initial_x.shape[1] self.samples = []
def _sample(self) -> torch.Tensor: if self.truncated and not self.clip: # sample from truncated normal return sample_truncated_normal( self.mu, self.sigma, self.lower_bound, self.upper_bound, self.n_samples, self.device, ) else: exp_mu = self.mu[:, None, :].expand( (self.batch_size, self.n_samples, self.dim) ) exp_sigma = self.sigma[:, None, :].expand( (self.batch_size, self.n_samples, self.dim) ) samples = exp_mu + exp_sigma * torch.randn( (self.batch_size, self.n_samples, self.dim), device=self.device ) if self.clip and self.truncated: samples = torch.clamp(samples, self.lower_bound, self.upper_bound) return samples
[docs] def optimize( self, cost: Callable[[Tensor], Tensor], verbose: bool = False ) -> Tuple[torch.Tensor, Dict]: """Optimize the given cost function using the Cross Entropy Method and return the optimal x. Note that everything is assumed to have a batch dimension. That includes x and the cost. Args: cost: A function that takes in a tensor x and returns a tensor with a scalar cost for each particle in each environment instance. Output shape: (batch_size, n_samples) Returns: The optimal x. Shape: (batch_size, dim) An info dict containing the keys "mean" and "std" for the mean and standard deviation of the final distribution.""" assert ( self.mu is not None ), "Please call reset before calling optimize for the first time." best_cost = np.inf * torch.ones( self.batch_size, device=self.device, dtype=torch.float32 ) best_x = torch.zeros( (self.batch_size, self.dim), device=self.device, dtype=torch.float32 ) self.sigma = self.initial_sigma for i in range(self.n_iterations): x_samples = self._sample() if self.record_samples: self.samples.append(x_samples) costs = cost(x_samples) elite_costs, elite_indices = torch.topk( costs, self.n_elites, dim=1, largest=False, sorted=True ) elite_samples = torch.gather( x_samples, 1, elite_indices[:, :, None].expand((-1, -1, self.dim)) ) self.mu = self.momentum * self.mu + ( 1.0 - self.momentum ) * elite_samples.mean(dim=1) self.sigma = self.momentum * self.sigma + ( 1.0 - self.momentum ) * elite_samples.std(dim=1) where_better = elite_costs[:, 0] < best_cost best_cost = torch.where(where_better, elite_costs[:, 0], best_cost) best_x = torch.where( where_better[:, None].expand((self.batch_size, self.dim)), elite_samples[:, 0, :], best_x, ) if verbose: print( f"CEM Iteration {i}: cost={elite_costs.mean():.3f}+-{elite_costs.std():.3f};" ) info = {"mean": self.mu, "std": self.sigma} if self.return_mode == "mean": return self.mu, info elif self.return_mode == "best": return best_x, info else: return self._sample()[:, 0, :], info