Source code for layeredrl.predictors

from functools import partial

import gymnasium as gym

from .predictor import Predictor
from .predictor_factory import PredictorFactory
from .reward_predictor import RewardPredictor
from .static_predictor import StaticPredictor
from ..nets import ProbFCDynamics, RewardNet, FixedEncoderNet, ValueNet
from ..models import ProbabilisticEnsemble
from ..utils.misc import get_obs_indices


[docs] def get_default_predictor_factory( env: gym.Env, sb_start_duration: float ) -> PredictorFactory: """Get a default predictor factory. The predictor factory creates a RewardPredictor object. The method assumes that the environment is goal-based and interprets the desired goal as the context and the achieved goal as the state for the planner level. Args: env: The environment for which to create the predictor factory. sb_start_duration: Duration (in env steps) for symmetry breaking start. Returns: The predictor factory. """ # Get index ranges in flattened observation for keys of observation dict obs_indices, _ = get_obs_indices(env) # Networks for the predictor (partial because correct dimensions are automatically set # during assembly of the hierarchy) partial_net = partial(ProbFCDynamics) partial_model = partial( ProbabilisticEnsemble, partial_net=partial_net, symmetry_breaking_start=True, sb_start_duration=sb_start_duration, sb_start_factor=1.0, n_models=1, n_modes=4, n_particles_per_model=1, normalize_targets=True, target_bn_momentum=0.001, ) partial_val_func = partial(ValueNet) partial_rew_func = partial(RewardNet) # Pick out desired goal as context and position and velocity as latent state partial_encoder = partial( FixedEncoderNet, latent_state_dims=range(*obs_indices["achieved_goal"]), context_dims=range(*obs_indices["desired_goal"]), ) # Predictor factory partial_predictor = partial( RewardPredictor, learn_encoder=False, ) predictor_factory = PredictorFactory( partial_model=partial_model, partial_val_func=partial_val_func, partial_rew_func=partial_rew_func, partial_encoder=partial_encoder, partial_predictor=partial_predictor, ) return predictor_factory
__all__ = [ "Predictor", "PredictorFactory", "RewardPredictor", "StaticPredictor", "get_default_predictor_factory", ]