Collectors API
- class layeredrl.collectors.Collector(hierarchy: Hierarchy, env: Env, test_env: Env | None = None, device=device(type='cpu'), writer: SummaryWriter | None = None, checkpoint_dir: Path | None = None, checkpoint_interval: int | None = None)[source]
Bases:
object- __init__(hierarchy: Hierarchy, env: Env, test_env: Env | None = None, device=device(type='cpu'), writer: SummaryWriter | None = None, checkpoint_dir: Path | None = None, checkpoint_interval: int | None = None)[source]
Initialize the collector.
- Parameters:
hierarchy – The hierarchy to collect data with.
env – The environment to collect data/train in.
test_env – The environment to test in.
device – The device to use.
writer – The TensorBoard writer to use for logging. If None, no logging is done.
checkpoint_dir – The directory to save checkpoints to. If None, no checkpoints are saved.
checkpoint_interval – The interval in steps between checkpoints. If None, only the final checkpoint is saved.
- collect(n_steps: int, env_expects_numpy: bool = True, record_transitions: bool = False, learn: bool = False, n_steps_start: int = 0, log_interval: int = 100, test_interval: int | None = None, n_test_steps: int = 1000, verbose: bool = False, post_step_callback: Callable | None = None, video_logger: VideoLogger | None = None) Tuple | Batch[source]
Collect transitions from the environment with the hierarchical policy.
This collects different transitions on every level of the hierarchy as the higher levels see semi MDPs.
- Parameters:
n_steps – The number of steps to collect in each environment instance. The total number of
n_envs. (collected steps is therefore n_steps *)
env_expects_numpy – Whether the environment expects numpy arrays as input.
record_transitions – Whether to record the environment transitions and return them in a batch. Note that the first dimension of the batch corresponds to the step, not the number of environment instances.
learn – Whether to learn after each step.
n_steps_start – The number of steps that have already been collected. This is useful for for resuming an experiment.
log_interval – The interval in steps between logging.
test_interval – The interval in vector environment steps between testing the hierarchy. If None, no testing is done.
n_test_steps – The number of vector environment steps to test the hierarchy for at each test interval.
verbose – Whether to print progress.
post_step_callback – A callback function that is called after each step. The callback function should take the current step and the next observation as an argument.
video_logger – A VideoLogger object to log videos of the rollouts.
- Returns:
The statistics of the rollouts and (if record_transitions is Ture) the collected transitions in a Batch object.
- save_checkpoint(t: int, n_steps: int)[source]
Save a checkpoint of the hierarchy.
- Parameters:
t – The current step.
- test(t: int, test_hierarchy: Hierarchy, n_steps: int, env_expects_numpy: bool = True, video_logger: VideoLogger | None = None) dict[source]
Test the hierarchy in the environment.
Note: This resets the test environment and the test hierarchy.
- Parameters:
t – The current training step.
n_steps – The number of steps to test the hierarchy.
env_expects_numpy – Whether the environment expects numpy arrays as input.
video_logger – A VideoLogger object to log videos of the rollouts.
- Returns:
The statistics of the test run.