Source code for layeredrl.nets.critic

from typing import Any, Dict, Optional, Union

import numpy as np
import torch

from tianshou.data import Batch
from tianshou.utils.net.continuous import Critic as TianshouCritic


[docs] class Critic(TianshouCritic): """Critic network for Tianshou level."""
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Set weights of last layer to zero self.last.model[-1].weight.data.zero_()
[docs] def forward( self, obs: Union[np.ndarray, torch.Tensor], act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> torch.Tensor: """Same as Tianshou critic but do not flatten obs and act before passing to preprocess_net.""" if act is not None: obs_dict = {k: v for k, v in obs.items()} obs_dict["action"] = act obs = Batch(obs_dict) logits, hidden = self.preprocess(obs) logits = self.last(logits) return logits