diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 10ea80c1dcd..77abee7d4fc 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -28,13 +28,23 @@ jobs: with: repository: pytorch/rl upload-artifact: docs - runner: "linux.g5.4xlarge.nvidia.gpu" - docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | set -e set -v - apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils + # apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils + yum makecache + # yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut + # Install Mesa and OpenGL Libraries: + yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL + # Install DRI Drivers: + yum install -y mesa-dri-drivers + # Install Xvfb for Headless Environments: + yum install -y xorg-x11-server-Xvfb + # xhost +local:docker + # Xvfb :1 -screen 0 1024x768x24 & + # export DISPLAY=:1 + root_dir="$(pwd)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" @@ -51,7 +61,7 @@ jobs: conda activate "${env_dir}" # 2. upgrade pip, ninja and packaging - apt-get install python3-pip unzip -y -f + # apt-get install python3-pip unzip -y -f python3 -m pip install --upgrade pip python3 -m pip install setuptools ninja packaging cmake -U diff --git a/examples/replay-buffers/filter-imcomplete-trajs.py b/examples/replay-buffers/filter-imcomplete-trajs.py new file mode 100644 index 00000000000..271c7c00831 --- /dev/null +++ b/examples/replay-buffers/filter-imcomplete-trajs.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Efficient Trajectory Sampling with CompletedTrajRepertoire + +This example demonstrates how to design a custom transform that filters trajectories during sampling, +ensuring that only completed trajectories are present in sampled batches. This can be particularly useful +when dealing with environments where some trajectories might be corrupted or never reach a done state, +which could skew the learning process or lead to biased models. For instance, in robotics or autonomous +driving, a trajectory might be interrupted due to external factors such as hardware failures or human +intervention, resulting in incomplete or inconsistent data. By filtering out these incomplete trajectories, +we can improve the quality of the training data and increase the robustness of our models. +""" + +import torch +from tensordict import TensorDictBase +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs import GymEnv, TrajCounter, Transform + + +class CompletedTrajectoryRepertoire(Transform): + """ + A transform that keeps track of completed trajectories and filters them out during sampling. + """ + + def __init__(self): + super().__init__() + self.completed_trajectories = set() + self.repertoire_tensor = torch.zeros((), dtype=torch.int64) + + def _update_repertoire(self, tensordict: TensorDictBase) -> None: + """Updates the repertoire of completed trajectories.""" + done = tensordict["next", "terminated"].squeeze(-1) + traj = tensordict["next", "traj_count"][done].view(-1) + if traj.numel(): + self.completed_trajectories = self.completed_trajectories.union( + traj.tolist() + ) + self.repertoire_tensor = torch.tensor( + list(self.completed_trajectories), dtype=torch.int64 + ) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Updates the repertoire of completed trajectories during insertion.""" + self._update_repertoire(tensordict) + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Filters out incomplete trajectories during sampling.""" + traj = tensordict["next", "traj_count"] + traj = traj.unsqueeze(-1) + has_traj = (traj == self.repertoire_tensor).any(-1) + has_traj = has_traj.view(tensordict.shape) + return tensordict[has_traj] + + +def main(): + # Create a CartPole environment with trajectory counting + env = GymEnv("CartPole-v1").append_transform(TrajCounter()) + + # Create a replay buffer with the completed trajectory repertoire transform + buffer = ReplayBuffer( + storage=LazyTensorStorage(1_000_000), transform=CompletedTrajectoryRepertoire() + ) + + # Roll out the environment for 1000 steps + while True: + rollout = env.rollout(1000, break_when_any_done=False) + if not rollout["next", "done"][-1].item(): + break + + # Extend the replay buffer with the rollout + buffer.extend(rollout) + + # Get the last trajectory count + last_traj_count = rollout[-1]["next", "traj_count"].item() + print(f"Incomplete trajectory: {last_traj_count}") + + # Sample from the replay buffer 10 times + for _ in range(10): + sample_traj_counts = buffer.sample(32)["next", "traj_count"].unique() + print(f"Sampled trajectories: {sample_traj_counts}") + assert last_traj_count not in sample_traj_counts + + +if __name__ == "__main__": + main() diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 4e943e03cfc..b6f4ac7069b 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from typing import Dict, List, Optional import torch diff --git a/test/test_collector.py b/test/test_collector.py index 38191a46eaa..5c91cb83633 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1345,7 +1345,7 @@ def make_env(): functools.partial(MultiSyncDataCollector, cat_results="stack"), ], ) -@pytest.mark.parametrize("init_random_frames", [50]) # 1226: faster execution +@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution @pytest.mark.parametrize( "explicit_spec,split_trajs", [[True, True], [False, False]] ) # 1226: faster execution diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 9ff4431fb50..db2d0bc2c49 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -301,6 +301,7 @@ def _state0(self) -> TensorDict: def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1) reward = action.clone() + action = action + torch.arange(action.shape[-1]) / action.shape[-1] return TensorDict( { @@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest: forest.extend(r4) return forest - def _make_forest_intersect(self) -> MCTSForest: + def _make_forest_rebranching(self) -> MCTSForest: """ ├── 0 │ ├── 16 @@ -449,7 +450,7 @@ def test_forest_check_ids(self): def test_forest_intersect(self): state0 = self._state0() - forest = self._make_forest_intersect() + forest = self._make_forest_rebranching() tree = forest.get_tree(state0) subtree = forest.get_tree(TensorDict(observation=19)) @@ -467,13 +468,110 @@ def test_forest_intersect(self): def test_forest_intersect_vertices(self): state0 = self._state0() - forest = self._make_forest_intersect() + forest = self._make_forest_rebranching() tree = forest.get_tree(state0) assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash")) assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash")) with pytest.raises(ValueError, match="key_type must be"): tree.vertices(key_type="another key type") + @pytest.mark.skipif(not _has_gym, reason="requires gym") + def test_simple_tree(self): + from torchrl.envs import GymEnv + + env = GymEnv("Pendulum-v1") + r = env.rollout(10) + state0 = r[0] + forest = MCTSForest() + forest.extend(r) + # forest = self._make_forest_intersect() + tree = forest.get_tree(state0, compact=False) + assert tree.max_length() == 9 + for p in tree.valid_paths(): + assert len(p) == 9 + + @pytest.mark.parametrize( + "tree_type,compact", + [ + ["simple", False], + ["forest", False], + # parent of rebranching trees are still buggy + # ["rebranching", False], + # ["rebranching", True], + ], + ) + def test_forest_parent(self, tree_type, compact): + if tree_type == "simple": + if not _has_gym: + pytest.skip("requires gym") + from torchrl.envs import GymEnv + + env = GymEnv("Pendulum-v1") + r = env.rollout(10) + state0 = r[0] + forest = MCTSForest() + forest.extend(r) + tree = forest.get_tree(state0, compact=compact) + elif tree_type == "forest": + state0 = self._state0() + forest = self._make_forest() + tree = forest.get_tree(state0, compact=compact) + else: + state0 = self._state0() + forest = self._make_forest_rebranching() + tree = forest.get_tree(state0, compact=compact) + # Check access + tree.subtree.parent + tree.subtree.subtree.parent + tree.subtree.subtree.subtree.parent + + # check present of weakref + assert tree.subtree[0]._parent is not None + assert tree.subtree[0].subtree[0]._parent is not None + + # Check content + assert_close(tree.subtree.parent, tree) + for p in tree.valid_paths(): + root = tree + for it in p: + node = root.subtree[it] + assert_close(node.parent, root) + root = node + + def test_forest_action_attr(self): + state0 = self._state0() + forest = self._make_forest() + tree = forest.get_tree(state0) + assert tree.branching_action is None + assert (tree.subtree.branching_action != tree.subtree.prev_action).any() + assert ( + tree.subtree[0].subtree.branching_action + != tree.subtree[0].subtree.prev_action + ).any() + assert tree.prev_action is None + + @pytest.mark.parametrize("intersect", [False, True]) + def test_forest_check_obs_match(self, intersect): + state0 = self._state0() + if intersect: + forest = self._make_forest_rebranching() + else: + forest = self._make_forest() + tree = forest.get_tree(state0) + for path in tree.valid_paths(): + prev_tree = tree + for p in path: + subtree = prev_tree.subtree[p] + assert ( + subtree.node_data["observation"] + == subtree.rollout[..., -1]["next", "observation"] + ).all() + assert ( + subtree.node_observation + == subtree.rollout[..., -1]["next", "observation"] + ).all() + prev_tree = subtree + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 16eb5904b84..14fbc7d5f22 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -712,10 +712,10 @@ def __init__( ) self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 + int(init_random_frames) if init_random_frames not in (None, -1) else 0 ) if ( - init_random_frames is not None + init_random_frames not in (-1, None, 0) and init_random_frames % frames_per_batch != 0 and RL_WARNINGS ): diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index f30162c9e65..513a7b94e58 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -16,13 +16,13 @@ TensorClass, TensorDict, TensorDictBase, - unravel_key, LazyStackedTensorDict, + unravel_key, ) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage +from torchrl.data.tensor_specs import Composite -from torchrl.data.tensor_specs import TensorSpec from torchrl.envs.common import EnvBase @@ -84,28 +84,37 @@ class Tree(TensorClass["nocast"]): # rollout following the observation encoded in node, in a TorchRL (TED) format rollout: TensorDict | None = None - # The data specifying the node - node: TensorDict | None = None + # The data specifying the node (typically an observation or a set of observations) + node_data: TensorDict | None = None # Stack of subtrees. A subtree is produced when an action is taken. subtree: "Tree" = None - _parent: weakref.ref | None = None + # weakrefs to the parent(s) of the node + _parent: weakref.ref | List[weakref.ref] | None = None + + # Specs: contains information such as action or observation keys and spaces. + # If present, they should be structured like env specs are: + # Composite(input_spec=Composite(full_state_spec=..., full_action_spec=...), + # output_spec=Composite(full_observation_spec=..., full_reward_spec=..., full_done_spec=...)) + # where every leaf component is optional. + specs: Composite | None = None @classmethod def make_node( cls, - data, + data: TensorDictBase, *, - parent=None, device: torch.device | None = None, batch_size: torch.Size | None = None, + specs: Composite | None = None, ) -> Tree: + """Creates a new node given some data.""" if "next" in data.keys(): rollout = data if not rollout.ndim: rollout = rollout.unsqueeze(0) - subtree = TensorDict.lazy_stack([cls.make_node(data["next"][..., 0])]) + subtree = TensorDict.lazy_stack([cls.make_node(data["next"][..., -1])]) else: rollout = None subtree = None @@ -116,44 +125,207 @@ def make_node( wins=torch.zeros(()), node=data.exclude("action", "next"), rollout=rollout, - _parent=parent, subtree=subtree, device=device, batch_size=batch_size, ) - def __post_init__(self): - if (self.subtree is None) ^ (self.rollout is None): - raise ValueError("A node was created with only a subtree or a rollout but not both.") + # Specs + @property + def full_observation_spec(self): + """The observation spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.""" + return self.specs["output_spec", "full_observation_spec"] + + @property + def full_reward_spec(self): + """The reward spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.""" + return self.specs["output_spec", "full_reward_spec"] + + @property + def full_done_spec(self): + """The done spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.""" + return self.specs["output_spec", "full_done_spec"] + + @property + def full_state_spec(self): + """The state spec of the tree. + + This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.""" + return self.specs["input_spec", "full_state_spec"] + + @property + def full_action_spec(self): + """The action spec of the tree. + + This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.""" + return self.specs["input_spec", "full_action_spec"] + + @property + def selected_actions(self) -> torch.Tensor | TensorDictBase | None: + """Returns a tensor containing all the selected actions branching out from this node.""" + if self.subtree is None: + return None + return self.subtree.rollout[..., 0]["action"] + + @property + def prev_action(self) -> torch.Tensor | TensorDictBase | None: + """The action undertaken just before this node's observation was generated. + + Returns: + a tensor, tensordict or None if the node has no parent. + + .. seealso:: This will be equal to :class:`~torchrl.data.Tree.branching_action` whenever the rollout data contains a single step. + + .. seealso:: :class:`All actions associated with a given node (or observation) in the tree <~torchrl.data.Tree.selected_action>`. + + """ + if self.rollout is None: + return None + return self.rollout[..., -1]["action"] + + @property + def branching_action(self) -> torch.Tensor | TensorDictBase | None: + """Returns the action that branched out to this particular node. + + Returns: + a tensor, tensordict or None if the node has no parent. + + .. seealso:: This will be equal to :class:`~torchrl.data.Tree.prev_action` whenever the rollout data contains a single step. + + .. seealso:: :class:`All actions associated with a given node (or observation) in the tree <~torchrl.data.Tree.selected_action>`. + + """ + if self.rollout is None: + return None + return self.rollout[..., 0]["action"] + + @property + def node_observation(self) -> torch.Tensor | TensorDictBase: + """Returns the observation associated with this particular node. + + This is the observation (or bag of observations) that defines the node before a branching occurs. + If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. + + If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance + is returned instead. - # @property - # def children(self) -> Tree: - # return self.subtree + For a more consistent representation, see :attr:`~.node_observations`. + + """ + # TODO: implement specs + return self.node_data["observation"] + + @property + def node_observations(self) -> torch.Tensor | TensorDictBase: + """Returns the observations associated with this particular node in a TensorDict format. + + This is the observation (or bag of observations) that defines the node before a branching occurs. + If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. + + If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance + is returned instead. + + For a more consistent representation, see :attr:`~.node_observations`. + + """ + # TODO: implement specs + return self.node_data.select("observation") @property def visits(self) -> int | torch.Tensor: + """Returns the number of visits associated with this particular node. + + This is an alias for the :attr:`~.count` attribute. + + """ return self.count + @visits.setter def visits(self, count): self.count = count - def fully_expanded(self, *, action_spec: TensorSpec | None = None) -> bool: - ... + def __setattr__(self, name: str, value: Any) -> None: + if name == "subtree" and value is not None: + wr = weakref.ref(self._tensordict) + if value._parent is None: + value._parent = wr + elif isinstance(value._parent, list): + value._parent.append(wr) + else: + value._parent = [value._parent, wr] + return super().__setattr__(name, value) @property def parent(self) -> Tree | None: + """The parent of the node. + + If the node has a parent and this object is still present in the python workspace, it will be returned by this + property. + + For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to + a different parent. + + .. note:: the ``parent`` attribute will match in content but not in identity: the tensorclass object is recustructed + using the same tensors (i.e., tensors that point to the same memory locations). + + Returns: + A ``Tree`` containing the parent data or ``None`` if the parent data is out of scope or the node is the root. + """ parent = self._parent if parent is not None: # Check that all parents match - if isinstance(parent, list): - parent = [p() for p in parent] - for p in parent[1:]: - if p is not parent[0]: - raise ValueError( - "All parents of a given node level must match." - ) - return parent[0] - return parent() + queue = [parent] + + def maybe_flatten_list(maybe_nested_list): + if isinstance(maybe_nested_list, list): + for p in maybe_nested_list: + if isinstance(p, list): + queue.append(p) + else: + yield p() + else: + yield maybe_nested_list() + + parent_result = None + while len(queue): + local_result = None + for r in maybe_flatten_list(queue.pop()): + if local_result is None: + local_result = r + elif r is not None and r is not local_result: + if isinstance(local_result, list): + local_result.append(r) + else: + local_result = [local_result, r] + if local_result is None: + continue + # replicate logic at macro level + if parent_result is None: + parent_result = local_result + else: + if isinstance(local_result, list): + local_result = [ + r for r in local_result if r not in parent_result + ] + else: + local_result = [local_result] + if isinstance(parent_result, list): + parent_result.extend(local_result) + else: + parent_result = [parent_result, *local_result] + if isinstance(parent_result, list): + return TensorDict.lazy_stack( + [self._from_tensordict(r) for r in parent_result] + ) + return self._from_tensordict(parent_result) @property def num_children(self) -> int: @@ -168,13 +340,13 @@ def is_terminal(self) -> bool | torch.Tensor: """Returns True if the tree has no children nodes.""" if self.rollout is not None: return self.rollout[..., -1]["next", "done"].squeeze(-1) - # Here we have two options: (1) subtree is None, in which case this is a node that - # has no child yet. Therefore, it can be explored further. - # (2) subtree is there, in which case it's not terminal. - return True + # If there is no rollout, there is no preceding data - either this is a root or it's a floating node. + # In either case, we assume that the node is not terminal. + return False def fully_expanded(self, env: EnvBase) -> bool: - cardinality = env.cardinality(self.node) + """Returns True if the number of children is equal to the environment cardinality.""" + cardinality = env.cardinality(self.node_data) num_actions = self.num_children return cardinality == num_actions @@ -843,7 +1015,6 @@ def _make_local_tree( root: TensorDictBase, index: torch.Tensor | None = None, compact: bool = True, - parent: Tree | None = None, ) -> Tuple[Tree, torch.Tensor | None, torch.Tensor | None]: root = root.select(*self.node_map.in_keys) node_meta = None @@ -860,6 +1031,8 @@ def _make_local_tree( while index.numel() <= 1: index = index.squeeze() d = self.data_map.storage[index] + + # Rebuild rollout step steps.append(merge_tensordicts(d, root, callback_exist=lambda *x: None)) d = d["next"] if d in self.node_map: @@ -871,6 +1044,7 @@ def _make_local_tree( else: # If the root is provided and not gathered from the storage, it could be that its # device doesn't match the data_map storage device. + root = steps[-1]["next"].select(*self.node_map.in_keys) device = getattr(self.data_map.storage, "device", None) if root.device != device: if device is not None: @@ -889,12 +1063,11 @@ def _make_local_tree( rollout=rollout, count=torch.zeros((), dtype=torch.int32), wins=torch.zeros(()), - node=root, + node_data=root, index=index, hash=None, # We do this to avoid raising an exception as rollout and subtree must be provided together - subtree=LazyStackedTensorDict(), - _parent=weakref.ref(parent) if parent is not None else None, + subtree=None, ), index, hash, @@ -916,7 +1089,7 @@ def _make_tree_iter( ): q = deque() memo = {} - tree, indices, hash = self._make_local_tree(root, index=index) + tree, indices, hash = self._make_local_tree(root, index=index, compact=compact) tree.node_id = 0 result = tree @@ -924,7 +1097,6 @@ def _make_tree_iter( counter = 1 if indices is not None: q.append((tree, indices, hash, depth)) - del tree, indices while len(q): tree, indices, hash, depth = q.popleft() @@ -936,15 +1108,29 @@ def _make_tree_iter( subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3) if subtree is None: subtree, subtree_indices, subtree_hash = self._make_local_tree( - tree.node, + tree.node_data, index=i, compact=compact, - parent=tree, ) subtree.node_id = counter counter += 1 subtree.hash = h memo[h] = (subtree, subtree_indices, subtree_hash) + else: + # We just need to save the two (or more) rollouts + subtree_bis, _, _ = self._make_local_tree( + tree.node_data, + index=i, + compact=compact, + ) + if subtree.rollout.ndim == subtree_bis.rollout.ndim: + subtree.rollout = TensorDict.stack( + [subtree.rollout, subtree_bis.rollout] + ) + else: + subtree.rollout = TensorDict.stack( + [*subtree.rollout, subtree_bis.rollout] + ) subtrees.append(subtree) if extend and subtree_indices is not None: