Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent 5535a05 commit 0f6187a
Show file tree
Hide file tree
Showing 15 changed files with 693 additions and 69 deletions.
4 changes: 4 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ vmas
onnxscript
onnxruntime
onnx
plotly
igraph
transformers
datasets
Binary file added docs/source/_static/img/rollout-llm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Intermediate
tutorials/dqn_with_rnn
tutorials/rb_tutorial
tutorials/export
tutorials/beam_search_with_gpt

Advanced
--------
Expand Down
10 changes: 7 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,14 +1776,18 @@ def __init__(self):
tensor=Unbounded(3),
non_tensor=NonTensor(shape=()),
)
self._saved_obs_spec = self.observation_spec.clone()
self.state_spec = Composite(
non_tensor=NonTensor(shape=()),
)
self._saved_state_spec = self.state_spec.clone()
self.reward_spec = Unbounded(1)
self._saved_full_reward_spec = self.full_reward_spec.clone()
self.action_spec = Unbounded(1)
self._saved_full_action_spec = self.full_action_spec.clone()

def _reset(self, tensordict):
data = self.observation_spec.zero()
data = self._saved_obs_spec.zero()
data.set_non_tensor("non_tensor", 0)
data.update(self.full_done_spec.zero())
return data
Expand All @@ -1792,10 +1796,10 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
data = self.observation_spec.zero()
data = self._saved_obs_spec.zero()
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
data.update(self.full_done_spec.zero())
data.update(self.full_reward_spec.zero())
data.update(self._saved_full_reward_spec.zero())
return data

def _set_seed(self, seed: Optional[int]):
Expand Down
11 changes: 8 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,8 +3526,13 @@ def test_single_env_spec():
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


def test_auto_spec():
env = CountingEnv()
@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata])
def test_auto_spec(env_type):
if env_type is EnvWithMetadata:
obs_vals = ["tensor", "non_tensor"]
else:
obs_vals = "observation"
env = env_type()
td = env.reset()

policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
Expand All @@ -3550,7 +3555,7 @@ def test_auto_spec():
shape=env.full_state_spec.shape, device=env.full_state_spec.device
)
env._action_keys = ["action"]
env.auto_specs_(policy, tensordict=td.copy())
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
env.check_env_specs(tensordict=td.copy())


Expand Down
1 change: 1 addition & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def _can_be_pickled(obj):
def _make_ordinal_device(device: torch.device):
if device is None:
return device
device = torch.device(device)
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
if device.type == "mps" and device.index is None:
Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/map/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
class SipHash(Module):
"""A Module to Compute SipHash values for given tensors.
A hash function module based on SipHash implementation in python.
A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]``
and the output shape will be ``[batch_size]``.
Args:
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
Expand Down
20 changes: 17 additions & 3 deletions torchrl/data/map/tdstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def from_tensordict_pair(
collate_fn: Callable[[Any], Any] | None = None,
write_fn: Callable[[Any, Any], Any] | None = None,
consolidated: bool | None = None,
):
) -> TensorDictMap:
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
Args:
Expand Down Expand Up @@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
if not self._has_lazy_out_keys():
# TODO: make this work with pytrees and avoid calling select if keys match
value = value.select(*self.out_keys, strict=False)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
if index.unique().numel() < index.numel():
# If multiple values point to the same place in the storage, we cannot process them by batch
# There could be a better way to deal with this, using unique ids.
vals = []
for it, val in zip(item.split(1), value.split(1)):
self[it] = val
vals.append(val)
# __setitem__ may affect the content of the input data
value.update(TensorDictBase.lazy_stack(vals))
return
if self.write_fn is not None:
# We use this block in the following context: the value written in the storage is already present,
# but it needs to be updated.
# We first check if the value is already there using `contains`. If so, we pass the new value and the
# previous one to write_fn. The values that are not present are passed alone.
if len(self):
modifiable = self.contains(item)
if modifiable.any():
Expand All @@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
value = self.write_fn(value)
else:
value = self.write_fn(value)
item, value = self._maybe_add_batch(item, value)
index = self._to_index(item, extend=True)
self.storage.set(index, value)

def __len__(self):
Expand Down
Loading

0 comments on commit 0f6187a

Please sign in to comment.