Source code for layeredrl.utils.buffers

from typing import List, Tuple, Union

import numpy as np
from tianshou.data import Batch, VectorReplayBuffer
import torch


[docs] class ToDeviceReplayBuffer(VectorReplayBuffer): """Replay buffer that moves batch to target device after sampling."""
[docs] def __init__(self, target_device=torch.device("cpu"), *args, **kwargs): """Initialize the wrapper. Args: target_device: The target device to move the batch to. """ super().__init__(*args, **kwargs) self.target_device = target_device
[docs] def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Sample from replay buffer and move batch to target device. If batch_size is 0, return all the data in the buffer Args: batch_size: The batch size. Returns: Sample data and its corresponding indices inside the buffer. """ batch, indices = super().sample(batch_size) batch.to_torch(device=self.target_device) return batch, indices
[docs] def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: """Return a data batch: self[index].""" batch = super().__getitem__(index) batch.to_torch(device=self.target_device) return batch