From afc0f24f2f0462ee7acf93d18126148aef527cb7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 6 Dec 2024 21:20:11 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/data/map/hash.py | 3 +- torchrl/data/map/tdstorage.py | 32 +- torchrl/data/map/tree.py | 391 +++++++++++++++++++++--- torchrl/data/map/utils.py | 6 +- torchrl/data/replay_buffers/storages.py | 4 +- 5 files changed, 387 insertions(+), 49 deletions(-) diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 01988dc43be..59526628dbe 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: class SipHash(Module): """A Module to Compute SipHash values for given tensors. - A hash function module based on SipHash implementation in python. + A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]`` + and the output shape will be ``[batch_size]``. Args: as_tensor (bool, optional): if ``True``, the bytes will be turned into integers diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index a601f1e3261..9413033bac4 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -138,6 +138,10 @@ def __init__( self.collate_fn = collate_fn self.write_fn = write_fn + @property + def max_size(self): + return self.storage.max_size + @property def out_keys(self) -> List[NestedKey]: out_keys = self.__dict__.get("_out_keys_and_lazy") @@ -177,7 +181,7 @@ def from_tensordict_pair( collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, consolidated: bool | None = None, - ): + ) -> TensorDictMap: """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. Args: @@ -238,7 +242,13 @@ def from_tensordict_pair( n_feat = 0 hash_module = [] for in_key in in_keys: - n_feat = source[in_key].shape[-1] + entry = source[in_key] + if entry.ndim == source.ndim: + # this is a good example of why td/tc are useful - carrying metadata + # allows us to know if there's a feature dim or not + n_feat = 0 + else: + n_feat = entry.shape[-1] if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT: _hash_module = RandomProjectionHash() else: @@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): if not self._has_lazy_out_keys(): # TODO: make this work with pytrees and avoid calling select if keys match value = value.select(*self.out_keys, strict=False) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + if index.unique().numel() < index.numel(): + # If multiple values point to the same place in the storage, we cannot process them by batch + # There could be a better way to deal with this, using unique ids. + vals = [] + for it, val in zip(item.split(1), value.split(1)): + self[it] = val + vals.append(val) + # __setitem__ may affect the content of the input data + value.update(TensorDictBase.lazy_stack(vals)) + return if self.write_fn is not None: + # We use this block in the following context: the value written in the storage is already present, + # but it needs to be updated. + # We first check if the value is already there using `contains`. If so, we pass the new value and the + # previous one to write_fn. The values that are not present are passed alone. if len(self): modifiable = self.contains(item) if modifiable.any(): @@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): value = self.write_fn(value) else: value = self.write_fn(value) - item, value = self._maybe_add_batch(item, value) - index = self._to_index(item, extend=True) self.storage.set(index, value) def __len__(self): diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..f30162c9e65 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import weakref from collections import deque from typing import Any, Callable, Dict, List, Literal, Tuple @@ -15,10 +16,13 @@ TensorClass, TensorDict, TensorDictBase, + unravel_key, LazyStackedTensorDict, ) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage + +from torchrl.data.tensor_specs import TensorSpec from torchrl.envs.common import EnvBase @@ -69,7 +73,9 @@ class Tree(TensorClass["nocast"]): """ - count: int = None + count: int | torch.Tensor = None + wins: int | torch.Tensor = None + index: torch.Tensor | None = None # The hash is None if the node has more than one action associated hash: int | None = None @@ -84,6 +90,71 @@ class Tree(TensorClass["nocast"]): # Stack of subtrees. A subtree is produced when an action is taken. subtree: "Tree" = None + _parent: weakref.ref | None = None + + @classmethod + def make_node( + cls, + data, + *, + parent=None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + ) -> Tree: + if "next" in data.keys(): + rollout = data + if not rollout.ndim: + rollout = rollout.unsqueeze(0) + subtree = TensorDict.lazy_stack([cls.make_node(data["next"][..., 0])]) + else: + rollout = None + subtree = None + if device is None: + device = data.device + return cls( + count=torch.zeros(()), + wins=torch.zeros(()), + node=data.exclude("action", "next"), + rollout=rollout, + _parent=parent, + subtree=subtree, + device=device, + batch_size=batch_size, + ) + + def __post_init__(self): + if (self.subtree is None) ^ (self.rollout is None): + raise ValueError("A node was created with only a subtree or a rollout but not both.") + + # @property + # def children(self) -> Tree: + # return self.subtree + + @property + def visits(self) -> int | torch.Tensor: + return self.count + @visits.setter + def visits(self, count): + self.count = count + + def fully_expanded(self, *, action_spec: TensorSpec | None = None) -> bool: + ... + + @property + def parent(self) -> Tree | None: + parent = self._parent + if parent is not None: + # Check that all parents match + if isinstance(parent, list): + parent = [p() for p in parent] + for p in parent[1:]: + if p is not parent[0]: + raise ValueError( + "All parents of a given node level must match." + ) + return parent[0] + return parent() + @property def num_children(self) -> int: """Number of children of this node. @@ -93,9 +164,19 @@ def num_children(self) -> int: return len(self.subtree) if self.subtree is not None else 0 @property - def is_terminal(self): - """Returns True if the the tree has no children nodes.""" - return self.subtree is None + def is_terminal(self) -> bool | torch.Tensor: + """Returns True if the tree has no children nodes.""" + if self.rollout is not None: + return self.rollout[..., -1]["next", "done"].squeeze(-1) + # Here we have two options: (1) subtree is None, in which case this is a node that + # has no child yet. Therefore, it can be explored further. + # (2) subtree is there, in which case it's not terminal. + return True + + def fully_expanded(self, env: EnvBase) -> bool: + cardinality = env.cardinality(self.node) + num_actions = self.num_children + return cardinality == num_actions def get_vertex_by_id(self, id: int) -> Tree: """Goes through the tree and returns the node corresponding the given id.""" @@ -163,9 +244,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: @@ -206,6 +284,14 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: ) def edges(self) -> List[Tuple[int, int]]: + """Retrieves a list of edges in the tree. + + Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. + The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited. + + Returns: + A list of tuples, where each tuple contains a parent node ID and a child node ID. + """ result = [] q = deque() parent = self.node_id @@ -221,22 +307,62 @@ def edges(self) -> List[Tuple[int, int]]: return result def valid_paths(self): + """Generates all valid paths in the tree. + + A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. + Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node. + + Yields: + tuple: A valid path in the tree. + """ + # Initialize a queue with the current tree node and an empty path q = deque() cur_path = () q.append((self, cur_path)) + # Perform BFS traversal of the tree while len(q): + # Dequeue the next tree node and its current path tree, cur_path = q.popleft() + # Get the number of child nodes n = int(tree.num_children) + # If this is a leaf node, yield the current path if not n: yield cur_path + # Iterate over the child nodes for i in range(n): cur_path_tree = cur_path + (i,) q.append((tree.subtree[i], cur_path_tree)) def max_length(self): - return max(*(len(path) for path in self.valid_paths())) + """Returns the maximum length of all valid paths in the tree. + + The length of a path is defined as the number of nodes in the path. + If the tree is empty, returns 0. + + Returns: + int: The maximum length of all valid paths in the tree. + + """ + lengths = tuple(len(path) for path in self.valid_paths()) + if len(lengths) == 0: + return 0 + elif len(lengths) == 1: + return lengths[0] + return max(*lengths) def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + """Retrieves the rollout data along a given path in the tree. + + The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. + If no rollout data is found along the path, returns ``None``. + + Args: + path: A tuple of integers representing the path in the tree. + + Returns: + The concatenated rollout data along the path, or None if no data is found. + + """ r = self.rollout tree = self rollouts = [] @@ -272,8 +398,19 @@ def plot( backend: str = "plotly", figure: str = "tree", info: List[str] = None, - make_labels: Callable[[Any], Any] | None = None, + make_labels: Callable[[Any, ...], Any] | None = None, ): + """Plots a visualization of the tree using the specified backend and figure type. + + Args: + backend: The plotting backend to use. Currently only supports 'plotly'. + figure: The type of figure to plot. Can be either 'tree' or 'box'. + info: A list of additional information to include in the plot (not currently used). + make_labels: An optional function to generate custom labels for the plot. + + Raises: + NotImplementedError: If an unsupported backend or figure type is specified. + """ if backend == "plotly": if figure == "box": _plot_plotly_box(self) @@ -284,33 +421,48 @@ def plot( else: pass raise NotImplementedError( - f"Unkown plotting backend {backend} with figure {figure}." + f"Unknown plotting backend {backend} with figure {figure}." ) class MCTSForest: """A collection of MCTS trees. + .. warning:: This class is currently under active development. Expect frequent API changes. + The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset. Keyword Args: data_map (TensorDictMap, optional): the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily - initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`. - node_map (TensorDictMap, optional): TODO - done_keys (list of NestedKey): the done keys of the environment. If not provided, + initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair` + using the list of :attr:`observation_keys` and :attr:`action_keys` as ``in_keys``. + node_map (TensorDictMap, optional): a map from the observation space to the index space. + Internally, the node map is used to gather all possible branches coming out of + a given node. For example, if an observation has two associated actions and outcomes + in the data map, then the :attr:`node_map` will return a data structure containing the + two indices in the :attr:`data_map` that correspond to these two outcomes. + If not provided, it is lazily initialized using + :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair` using the list of + :attr:`observation_keys` as ``in_keys`` and the :class:`~torchrl.data.QueryModule` as + ``out_keys``. + max_size (int, optional): the size of the maps. + If not provided, defaults to ``data_map.max_size`` if this can be found, then + ``node_map.max_size``. If none of these are provided, defaults to `1000`. + done_keys (list of NestedKey, optional): the done keys of the environment. If not provided, defaults to ``("done", "terminated", "truncated")``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - action_keys (list of NestedKey): the action keys of the environment. If not provided, + action_keys (list of NestedKey, optional): the action keys of the environment. If not provided, defaults to ``("action",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - reward_keys (list of NestedKey): the reward keys of the environment. If not provided, + reward_keys (list of NestedKey, optional): the reward keys of the environment. If not provided, defaults to ``("reward",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - observation_keys (list of NestedKey): the observation keys of the environment. If not provided, + observation_keys (list of NestedKey, optional): the observation keys of the environment. If not provided, defaults to ``("observation",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + excluded_keys (list of NestedKey, optional): a list of keys to exclude from the data storage. consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk. Defaults to ``False``. @@ -405,10 +557,12 @@ def __init__( *, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, + max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, + excluded_keys: List[NestedKey] = None, consolidated: bool | None = None, ): @@ -416,55 +570,125 @@ def __init__( self.node_map = node_map + if max_size is None: + if data_map is not None: + max_size = data_map.max_size + if max_size != getattr(node_map, "max_size", max_size): + raise ValueError( + f"Conflicting max_size: got data_map.max_size={data_map.max_size} and node_map.max_size={node_map.max_size}." + ) + elif node_map is not None: + max_size = node_map.max_size + else: + max_size = None + elif data_map is not None and max_size != getattr( + data_map, "max_size", max_size + ): + raise ValueError( + f"Conflicting max_size: got data_map.max_size={data_map.max_size} and max_size={max_size}." + ) + elif node_map is not None and max_size != getattr( + node_map, "max_size", max_size + ): + raise ValueError( + f"Conflicting max_size: got node_map.max_size={node_map.max_size} and max_size={max_size}." + ) + self.max_size = max_size + self.done_keys = done_keys self.action_keys = action_keys self.reward_keys = reward_keys self.observation_keys = observation_keys + self.excluded_keys = excluded_keys self.consolidated = consolidated @property - def done_keys(self): + def done_keys(self) -> List[NestedKey]: + """Done Keys. + + Returns the keys used to indicate that an episode has ended. + The default done keys are "done", "terminated", and "truncated". These keys can be + used in the environment's output to signal the end of an episode. + + Returns: + A list of strings representing the done keys. + + """ done_keys = getattr(self, "_done_keys", None) if done_keys is None: - self._done_keys = done_keys = ("done", "terminated", "truncated") + self._done_keys = done_keys = ["done", "terminated", "truncated"] return done_keys @done_keys.setter def done_keys(self, value): - self._done_keys = value + self._done_keys = _make_list_of_nestedkeys(value, "done_keys") @property - def reward_keys(self): + def reward_keys(self) -> List[NestedKey]: + """Reward Keys. + + Returns the keys used to retrieve rewards from the environment's output. + The default reward key is "reward". + + Returns: + A list of strings or tuples representing the reward keys. + + """ reward_keys = getattr(self, "_reward_keys", None) if reward_keys is None: - self._reward_keys = reward_keys = ("reward",) + self._reward_keys = reward_keys = ["reward"] return reward_keys @reward_keys.setter def reward_keys(self, value): - self._reward_keys = value + self._reward_keys = _make_list_of_nestedkeys(value, "reward_keys") @property - def action_keys(self): + def action_keys(self) -> List[NestedKey]: + """Action Keys. + + Returns the keys used to retrieve actions from the environment's input. + The default action key is "action". + + Returns: + A list of strings or tuples representing the action keys. + + """ action_keys = getattr(self, "_action_keys", None) if action_keys is None: - self._action_keys = action_keys = ("action",) + self._action_keys = action_keys = ["action"] return action_keys @action_keys.setter def action_keys(self, value): - self._action_keys = value + self._action_keys = _make_list_of_nestedkeys(value, "action_keys") @property - def observation_keys(self): + def observation_keys(self) -> List[NestedKey]: + """Observation Keys. + + Returns the keys used to retrieve observations from the environment's output. + The default observation key is "observation". + + Returns: + A list of strings or tuples representing the observation keys. + """ observation_keys = getattr(self, "_observation_keys", None) if observation_keys is None: - self._observation_keys = observation_keys = ("observation",) + self._observation_keys = observation_keys = ["observation"] return observation_keys @observation_keys.setter def observation_keys(self, value): - self._observation_keys = value + self._observation_keys = _make_list_of_nestedkeys(value, "observation_keys") + + @property + def excluded_keys(self) -> List[NestedKey] | None: + return self._excluded_keys + + @excluded_keys.setter + def excluded_keys(self, value): + self._excluded_keys = _make_list_of_nestedkeys(value, "excluded_keys") def get_keys_from_env(self, env: EnvBase): """Writes missing done, action and reward keys to the Forest given an environment. @@ -482,8 +706,21 @@ def get_keys_from_env(self, env: EnvBase): @classmethod def _write_fn_stack(cls, new, old=None): + # This function updates the old values by adding the new ones + # if and only if the new ones are not there. + # If the old value is not provided, we assume there are none and the + # `new` is just prepared. + # This involves unsqueezing the last dim (since we'll be stacking tensors + # and calling unique). + # The update involves calling cat along the last dim + unique + # which will keep only the new values that were unknown to + # the storage. + # We use this method to track all the indices that are associated with + # an observation. Every time a new index is obtained, it is stacked alongside + # the others. if old is None: - result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) + # we unsqueeze the values to stack them along dim -1 + result = new.apply(lambda x: x.unsqueeze(-1), filter_empty=False) result.set( "count", torch.ones(result.shape, dtype=torch.int, device=result.device) ) @@ -493,28 +730,44 @@ def cat(name, x, y): if name == "count": return x if y.ndim < x.ndim: - y = y.unsqueeze(0) - result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + y = y.unsqueeze(-1) + result = torch.cat([x, y], -1) + # Breaks on mps + if result.device.type == "mps": + result = result.cpu() + result = result.unique(dim=-1, sorted=False) + result = result.to("mps") + else: + result = result.unique(dim=-1, sorted=False) return result result = old.named_apply(cat, new, default=None) result.set_("count", old.get("count") + 1) return result - def _make_storage(self, source, dest): + def _make_data_map(self, source, dest): try: + kwargs = {} + if self.max_size is not None: + kwargs["max_size"] = self.max_size self.data_map = TensorDictMap.from_tensordict_pair( source, dest, in_keys=[*self.observation_keys, *self.action_keys], consolidated=self.consolidated, + **kwargs, ) + if self.max_size is None: + self.max_size = self.data_map.max_size except KeyError as err: raise KeyError( "A KeyError occurred during data map creation. This could be due to the wrong setting of a key in the MCTSForest constructor. Scroll up for more info." ) from err - def _make_storage_branches(self, source, dest): + def _make_node_map(self, source, dest): + kwargs = {} + if self.max_size is not None: + kwargs["max_size"] = self.max_size self.node_map = TensorDictMap.from_tensordict_pair( source, dest, @@ -528,26 +781,59 @@ def _make_storage_branches(self, source, dest): storage_constructor=ListStorage, collate_fn=TensorDict.lazy_stack, write_fn=self._write_fn_stack, + **kwargs, ) + if self.max_size is None: + self.max_size = self.data_map.max_size - def extend(self, rollout): + def extend(self, rollout, *, return_node: bool = False): source, dest = ( rollout.exclude("next").copy(), rollout.select("next", *self.action_keys).copy(), ) + if self.excluded_keys is not None: + dest = dest.exclude(*self.excluded_keys, inplace=True) + dest.get("next").exclude(*self.excluded_keys, inplace=True) if self.data_map is None: - self._make_storage(source, dest) + self._make_data_map(source, dest) # We need to set the action somewhere to keep track of what action lead to what child # # Set the action in the 'next' # dest[1:] = source[:-1].exclude(*self.done_keys) + # Add ('observation', 'action') -> ('next, observation') self.data_map[source] = dest value = source if self.node_map is None: - self._make_storage_branches(source, dest) + self._make_node_map(source, dest) + # map ('observation',) -> ('indices',) self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + if return_node: + return self.get_tree(rollout) + + def add(self, step, *, return_node: bool = False): + source, dest = ( + step.exclude("next").copy(), + step.select("next", *self.action_keys).copy(), + ) + + if self.data_map is None: + self._make_data_map(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + # Add ('observation', 'action') -> ('next, observation') + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_node_map(source, dest) + # map ('observation',) -> ('indices',) + self.node_map[source] = value + if return_node: + return self.get_tree(step) def get_child(self, root: TensorDictBase) -> TensorDictBase: return self.data_map[root] @@ -557,6 +843,7 @@ def _make_local_tree( root: TensorDictBase, index: torch.Tensor | None = None, compact: bool = True, + parent: Tree | None = None, ) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]: root = root.select(*self.node_map.in_keys) node_meta = None @@ -582,6 +869,14 @@ def _make_local_tree( if not compact: break else: + # If the root is provided and not gathered from the storage, it could be that its + # device doesn't match the data_map storage device. + device = getattr(self.data_map.storage, "device", None) + if root.device != device: + if device is not None: + root = root.to(self.data_map.storage.device) + else: + root.clear_device_() index = None break rollout = None @@ -592,11 +887,14 @@ def _make_local_tree( return ( Tree( rollout=rollout, - count=node_meta["count"], + count=torch.zeros((), dtype=torch.int32), + wins=torch.zeros(()), node=root, index=index, hash=None, - subtree=None, + # We do this to avoid raising an exception as rollout and subtree must be provided together + subtree=LazyStackedTensorDict(), + _parent=weakref.ref(parent) if parent is not None else None, ), index, hash, @@ -638,7 +936,10 @@ def _make_tree_iter( subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3) if subtree is None: subtree, subtree_indices, subtree_hash = self._make_local_tree( - tree.node, index=i, compact=compact + tree.node, + index=i, + compact=compact, + parent=tree, ) subtree.node_id = counter counter += 1 @@ -668,3 +969,15 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + + +def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: + if obj is None: + return obj + if isinstance(obj, (str, tuple)): + return [obj] + if not isinstance(obj, list): + raise ValueError( + f"{attr} must be a list of NestedKeys or a NestedKey, got {obj}." + ) + return [unravel_key(key) for key in obj] diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index 570214f1cb2..d9588d79905 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -17,13 +17,13 @@ def _plot_plotly_tree( if make_labels is None: - def make_labels(tree): + def make_labels(tree, path, *args, **kwargs): return str((tree.node_id, tree.hash)) nr_vertices = tree.num_vertices() - vertices = tree.vertices() + vertices = tree.vertices(key_type="path") - v_label = [make_labels(subtree) for subtree in vertices.values()] + v_label = [make_labels(subtree, path) for path, subtree in vertices.items()] G = Graph(nr_vertices, tree.edges()) layout = G.layout_sugiyama(range(nr_vertices)) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 665cae254f5..ae0d97b7bab 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -246,8 +246,8 @@ def set( set_cursor: bool = True, ): if not isinstance(cursor, INT_CLASSES): - if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( - isinstance(cursor, np.ndarray) and cursor.size <= 1 + if (isinstance(cursor, torch.Tensor) and cursor.ndim == 0) or ( + isinstance(cursor, np.ndarray) and cursor.ndim == 0 ): self.set(int(cursor), data, set_cursor=set_cursor) return