Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 6, 2024
1 parent 80621aa commit afc0f24
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 49 deletions.
3 changes: 2 additions & 1 deletion torchrl/data/map/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
class SipHash(Module):
"""A Module to Compute SipHash values for given tensors.
A hash function module based on SipHash implementation in python.
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
and the output shape will be ``[batch_size]``.
Args:
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
Expand Down
32 changes: 28 additions & 4 deletions torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def __init__(
self.collate_fn = collate_fn
self.write_fn = write_fn

@property
def max_size(self):
return self.storage.max_size

@property
def out_keys(self) -> List[NestedKey]:
out_keys = self.__dict__.get("_out_keys_and_lazy")
Expand Down Expand Up @@ -177,7 +181,7 @@ def from_tensordict_pair(
collate_fn: Callable[[Any], Any] | None = None,
write_fn: Callable[[Any, Any], Any] | None = None,
consolidated: bool | None = None,
):
) -> TensorDictMap:
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
Args:
Expand Down Expand Up @@ -238,7 +242,13 @@ def from_tensordict_pair(
n_feat = 0
hash_module = []
for in_key in in_keys:
n_feat = source[in_key].shape[-1]
entry = source[in_key]
if entry.ndim == source.ndim:
# this is a good example of why td/tc are useful - carrying metadata
# allows us to know if there's a feature dim or not
n_feat = 0
else:
n_feat = entry.shape[-1]
if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT:
_hash_module = RandomProjectionHash()
else:
Expand Down Expand Up @@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
if not self._has_lazy_out_keys():
# TODO: make this work with pytrees and avoid calling select if keys match
value = value.select(*self.out_keys, strict=False)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
if index.unique().numel() < index.numel():
# If multiple values point to the same place in the storage, we cannot process them by batch
# There could be a better way to deal with this, using unique ids.
vals = []
for it, val in zip(item.split(1), value.split(1)):
self[it] = val
vals.append(val)
# __setitem__ may affect the content of the input data
value.update(TensorDictBase.lazy_stack(vals))
return
if self.write_fn is not None:
# We use this block in the following context: the value written in the storage is already present,
# but it needs to be updated.
# We first check if the value is already there using `contains`. If so, we pass the new value and the
# previous one to write_fn. The values that are not present are passed alone.
if len(self):
modifiable = self.contains(item)
if modifiable.any():
Expand All @@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
value = self.write_fn(value)
else:
value = self.write_fn(value)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
self.storage.set(index, value)

def __len__(self):
Expand Down
Loading

0 comments on commit afc0f24

Please sign in to comment.