Source code for layeredrl.utils.loggers

from typing import Optional
from pathlib import Path

from gymnasium import Env
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip


[docs] class VideoLogger: """Class for logging video clips of rollouts."""
[docs] def __init__( self, log_dir: str, fps: int = 10, max_episodes: int = 10, camera_name: Optional[str] = None, ): """Initialize the logger. Args: log_dir: The directory to save the videos. fps: The frames per second of the video. """ self.log_dir = log_dir self.fps = fps self.frames = [] self.max_episodes = max_episodes self.camera_name = camera_name self._episode_counter = 0
[docs] def log_frame(self, env: Env, terminated: bool, truncated: bool): """Log a frame from the environment. Requires the environments to be in render mode 'rbg_array'. Args: env: The environment. terminated: Whether the episode is terminated. truncated: Whether the episode is truncated. """ if self._episode_counter < self.max_episodes: unwrapped = env.envs[0].unwrapped if self.camera_name is not None: env = unwrapped camera_name = self.camera_name # requires environment to have a bird view camera elif hasattr(unwrapped, "ant_env"): env = unwrapped.ant_env camera_name = "bird" else: env = unwrapped camera_name = None # render to make sure all markers are updated etc. render = env.render() # render again with the right camera render = env.mujoco_renderer.render( render_mode="rgb_array", camera_name=camera_name ) assert render is not None, "Environment must be in render mode 'rgb_array'." self.frames.append(render) if terminated[0] or truncated[0]: self._episode_counter += 1 print("Recorded an episode on video.") if self._episode_counter == self.max_episodes: self.save_video() self.frames = [] self._episode_counter += 1
[docs] def save_video(self): """Save video to disk.""" path = self.log_dir / Path("rollouts.mp4") print(f"Saving video to {path}") clip = ImageSequenceClip(self.frames, fps=self.fps) clip.write_videofile(str(path))