From 15f1f0d489af84c0a2b186c24d62e80fe2fed6d1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Jun 2024 09:16:45 +0100 Subject: [PATCH] amend --- torchrl/data/replay_buffers/samplers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 822fbc59bc9..bcacbbcbbc7 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1806,6 +1806,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811 ) def _preceding_stop_idx(self, storage, lengths, seq_length): + print('lengths', lengths) preceding_stop_idx = self._cache.get("preceding_stop_idx") if preceding_stop_idx is not None: return preceding_stop_idx @@ -1841,6 +1842,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] seq_length, num_slices = self._adjusted_batch_size(batch_size) preceding_stop_idx = self._preceding_stop_idx(storage, lengths, seq_length) + preceding_stop_idx = (preceding_stop_idx + start_idx[0, 0]) % storage._len_along_dim0 if storage.ndim > 1: # we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted) # This is because the lengths come as they would for a permuted storage