diff --git a/docs/requirements.txt b/docs/requirements.txt index e212cd942f4..90efea35854 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -28,3 +28,7 @@ vmas onnxscript onnxruntime onnx +plotly +igraph +transformers +datasets diff --git a/docs/source/_static/img/rollout-llm.png b/docs/source/_static/img/rollout-llm.png new file mode 100644 index 00000000000..b2e63394de1 Binary files /dev/null and b/docs/source/_static/img/rollout-llm.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 2eedc045416..6a448d61c41 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -105,6 +105,7 @@ Intermediate tutorials/dqn_with_rnn tutorials/rb_tutorial tutorials/export + tutorials/beam_search_with_gpt Advanced -------- diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d78e2f27184..4fab4027431 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1776,14 +1776,18 @@ def __init__(self): tensor=Unbounded(3), non_tensor=NonTensor(shape=()), ) + self._saved_obs_spec = self.observation_spec.clone() self.state_spec = Composite( non_tensor=NonTensor(shape=()), ) + self._saved_state_spec = self.state_spec.clone() self.reward_spec = Unbounded(1) + self._saved_full_reward_spec = self.full_reward_spec.clone() self.action_spec = Unbounded(1) + self._saved_full_action_spec = self.full_action_spec.clone() def _reset(self, tensordict): - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", 0) data.update(self.full_done_spec.zero()) return data @@ -1792,10 +1796,10 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1) data.update(self.full_done_spec.zero()) - data.update(self.full_reward_spec.zero()) + data.update(self._saved_full_reward_spec.zero()) return data def _set_seed(self, seed: Optional[int]): diff --git a/test/test_env.py b/test/test_env.py index 81708b0b9a6..1e41b1d403d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3526,8 +3526,13 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) -def test_auto_spec(): - env = CountingEnv() +@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata]) +def test_auto_spec(env_type): + if env_type is EnvWithMetadata: + obs_vals = ["tensor", "non_tensor"] + else: + obs_vals = "observation" + env = env_type() td = env.reset() policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( @@ -3550,7 +3555,7 @@ def test_auto_spec(): shape=env.full_state_spec.shape, device=env.full_state_spec.device ) env._action_keys = ["action"] - env.auto_specs_(policy, tensordict=td.copy()) + env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals) env.check_env_specs(tensordict=td.copy()) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d37aebb862f..c81ffcc962b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -829,6 +829,7 @@ def _can_be_pickled(obj): def _make_ordinal_device(device: torch.device): if device is None: return device + device = torch.device(device) if device.type == "cuda" and device.index is None: return torch.device("cuda", index=torch.cuda.current_device()) if device.type == "mps" and device.index is None: diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 01988dc43be..59526628dbe 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: class SipHash(Module): """A Module to Compute SipHash values for given tensors. - A hash function module based on SipHash implementation in python. + A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]`` + and the output shape will be ``[batch_size]``. Args: as_tensor (bool, optional): if ``True``, the bytes will be turned into integers diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index a601f1e3261..6ff17daaed5 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -177,7 +177,7 @@ def from_tensordict_pair( collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, consolidated: bool | None = None, - ): + ) -> TensorDictMap: """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. Args: @@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): if not self._has_lazy_out_keys(): # TODO: make this work with pytrees and avoid calling select if keys match value = value.select(*self.out_keys, strict=False) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + if index.unique().numel() < index.numel(): + # If multiple values point to the same place in the storage, we cannot process them by batch + # There could be a better way to deal with this, using unique ids. + vals = [] + for it, val in zip(item.split(1), value.split(1)): + self[it] = val + vals.append(val) + # __setitem__ may affect the content of the input data + value.update(TensorDictBase.lazy_stack(vals)) + return if self.write_fn is not None: + # We use this block in the following context: the value written in the storage is already present, + # but it needs to be updated. + # We first check if the value is already there using `contains`. If so, we pass the new value and the + # previous one to write_fn. The values that are not present are passed alone. if len(self): modifiable = self.contains(item) if modifiable.any(): @@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): value = self.write_fn(value) else: value = self.write_fn(value) - item, value = self._maybe_add_batch(item, value) - index = self._to_index(item, extend=True) self.storage.set(index, value) def __len__(self): diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..b88e6a4a2ec 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -15,6 +15,7 @@ TensorClass, TensorDict, TensorDictBase, + unravel_key, ) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree @@ -94,7 +95,7 @@ def num_children(self) -> int: @property def is_terminal(self): - """Returns True if the the tree has no children nodes.""" + """Returns True if the tree has no children nodes.""" return self.subtree is None def get_vertex_by_id(self, id: int) -> Tree: @@ -163,9 +164,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: @@ -206,6 +204,14 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: ) def edges(self) -> List[Tuple[int, int]]: + """Retrieves a list of edges in the tree. + + Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. + The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited. + + Returns: + A list of tuples, where each tuple contains a parent node ID and a child node ID. + """ result = [] q = deque() parent = self.node_id @@ -221,22 +227,62 @@ def edges(self) -> List[Tuple[int, int]]: return result def valid_paths(self): + """Generates all valid paths in the tree. + + A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. + Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node. + + Yields: + tuple: A valid path in the tree. + """ + # Initialize a queue with the current tree node and an empty path q = deque() cur_path = () q.append((self, cur_path)) + # Perform BFS traversal of the tree while len(q): + # Dequeue the next tree node and its current path tree, cur_path = q.popleft() + # Get the number of child nodes n = int(tree.num_children) + # If this is a leaf node, yield the current path if not n: yield cur_path + # Iterate over the child nodes for i in range(n): cur_path_tree = cur_path + (i,) q.append((tree.subtree[i], cur_path_tree)) def max_length(self): - return max(*(len(path) for path in self.valid_paths())) + """Returns the maximum length of all valid paths in the tree. + + The length of a path is defined as the number of nodes in the path. + If the tree is empty, returns 0. + + Returns: + int: The maximum length of all valid paths in the tree. + + """ + lengths = tuple(len(path) for path in self.valid_paths()) + if len(lengths) == 0: + return 0 + elif len(lengths) == 1: + return lengths[0] + return max(*lengths) def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + """Retrieves the rollout data along a given path in the tree. + + The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. + If no rollout data is found along the path, returns ``None``. + + Args: + path: A tuple of integers representing the path in the tree. + + Returns: + The concatenated rollout data along the path, or None if no data is found. + + """ r = self.rollout tree = self rollouts = [] @@ -272,8 +318,19 @@ def plot( backend: str = "plotly", figure: str = "tree", info: List[str] = None, - make_labels: Callable[[Any], Any] | None = None, + make_labels: Callable[[Any, ...], Any] | None = None, ): + """Plots a visualization of the tree using the specified backend and figure type. + + Args: + backend: The plotting backend to use. Currently only supports 'plotly'. + figure: The type of figure to plot. Can be either 'tree' or 'box'. + info: A list of additional information to include in the plot (not currently used). + make_labels: An optional function to generate custom labels for the plot. + + Raises: + NotImplementedError: If an unsupported backend or figure type is specified. + """ if backend == "plotly": if figure == "box": _plot_plotly_box(self) @@ -284,7 +341,7 @@ def plot( else: pass raise NotImplementedError( - f"Unkown plotting backend {backend} with figure {figure}." + f"Unknown plotting backend {backend} with figure {figure}." ) @@ -423,47 +480,99 @@ def __init__( self.consolidated = consolidated @property - def done_keys(self): + def done_keys(self) -> List[NestedKey]: + """Done Keys. + + Returns the keys used to indicate that an episode has ended. + The default done keys are "done", "terminated", and "truncated". These keys can be + used in the environment's output to signal the end of an episode. + + Returns: + A list of strings representing the done keys. + + """ done_keys = getattr(self, "_done_keys", None) if done_keys is None: - self._done_keys = done_keys = ("done", "terminated", "truncated") + self._done_keys = done_keys = ["done", "terminated", "truncated"] return done_keys @done_keys.setter def done_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._done_keys = value @property - def reward_keys(self): + def reward_keys(self) -> List[NestedKey]: + """Reward Keys. + + Returns the keys used to retrieve rewards from the environment's output. + The default reward key is "reward". + + Returns: + A list of strings or tuples representing the reward keys. + + """ reward_keys = getattr(self, "_reward_keys", None) if reward_keys is None: - self._reward_keys = reward_keys = ("reward",) + self._reward_keys = reward_keys = ["reward"] return reward_keys @reward_keys.setter def reward_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._reward_keys = value @property - def action_keys(self): + def action_keys(self) -> List[NestedKey]: + """Action Keys. + + Returns the keys used to retrieve actions from the environment's input. + The default action key is "action". + + Returns: + A list of strings or tuples representing the action keys. + + """ action_keys = getattr(self, "_action_keys", None) if action_keys is None: - self._action_keys = action_keys = ("action",) + self._action_keys = action_keys = ["action"] return action_keys @action_keys.setter def action_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._action_keys = value @property - def observation_keys(self): + def observation_keys(self) -> List[NestedKey]: + """Observation Keys. + + Returns the keys used to retrieve observations from the environment's output. + The default observation key is "observation". + + Returns: + A list of strings or tuples representing the observation keys. + """ observation_keys = getattr(self, "_observation_keys", None) if observation_keys is None: - self._observation_keys = observation_keys = ("observation",) + self._observation_keys = observation_keys = ["observation"] return observation_keys @observation_keys.setter def observation_keys(self, value): + if isinstance(value, (str, tuple)): + value = [value] + if value is not None: + value = [unravel_key(val) for val in value] self._observation_keys = value def get_keys_from_env(self, env: EnvBase): @@ -482,8 +591,21 @@ def get_keys_from_env(self, env: EnvBase): @classmethod def _write_fn_stack(cls, new, old=None): + # This function updates the old values by adding the new ones + # if and only if the new ones are not there. + # If the old value is not provided, we assume there are none and the + # `new` is just prepared. + # This involves unsqueezing the last dim (since we'll be stacking tensors + # and calling unique). + # The update involves calling cat along the last dim + unique + # which will keep only the new values that were unknown to + # the storage. + # We use this method to track all the indices that are associated with + # an observation. Every time a new index is obtained, it is stacked alongside + # the others. if old is None: - result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) + # we unsqueeze the values to stack them along dim -1 + result = new.apply(lambda x: x.unsqueeze(-1), filter_empty=False) result.set( "count", torch.ones(result.shape, dtype=torch.int, device=result.device) ) @@ -493,8 +615,15 @@ def cat(name, x, y): if name == "count": return x if y.ndim < x.ndim: - y = y.unsqueeze(0) - result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + y = y.unsqueeze(-1) + result = torch.cat([x, y], -1) + # Breaks on mps + if result.device.type == "mps": + result = result.cpu() + result = result.unique(dim=-1, sorted=False) + result = result.to("mps") + else: + result = result.unique(dim=-1, sorted=False) return result result = old.named_apply(cat, new, default=None) @@ -543,12 +672,35 @@ def extend(self, rollout): # # Set the action in the 'next' # dest[1:] = source[:-1].exclude(*self.done_keys) + # Add ('observation', 'action') -> ('next, observation') self.data_map[source] = dest value = source if self.node_map is None: self._make_storage_branches(source, dest) + # map ('observation',) -> ('indices',) self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + def add(self, step): + source, dest = ( + step.exclude("next").copy(), + step.select("next", *self.action_keys).copy(), + ) + + if self.data_map is None: + self._make_storage(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + # Add ('observation', 'action') -> ('next, observation') + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_storage_branches(source, dest) + # map ('observation',) -> ('indices',) + self.node_map[source] = value + def get_child(self, root: TensorDictBase) -> TensorDictBase: return self.data_map[root] @@ -582,6 +734,14 @@ def _make_local_tree( if not compact: break else: + # If the root is provided and not gathered from the storage, it could be that its + # device doesn't match the data_map storage device. + device = getattr(self.data_map.storage, "device", None) + if root.device != device: + if device is not None: + root = root.to(self.data_map.storage.device) + else: + root.clear_device_() index = None break rollout = None diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index 570214f1cb2..d9588d79905 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -17,13 +17,13 @@ def _plot_plotly_tree( if make_labels is None: - def make_labels(tree): + def make_labels(tree, path, *args, **kwargs): return str((tree.node_id, tree.hash)) nr_vertices = tree.num_vertices() - vertices = tree.vertices() + vertices = tree.vertices(key_type="path") - v_label = [make_labels(subtree) for subtree in vertices.values()] + v_label = [make_labels(subtree, path) for path, subtree in vertices.items()] G = Graph(nr_vertices, tree.edges()) layout = G.layout_sugiyama(range(nr_vertices)) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 665cae254f5..ae0d97b7bab 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -246,8 +246,8 @@ def set( set_cursor: bool = True, ): if not isinstance(cursor, INT_CLASSES): - if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( - isinstance(cursor, np.ndarray) and cursor.size <= 1 + if (isinstance(cursor, torch.Tensor) and cursor.ndim == 0) or ( + isinstance(cursor, np.ndarray) and cursor.ndim == 0 ): self.set(int(cursor), data, set_cursor=set_cursor) return diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ddf6ed41c99..dad0aaf69a6 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -41,7 +41,7 @@ unravel_key, ) from tensordict.base import NO_DEFAULT -from tensordict.utils import _getitem_batch_size, NestedKey +from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for DEVICE_TYPING = Union[torch.device, str, int] @@ -2466,10 +2466,10 @@ def one(self, shape=None): data=None, batch_size=(*shape, *self._safe_shape), device=self.device ) - def is_in(self, val: torch.Tensor) -> bool: + def is_in(self, val: Any) -> bool: shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( - isinstance(val, NonTensorData) + is_non_tensor(val) and val.shape == shape # We relax constrains on device as they're hard to enforce for non-tensor # tensordicts and pointless @@ -4373,7 +4373,7 @@ def set(self, name, spec): shape = spec.shape if shape[: self.ndim] != self.shape: if ( - isinstance(spec, Composite) + isinstance(spec, (Composite, NonTensor)) and spec.ndim < self.ndim and self.shape[: spec.ndim] == spec.shape ): @@ -4382,7 +4382,7 @@ def set(self, name, spec): spec.shape = self.shape else: raise ValueError( - "The shape of the spec and the Composite mismatch: the first " + f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"Composite.shape={self.shape}." ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d5a062bc11e..3ce99232a6c 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,14 @@ import numpy as np import torch import torch.nn as nn -from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.utils import NestedKey +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + unravel_key, +) +from tensordict.base import _is_leaf_nontensor +from tensordict.utils import is_non_tensor, NestedKey from torchrl._utils import ( _ends_with, _make_ordinal_device, @@ -25,7 +31,13 @@ seed_generator, ) -from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded +from torchrl.data.tensor_specs import ( + Categorical, + Composite, + NonTensor, + TensorSpec, + Unbounded, +) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _make_compatible_policy, @@ -430,7 +442,6 @@ def auto_specs_( done_key: NestedKey | List[NestedKey] | None = None, observation_key: NestedKey | List[NestedKey] = "observation", reward_key: NestedKey | List[NestedKey] = "reward", - batch_size: torch.Size | None = None, ): """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. @@ -484,6 +495,7 @@ def auto_specs_( tensordict2, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) input_spec = Composite(input_spec_stack, batch_size=batch_size) if not self.batch_locked and batch_size != self.batch_size: @@ -501,6 +513,7 @@ def auto_specs_( nexts_1, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) output_spec = Composite(output_spec_stack, batch_size=batch_size) @@ -523,7 +536,8 @@ def auto_specs_( full_observation_spec = output_spec.separates(*observation_key, default=None) if not output_spec.is_empty(recurse=True): raise RuntimeError( - f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + f"Keys {list(output_spec.keys(True, True))} are unaccounted for. " + f"Make sure you have passed all the leaf names to the auto_specs_ method." ) if full_action_spec is not None: @@ -2995,6 +3009,52 @@ def add_truncated_keys(self) -> EnvBase: self.__dict__["_done_keys"] = None return self + def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Advances the environment state by one step using the provided `next_tensordict`. + + This method updates the environment's state by transitioning from the current + state to the next, as defined by the `next_tensordict`. The resulting tensordict + includes updated observations and any other relevant state information, with + keys managed according to the environment's specifications. + + Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently + handle the transition of state, observation, action, reward, and done keys. The + :class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and + exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance + is created with `exclude_action=False`, meaning that action keys are retained in + the root tensordict. + + Args: + next_tensordict (TensorDictBase): A tensordict containing the state of the + environment at the next time step. This tensordict should include keys + for observations, actions, rewards, and done flags, as defined by the + environment's specifications. + + Returns: + TensorDictBase: A new tensordict representing the environment state after + advancing by one step. + + .. note:: The method ensures that the environment's key specifications are validated + against the provided `next_tensordict`, issuing warnings if discrepancies + are found. + + .. note:: This method is designed to work efficiently with environments that have + consistent key specifications, leveraging the `_StepMDP` class to minimize + overhead. + + Example: + >>> from torchrl.envs import GymEnv + >>> env = GymEnv("Pendulum-1") + >>> data = env.reset() + >>> for i in range(10): + ... # compute action + ... env.rand_action(data) + ... # Perform action + ... next_data = env.step(reset_data) + ... data = env.step_mdp(next_data) + """ + return self._step_mdp(next_tensordict) + @property def _step_mdp(self): step_func = self.__dict__.get("_step_mdp_value") @@ -3568,6 +3628,12 @@ def _has_dynamic_specs(spec: Composite): def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)): + stack[name] = NonTensor(shape=()) + return + elif is_non_tensor(leaf): + stack[name] = NonTensor(shape=leaf.shape) + return shape = leaf.shape if leaf_compare is not None: shape_compare = leaf_compare.shape diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7454bce99b3..0e037ed9b0c 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import torch @@ -76,7 +76,7 @@ def __get__(self, cls, owner): class _StepMDP: - """Stateful version of step_mdp. + """Stateful version of :func:`~torchrl.envs.step_mdp`. Precomputes the list of keys to include and exclude during a call to step_mdp to reduce runtime. @@ -337,48 +337,47 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_keys: Union[NestedKey, List[NestedKey]] = "reward", - done_keys: Union[NestedKey, List[NestedKey]] = "done", - action_keys: Union[NestedKey, List[NestedKey]] = "action", + reward_keys: NestedKey | List[NestedKey] = "reward", + done_keys: NestedKey | List[NestedKey] = "done", + action_keys: NestedKey | List[NestedKey] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. - The arguments allow for a precise control over what should be kept and what + The arguments allow for precise control over what should be kept and what should be copied from the ``"next"`` entry. The default behavior is: - move the observation entries, reward and done states to the root, exclude - the current action and keep all extra keys (non-action, non-done, non-reward). + move the observation entries, reward, and done states to the root, exclude + the current action, and keep all extra keys (non-action, non-done, non-reward). Args: - tensordict (TensorDictBase): tensordict with keys to be renamed - next_tensordict (TensorDictBase, optional): destination tensordict - keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. + tensordict (TensorDictBase): The tensordict with keys to be renamed. + next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created. + keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept. Default is ``True``. - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded + exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``True``. - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded + from the ``"next"`` entry (if present). Default is ``True``. + exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``False``. - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will + from the ``"next"`` entry (if present). Default is ``False``. + exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). - Default is ``True``. - reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults + the ``"next"`` entry). Default is ``True``. + reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults to "reward". - done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults + done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults to "done". - action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults + action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults to "action". Returns: - A new tensordict (or next_tensordict) containing the tensors of the t+1 step. + TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step. + + .. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the + key values to reduce the overhead of making a step in the MDP. Examples: - This funtion allows for this kind of loop to be used: >>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ @@ -781,7 +780,9 @@ def check_env_specs( from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): - for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)): + for real, fake in zip( + real_tensordict_select.unbind(-1), fake_tensordict_select.unbind(-1) + ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): raise AssertionError(zeroing_err_msg) @@ -1365,6 +1366,8 @@ def _update_during_reset( reset_keys: List[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" + if not reset_keys: + return tensordict.update(tensordict_reset) roots = set() for reset_key in reset_keys: # get the node of the reset key diff --git a/tutorials/sphinx-tutorials/beam_search_with_gpt.py b/tutorials/sphinx-tutorials/beam_search_with_gpt.py new file mode 100644 index 00000000000..a8ff399f96e --- /dev/null +++ b/tutorials/sphinx-tutorials/beam_search_with_gpt.py @@ -0,0 +1,365 @@ +""" +Beam Search with TorchRL +======================== + +Key learning +------------ + +In this tutorial, you will learn how to use TorchRL to implement beam search for efficient text generation. +You will understand how to define a policy, build an environment, and run the policy using a beam search algorithm. + +Introduction +------------ +Text generation is a fundamental task in natural language processing (NLP) that has numerous applications in chatbots, +language translation, and content creation. One of the challenges in text generation is efficiently exploring the vast +space of possible sequences to find the most coherent and relevant output. Beam search is a popular heuristic search +algorithm used to address this challenge by maintaining a set of candidate solutions (or "beams") at each step and +selecting the top-scoring candidates to move forward to the next step. + + +Introduction to Beam Search +--------------------------- + +Beam search is a heuristic search algorithm used in many natural language processing tasks, including machine +translation, summarization, and text generation. It works by maintaining a set of candidate solutions (or "beams") at +each step, and selecting the top-scoring candidates to move forward to the next step. + +""" +import argparse + +import torch + +import torchrl.data +import tqdm +from tensordict import NonTensorStack, TensorDict +from tensordict.nn import ( + ProbabilisticTensorDictModule as Prob, + TensorDictModule as Mod, + TensorDictSequential as Seq, +) +from tensordict.tensorclass import NonTensorData +from torch.distributions import Categorical + +from torchrl._utils import _make_ordinal_device +from torchrl.data import MCTSForest, SipHash +from torchrl.envs import EnvBase +from torchrl.envs.common import _StepMDP +from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, pipeline +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +parser = argparse.ArgumentParser() +parser.add_argument("--model", choices=["llama3.1", "gpt2"], default="gpt2") +parser.add_argument("--beta", type=int, default=3) +parser.add_argument("--pool", type=int, default=1000) +parser.add_argument("--nsteps", type=int, default=10) +parser.add_argument("--device", type=str, default=None) +parser.add_argument("--device_map", type=str, default="auto") + +args = parser.parse_args() + +################################################ +# Build the model +# --------------- +# In this example, we use a pre-trained GPT-2 model as our language model. +# We define a GPTWrapper class to wrap the GPT-2 model and return the output as a TensorDict. + +if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + cfg = GPT2Config.from_pretrained("openai-community/gpt2") + llm = GPT2LMHeadModel(cfg).eval().requires_grad_(False) + + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + +elif args.model == "llama3.1": + model_id = "meta-llama/Llama-3.1-8B" + + if args.device: + args.device_map = None + pipeline = pipeline( + "text-generation", + model=model_id, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map=args.device_map, + device=args.device, + ) + + tokenizer = pipeline.tokenizer + llm = pipeline.model.eval().requires_grad_(False) + if args.device: + device = _make_ordinal_device(args.device) + elif torch.cuda.is_available(): + device = "cuda:0" + elif torch.mps.is_available(): + torch.mps.empty_cache() + device = "mps:0" + else: + device = "cpu" + + +torch.set_default_device(device) + +text_to_tensor = Seq( + Mod(tokenizer, in_keys=["query"], out_keys=["out"]), + # A renaming layer + Mod(lambda x: x, in_keys=[("out", "input_ids")], out_keys=["observation"]), +).select_out_keys("observation") +td = TensorDict( + query=NonTensorStack.from_list(["hello world! Give me a high five"] * 4), + batch_size=[4], +) +print(text_to_tensor(td)) + +################################################ +# Define the policy +# ----------------- +# We define a policy that takes the observation as input and outputs an action (i.e., the next token to generate). +# +# Our policy takes the observation (i.e., the current text) as input and outputs an action (i.e., the next token to +# generate). The policy consists of a sequence of modules: first, we use the GPTWrapper to get the output from the +# GPT-2 model, and then we select the top-scoring token using a categorical distribution. + + +class LLMWrapper(torch.nn.Module): + def __init__(self, gpt): + super().__init__() + self.gpt = gpt + + def forward(self, x) -> CausalLMOutputWithCrossAttentions: + result = TensorDict.from_dataclass(self.gpt(x, return_dict=True), device=device) + return result + + +class CategoricalWithoutReplacement(Categorical): + def sample(self, sample_shape=()): + n = sample_shape.numel() + probs = self.probs + probs_shape = probs.shape + if len(probs_shape) > 2: + probs = probs.flatten(0, -2) + samples = torch.multinomial(probs, n, replacement=False) + return samples.view((*sample_shape, *probs_shape[:-1])) + + +prob_module = Prob( + in_keys=["logits"], + out_keys=["action"], + default_interaction_type="random", + distribution_class=CategoricalWithoutReplacement, + return_log_prob=True, + log_prob_key="logits_select", + num_samples=args.pool, +) + + +def select_unique_obs(td): + # Get the obs (the hash) + hashes = td["hash"] + hashes = hashes.squeeze() + assert hashes.ndim == 1 + _, unique_hashes = torch.unique(hashes, dim=0, return_inverse=True) + unique_hashes = unique_hashes.unique() + return td[unique_hashes] + + +def select_top_k(td, top_k=args.beta): + logits = td["logits_select"] + topk = logits.topk(top_k, dim=0) + topk_indices = topk.indices.squeeze(-1) + return td[topk_indices].set("topk_indices", topk_indices) + + +policy = Seq( + # Only get the unique obs + select_unique_obs, + # Call to the LLM + Mod(LLMWrapper(llm), in_keys=["observation"], out_keys=["data"]), + # Select last logit + Mod(lambda x: x[:, -1:], in_keys=[("data", "logits")], out_keys=["logits"]), + # Sample + prob_module, + # Reshape to -1 + lambda td: td.reshape(-1), + # Top-k + select_top_k, +) + +################################################ +# Build the hash module +# --------------------- +# +# We are going to build a hash module to mark each step in the dataset. In theory, observations could be used directly +# but the shape of each observation in the rollout will differ because the number of tokens is different at each step +# of the trajectory. +# +# Using a hashing module, we can reduce every observation to an integer. Although we cannot recover the prompt directly +# from the hash, we can easily recover this by concatenating the previous actions with the initial prompt. +# +# +# .. figure:: /_static/img/rollout-llm.png +# :alt: Data collection loop with our LLM environment. +# +siphash = SipHash() + + +################################################ +# Build the environment +# --------------------- +# +# We define an environment that simulates the text generation process. +# The environment has two main methods: _reset, which initializes the environment with a given observation, and +# _step, which takes an action (i.e., the next token to generate) and returns the next observation and reward. + + +class LLMEnv(EnvBase): + def __init__(self): + super().__init__() + self._batch_locked = False + _StepMDP(self) + + def _reset(self, tensordict): + out = tensordict.copy() + obs = out["observation"] + if obs.ndim > 1: + text = tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = tokenizer.decode(obs) + text = NonTensorData(text) + out["text"] = text + + if obs.ndim > 1: + out["hash"] = siphash(out["observation"]).unsqueeze(-1) + else: + out["hash"] = siphash(out["observation"].unsqueeze(0)).transpose(0, -1) + + if not self.full_done_spec.is_empty(): + out.update(self.full_done_spec.zero(tensordict.shape)) + else: + out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)) + out.set( + "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool) + ) + return out + + def _step(self, tensordict): + action = tensordict.get("action") + obs = torch.cat([tensordict.get("observation"), action], -1) + + catval = torch.cat([tensordict.get("hash"), action], -1) + if obs.ndim > 1: + new_hash = siphash(catval).unsqueeze(-1) + else: + new_hash = siphash(catval.unsqueeze(0)).transpose(0, -1) + + if obs.ndim > 1: + text = tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = tokenizer.decode(obs) + text = NonTensorData(text) + return TensorDict( + observation=obs, + hash=new_hash, + text=text, + reward=torch.zeros((*tensordict.batch_size, 1)), + done=torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), + terminated=torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), + batch_size=tensordict.batch_size, + ) + + def _set_seed(self, *args): + pass + + +env = LLMEnv() + +################################################ +# Define specs +# ------------ +# + +policy = policy.select_out_keys("action") + +x = tokenizer(["Check out TorchRL!"])["input_ids"] +td = TensorDict(observation=x, batch_size=[1]).repeat_interleave(args.beta) +td = env.reset(td) +print("data after reset", td) +print("action", policy(td)) +# We must indicate what the observations are +env.auto_specs_(policy, tensordict=td, observation_key=["observation", "text", "hash"]) +print(env.specs) +# Reset out keys - we want them all +policy.reset_out_keys() +policy = policy.select_out_keys("action", "logits_select") + +td = TensorDict(observation=x, batch_size=[1]).repeat_interleave(args.beta, dim=0) +td = env.reset(td) +env.action_spec = torchrl.data.Categorical(n=tokenizer.vocab_size, shape=(1,)) +env.check_env_specs(tensordict=td, return_contiguous=False) + +################################################ +# Create a forest to store the data +# --------------------------------- +# + +forest = MCTSForest(observation_keys=["hash"], action_keys=["action", "logits_select"]) + +################################################ +# Run policy +# ---------- +# + +with torch.no_grad(): + # Total number of candidates + pool = args.pool + # Number of selected beams + beta = args.beta + x = tokenizer(["Check out TorchRL!"])["input_ids"] + reset_td = env.reset( + TensorDict(observation=x, batch_size=[1]).repeat_interleave(args.beta) + ) + tds = [] + # beam search + td = reset_td + reset_td = reset_td[0].clone() + + pbar = tqdm.tqdm(range(args.nsteps)) + for _ in pbar: + td = policy(td) + next_td = env.step(td) + + tds.append(next_td) + next_td_filtered = next_td.exclude( + "observation", "text", ("next", "observation"), ("next", "text") + ) + forest.extend(next_td_filtered) + pbar.set_description(f"Forest length: {len(forest)}") + + print("action", next_td["action"]) + td = env.step_mdp(next_td) + print("hash", td["hash"]) + + tds = TensorDict.lazy_stack(tds, -1) + for i in range(tds.shape[0]): + print(tds[i, -1]["next", "text"]) + + tree = forest.get_tree(reset_td) + valid_paths = list(tree.valid_paths()) + print("valid paths", valid_paths) + + for path in valid_paths: + rollout = tree.rollout_from_path(path) + print("Check out TorchRL!", tokenizer.decode(rollout["action"].squeeze(-1))) + print(rollout["logits_select"].sum()) + + def make_labels(local_tree, path): + if path: + r = tree.rollout_from_path(path) + actions = r["action"] + return "Check out TorchRL! " + tokenizer.decode(actions.squeeze(-1)) + return "Check out TorchRL!" + + tree.plot(make_labels=make_labels)