diff --git a/test/test_env.py b/test/test_env.py index b48b1a1cf8f..cef7a507f2a 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -8,6 +8,7 @@ import functools import gc import os.path +import random import re from collections import defaultdict from functools import partial @@ -114,6 +115,7 @@ DoubleToFloat, EnvBase, EnvCreator, + LLMHashingEnv, ParallelEnv, PendulumEnv, SerialEnv, @@ -3419,6 +3421,29 @@ def test_pendulum_env(self, device): r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device)) assert r.shape == torch.Size((5, 10)) + def test_llm_hashing_env(self): + vocab_size = 5 + + class Tokenizer: + def __call__(self, obj): + return torch.randint(vocab_size, (len(obj.split(" ")),)).tolist() + + def decode(self, obj): + words = ["apple", "banana", "cherry", "date", "elderberry"] + return " ".join(random.choice(words) for _ in obj) + + def batch_decode(self, obj): + return [self.decode(_obj) for _obj in obj] + + def encode(self, obj): + return self(obj) + + tokenizer = Tokenizer() + env = LLMHashingEnv(tokenizer=tokenizer, vocab_size=vocab_size) + td = env.make_tensordict("some sentence") + assert isinstance(td, TensorDict) + env.check_env_specs(tensordict=td) + @pytest.mark.parametrize("device", [None, *get_default_devices()]) @pytest.mark.parametrize("env_device", [None, *get_default_devices()]) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 513a7b94e58..c09db75aa5b 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -135,35 +135,40 @@ def make_node( def full_observation_spec(self): """The observation spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`. + """ return self.specs["output_spec", "full_observation_spec"] @property def full_reward_spec(self): """The reward spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`. + """ return self.specs["output_spec", "full_reward_spec"] @property def full_done_spec(self): """The done spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_done_spec']`. + """ return self.specs["output_spec", "full_done_spec"] @property def full_state_spec(self): """The state spec of the tree. - This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.""" + This is an alias for `Tree.specs['input_spec', 'full_state_spec']`. + """ return self.specs["input_spec", "full_state_spec"] @property def full_action_spec(self): """The action spec of the tree. - This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.""" + This is an alias for `Tree.specs['input_spec', 'full_action_spec']`. + """ return self.specs["input_spec", "full_action_spec"] @property diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 004f2958a0d..f5d102a6279 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -555,6 +555,8 @@ def auto_specs_( @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): + return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs) + kwargs["return_contiguous"] = return_contiguous return check_env_specs_func(self, *args, **kwargs) check_env_specs.__doc__ = check_env_specs_func.__doc__ @@ -3285,7 +3287,10 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ if self._simple_done: done = tensordict._get_str("done", default=None) - any_done = done is not None and done.any() + if done is not None: + any_done = done.any() + else: + any_done = False if any_done: tensordict._set_str( "_reset", diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 67880f08ee7..2f456482147 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -5,7 +5,7 @@ from typing import Callable, List, Union import torch -from tensordict import NestedKey, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase from tensordict.tensorclass import NonTensorData, NonTensorStack from torchrl.data import ( @@ -103,7 +103,7 @@ def __init__( self.observation_key = observation_key observation_spec = { observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)), - "hash": Unbounded(shape=(1,), dtype=torch.int64), + "hashing": Unbounded(shape=(1,), dtype=torch.int64), } self.text_output = text_output if not text_output: @@ -117,6 +117,16 @@ def __init__( self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,))) _StepMDP(self) + def make_tensordict(self, input: str | List[str]) -> TensorDict: + """Converts a string or list of strings in a TensorDict with appropriate shape and device.""" + list_len = len(input) if isinstance(input, list) else 0 + tensordict = TensorDict( + {self.observation_key: self._tokenizer(input)}, device=self.device + ) + if list_len: + tensordict.batch_size = [list_len] + return self.reset(tensordict) + def _reset(self, tensordict: TensorDictBase): """Initializes the environment with a given observation. @@ -128,7 +138,11 @@ def _reset(self, tensordict: TensorDictBase): """ out = tensordict.empty() - obs = tensordict.get(self.observation_key) + obs = tensordict.get(self.observation_key, None) + if obs is None: + raise RuntimeError( + f"Resetting the {type(self).__name__} environment requires a prompt." + ) if self.text_output: if obs.ndim > 1: text = self._tokenizer.batch_decode(obs) @@ -139,9 +153,9 @@ def _reset(self, tensordict: TensorDictBase): out.set(self.text_key, text) if obs.ndim > 1: - out.set("hash", self._hashing_module(obs).unsqueeze(-1)) + out.set("hashing", self._hashing_module(obs).unsqueeze(-1)) else: - out.set("hash", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1)) + out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1)) if not self.full_done_spec.is_empty(): out.update(self.full_done_spec.zero(tensordict.shape)) @@ -166,7 +180,7 @@ def _step(self, tensordict): obs = torch.cat([tensordict.get(self.observation_key), action], -1) kwargs = {self.observation_key: obs} - catval = torch.cat([tensordict.get("hash"), action], -1) + catval = torch.cat([tensordict.get("hashing"), action], -1) if obs.ndim > 1: new_hash = self._hashing_module(catval).unsqueeze(-1) else: @@ -182,7 +196,7 @@ def _step(self, tensordict): kwargs[self.text_key] = text kwargs.update( { - "hash": new_hash, + "hashing": new_hash, "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), "terminated": torch.zeros( (*tensordict.batch_size, 1), dtype=torch.bool diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 423b71e316e..d2ec66475ab 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -778,12 +778,15 @@ def check_env_specs( ) zeroing_err_msg = ( "zeroing the two tensordicts did not make them identical. " - "Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" + f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" ) from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): - for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)): + for real, fake in zip( + real_tensordict.filter_non_tensor_data().unbind(-1), + fake_tensordict.filter_non_tensor_data().unbind(-1), + ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): raise AssertionError(zeroing_err_msg) @@ -1367,6 +1370,8 @@ def _update_during_reset( reset_keys: List[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" + if not reset_keys: + return tensordict.update(tensordict_reset) roots = set() for reset_key in reset_keys: # get the node of the reset key