Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 12, 2024
2 parents ae2bc40 + 71226b1 commit a874ea2
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 15 deletions.
25 changes: 25 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import gc
import os.path
import random
import re
from collections import defaultdict
from functools import partial
Expand Down Expand Up @@ -114,6 +115,7 @@
DoubleToFloat,
EnvBase,
EnvCreator,
LLMHashingEnv,
ParallelEnv,
PendulumEnv,
SerialEnv,
Expand Down Expand Up @@ -3419,6 +3421,29 @@ def test_pendulum_env(self, device):
r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device))
assert r.shape == torch.Size((5, 10))

def test_llm_hashing_env(self):
vocab_size = 5

class Tokenizer:
def __call__(self, obj):
return torch.randint(vocab_size, (len(obj.split(" ")),)).tolist()

def decode(self, obj):
words = ["apple", "banana", "cherry", "date", "elderberry"]
return " ".join(random.choice(words) for _ in obj)

def batch_decode(self, obj):
return [self.decode(_obj) for _obj in obj]

def encode(self, obj):
return self(obj)

tokenizer = Tokenizer()
env = LLMHashingEnv(tokenizer=tokenizer, vocab_size=vocab_size)
td = env.make_tensordict("some sentence")
assert isinstance(td, TensorDict)
env.check_env_specs(tensordict=td)


@pytest.mark.parametrize("device", [None, *get_default_devices()])
@pytest.mark.parametrize("env_device", [None, *get_default_devices()])
Expand Down
15 changes: 10 additions & 5 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,35 +135,40 @@ def make_node(
def full_observation_spec(self):
"""The observation spec of the tree.
This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`."""
This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.
"""
return self.specs["output_spec", "full_observation_spec"]

@property
def full_reward_spec(self):
"""The reward spec of the tree.
This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`."""
This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.
"""
return self.specs["output_spec", "full_reward_spec"]

@property
def full_done_spec(self):
"""The done spec of the tree.
This is an alias for `Tree.specs['output_spec', 'full_done_spec']`."""
This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.
"""
return self.specs["output_spec", "full_done_spec"]

@property
def full_state_spec(self):
"""The state spec of the tree.
This is an alias for `Tree.specs['input_spec', 'full_state_spec']`."""
This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.
"""
return self.specs["input_spec", "full_state_spec"]

@property
def full_action_spec(self):
"""The action spec of the tree.
This is an alias for `Tree.specs['input_spec', 'full_action_spec']`."""
This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.
"""
return self.specs["input_spec", "full_action_spec"]

@property
Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ def auto_specs_(

@wraps(check_env_specs_func)
def check_env_specs(self, *args, **kwargs):
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
kwargs["return_contiguous"] = return_contiguous
return check_env_specs_func(self, *args, **kwargs)

check_env_specs.__doc__ = check_env_specs_func.__doc__
Expand Down Expand Up @@ -3285,7 +3287,10 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
if self._simple_done:
done = tensordict._get_str("done", default=None)
any_done = done is not None and done.any()
if done is not None:
any_done = done.any()
else:
any_done = False
if any_done:
tensordict._set_str(
"_reset",
Expand Down
28 changes: 21 additions & 7 deletions torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Callable, List, Union

import torch
from tensordict import NestedKey, TensorDictBase
from tensordict import NestedKey, TensorDict, TensorDictBase
from tensordict.tensorclass import NonTensorData, NonTensorStack

from torchrl.data import (
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
self.observation_key = observation_key
observation_spec = {
observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)),
"hash": Unbounded(shape=(1,), dtype=torch.int64),
"hashing": Unbounded(shape=(1,), dtype=torch.int64),
}
self.text_output = text_output
if not text_output:
Expand All @@ -117,6 +117,16 @@ def __init__(
self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,)))
_StepMDP(self)

def make_tensordict(self, input: str | List[str]) -> TensorDict:
"""Converts a string or list of strings in a TensorDict with appropriate shape and device."""
list_len = len(input) if isinstance(input, list) else 0
tensordict = TensorDict(
{self.observation_key: self._tokenizer(input)}, device=self.device
)
if list_len:
tensordict.batch_size = [list_len]
return self.reset(tensordict)

def _reset(self, tensordict: TensorDictBase):
"""Initializes the environment with a given observation.
Expand All @@ -128,7 +138,11 @@ def _reset(self, tensordict: TensorDictBase):
"""
out = tensordict.empty()
obs = tensordict.get(self.observation_key)
obs = tensordict.get(self.observation_key, None)
if obs is None:
raise RuntimeError(
f"Resetting the {type(self).__name__} environment requires a prompt."
)
if self.text_output:
if obs.ndim > 1:
text = self._tokenizer.batch_decode(obs)
Expand All @@ -139,9 +153,9 @@ def _reset(self, tensordict: TensorDictBase):
out.set(self.text_key, text)

if obs.ndim > 1:
out.set("hash", self._hashing_module(obs).unsqueeze(-1))
out.set("hashing", self._hashing_module(obs).unsqueeze(-1))
else:
out.set("hash", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))
out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))

if not self.full_done_spec.is_empty():
out.update(self.full_done_spec.zero(tensordict.shape))
Expand All @@ -166,7 +180,7 @@ def _step(self, tensordict):
obs = torch.cat([tensordict.get(self.observation_key), action], -1)
kwargs = {self.observation_key: obs}

catval = torch.cat([tensordict.get("hash"), action], -1)
catval = torch.cat([tensordict.get("hashing"), action], -1)
if obs.ndim > 1:
new_hash = self._hashing_module(catval).unsqueeze(-1)
else:
Expand All @@ -182,7 +196,7 @@ def _step(self, tensordict):
kwargs[self.text_key] = text
kwargs.update(
{
"hash": new_hash,
"hashing": new_hash,
"done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool),
"terminated": torch.zeros(
(*tensordict.batch_size, 1), dtype=torch.bool
Expand Down
9 changes: 7 additions & 2 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,15 @@ def check_env_specs(
)
zeroing_err_msg = (
"zeroing the two tensordicts did not make them identical. "
"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
)
from torchrl.envs.common import _has_dynamic_specs

if _has_dynamic_specs(env.specs):
for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)):
for real, fake in zip(
real_tensordict.filter_non_tensor_data().unbind(-1),
fake_tensordict.filter_non_tensor_data().unbind(-1),
):
fake = fake.apply(lambda x, y: x.expand_as(y), real)
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
raise AssertionError(zeroing_err_msg)
Expand Down Expand Up @@ -1367,6 +1370,8 @@ def _update_during_reset(
reset_keys: List[NestedKey],
):
"""Updates the input tensordict with the reset data, based on the reset keys."""
if not reset_keys:
return tensordict.update(tensordict_reset)
roots = set()
for reset_key in reset_keys:
# get the node of the reset key
Expand Down

0 comments on commit a874ea2

Please sign in to comment.