From e7062a1d68caccf5b8a9f8ad35aef366f98cd46f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 4 Dec 2024 11:34:04 +0000 Subject: [PATCH 1/4] [BugFix] Fix typing for python 3.9 ghstack-source-id: 663da84096214611804a726e2d38d27a6f21c958 Pull Request resolved: https://github.com/pytorch/rl/pull/2631 --- test/mocking_classes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 4e943e03cfc..b6f4ac7069b 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from typing import Dict, List, Optional import torch From 2511c04a533e191d8200f75a60951385438e8e1e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 4 Dec 2024 12:52:56 +0000 Subject: [PATCH 2/4] [CI] Change doc image ghstack-source-id: eceab242294ec55135d79f29e848345a5d5d455e Pull Request resolved: https://github.com/pytorch/rl/pull/2632 --- .github/workflows/docs.yml | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 10ea80c1dcd..77abee7d4fc 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -28,13 +28,23 @@ jobs: with: repository: pytorch/rl upload-artifact: docs - runner: "linux.g5.4xlarge.nvidia.gpu" - docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | set -e set -v - apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils + # apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils + yum makecache + # yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut + # Install Mesa and OpenGL Libraries: + yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL + # Install DRI Drivers: + yum install -y mesa-dri-drivers + # Install Xvfb for Headless Environments: + yum install -y xorg-x11-server-Xvfb + # xhost +local:docker + # Xvfb :1 -screen 0 1024x768x24 & + # export DISPLAY=:1 + root_dir="$(pwd)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" @@ -51,7 +61,7 @@ jobs: conda activate "${env_dir}" # 2. upgrade pip, ninja and packaging - apt-get install python3-pip unzip -y -f + # apt-get install python3-pip unzip -y -f python3 -m pip install --upgrade pip python3 -m pip install setuptools ninja packaging cmake -U From b840a772c4ed7446cbba3241f1065f18539c0149 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Dec 2024 10:54:55 -0800 Subject: [PATCH 3/4] [Example] Efficient Trajectory Sampling with CompletedTrajRepertoire ghstack-source-id: 4d5c587c69230aa8f3a1b9b6fe19f52fa683d703 Pull Request resolved: https://github.com/pytorch/rl/pull/2642 --- .../replay-buffers/filter-imcomplete-trajs.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 examples/replay-buffers/filter-imcomplete-trajs.py diff --git a/examples/replay-buffers/filter-imcomplete-trajs.py b/examples/replay-buffers/filter-imcomplete-trajs.py new file mode 100644 index 00000000000..271c7c00831 --- /dev/null +++ b/examples/replay-buffers/filter-imcomplete-trajs.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Efficient Trajectory Sampling with CompletedTrajRepertoire + +This example demonstrates how to design a custom transform that filters trajectories during sampling, +ensuring that only completed trajectories are present in sampled batches. This can be particularly useful +when dealing with environments where some trajectories might be corrupted or never reach a done state, +which could skew the learning process or lead to biased models. For instance, in robotics or autonomous +driving, a trajectory might be interrupted due to external factors such as hardware failures or human +intervention, resulting in incomplete or inconsistent data. By filtering out these incomplete trajectories, +we can improve the quality of the training data and increase the robustness of our models. +""" + +import torch +from tensordict import TensorDictBase +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs import GymEnv, TrajCounter, Transform + + +class CompletedTrajectoryRepertoire(Transform): + """ + A transform that keeps track of completed trajectories and filters them out during sampling. + """ + + def __init__(self): + super().__init__() + self.completed_trajectories = set() + self.repertoire_tensor = torch.zeros((), dtype=torch.int64) + + def _update_repertoire(self, tensordict: TensorDictBase) -> None: + """Updates the repertoire of completed trajectories.""" + done = tensordict["next", "terminated"].squeeze(-1) + traj = tensordict["next", "traj_count"][done].view(-1) + if traj.numel(): + self.completed_trajectories = self.completed_trajectories.union( + traj.tolist() + ) + self.repertoire_tensor = torch.tensor( + list(self.completed_trajectories), dtype=torch.int64 + ) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Updates the repertoire of completed trajectories during insertion.""" + self._update_repertoire(tensordict) + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Filters out incomplete trajectories during sampling.""" + traj = tensordict["next", "traj_count"] + traj = traj.unsqueeze(-1) + has_traj = (traj == self.repertoire_tensor).any(-1) + has_traj = has_traj.view(tensordict.shape) + return tensordict[has_traj] + + +def main(): + # Create a CartPole environment with trajectory counting + env = GymEnv("CartPole-v1").append_transform(TrajCounter()) + + # Create a replay buffer with the completed trajectory repertoire transform + buffer = ReplayBuffer( + storage=LazyTensorStorage(1_000_000), transform=CompletedTrajectoryRepertoire() + ) + + # Roll out the environment for 1000 steps + while True: + rollout = env.rollout(1000, break_when_any_done=False) + if not rollout["next", "done"][-1].item(): + break + + # Extend the replay buffer with the rollout + buffer.extend(rollout) + + # Get the last trajectory count + last_traj_count = rollout[-1]["next", "traj_count"].item() + print(f"Incomplete trajectory: {last_traj_count}") + + # Sample from the replay buffer 10 times + for _ in range(10): + sample_traj_counts = buffer.sample(32)["next", "traj_count"].unique() + print(f"Sampled trajectories: {sample_traj_counts}") + assert last_traj_count not in sample_traj_counts + + +if __name__ == "__main__": + main() From 19dfefc84ec9e8998b7ef6e97578fe186372d48f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 11 Dec 2024 09:15:52 -0800 Subject: [PATCH 4/4] [BugFix] Fix init_random_frames=0 ghstack-source-id: 38a544ea15631f9affb4c385c09e7c4df94af55d Pull Request resolved: https://github.com/pytorch/rl/pull/2645 --- test/test_collector.py | 2 +- torchrl/collectors/collectors.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 38191a46eaa..5c91cb83633 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1345,7 +1345,7 @@ def make_env(): functools.partial(MultiSyncDataCollector, cat_results="stack"), ], ) -@pytest.mark.parametrize("init_random_frames", [50]) # 1226: faster execution +@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution @pytest.mark.parametrize( "explicit_spec,split_trajs", [[True, True], [False, False]] ) # 1226: faster execution diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 16eb5904b84..14fbc7d5f22 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -712,10 +712,10 @@ def __init__( ) self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 + int(init_random_frames) if init_random_frames not in (None, -1) else 0 ) if ( - init_random_frames is not None + init_random_frames not in (-1, None, 0) and init_random_frames % frames_per_batch != 0 and RL_WARNINGS ):