Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 14, 2024
1 parent b0ea50d commit 15f1f0d
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
)

def _preceding_stop_idx(self, storage, lengths, seq_length):
print('lengths', lengths)
preceding_stop_idx = self._cache.get("preceding_stop_idx")
if preceding_stop_idx is not None:
return preceding_stop_idx
Expand Down Expand Up @@ -1841,6 +1842,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
seq_length, num_slices = self._adjusted_batch_size(batch_size)

preceding_stop_idx = self._preceding_stop_idx(storage, lengths, seq_length)
preceding_stop_idx = (preceding_stop_idx + start_idx[0, 0]) % storage._len_along_dim0
if storage.ndim > 1:
# we need to convert indices of the permuted, flatten storage to indices in a flatten storage (not permuted)
# This is because the lengths come as they would for a permuted storage
Expand Down

0 comments on commit 15f1f0d

Please sign in to comment.