Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 12, 2024
2 parents 21046ff + 19dfefc commit 17e8dfb
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 49 deletions.
18 changes: 14 additions & 4 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,23 @@ jobs:
with:
repository: pytorch/rl
upload-artifact: docs
runner: "linux.g5.4xlarge.nvidia.gpu"
docker-image: "nvidia/cudagl:11.4.0-base"
timeout: 120
script: |
set -e
set -v
apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
# apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
yum makecache
# yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut
# Install Mesa and OpenGL Libraries:
yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL
# Install DRI Drivers:
yum install -y mesa-dri-drivers
# Install Xvfb for Headless Environments:
yum install -y xorg-x11-server-Xvfb
# xhost +local:docker
# Xvfb :1 -screen 0 1024x768x24 &
# export DISPLAY=:1
root_dir="$(pwd)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
Expand All @@ -51,7 +61,7 @@ jobs:
conda activate "${env_dir}"
# 2. upgrade pip, ninja and packaging
apt-get install python3-pip unzip -y -f
# apt-get install python3-pip unzip -y -f
python3 -m pip install --upgrade pip
python3 -m pip install setuptools ninja packaging cmake -U
Expand Down
89 changes: 89 additions & 0 deletions examples/replay-buffers/filter-imcomplete-trajs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Efficient Trajectory Sampling with CompletedTrajRepertoire
This example demonstrates how to design a custom transform that filters trajectories during sampling,
ensuring that only completed trajectories are present in sampled batches. This can be particularly useful
when dealing with environments where some trajectories might be corrupted or never reach a done state,
which could skew the learning process or lead to biased models. For instance, in robotics or autonomous
driving, a trajectory might be interrupted due to external factors such as hardware failures or human
intervention, resulting in incomplete or inconsistent data. By filtering out these incomplete trajectories,
we can improve the quality of the training data and increase the robustness of our models.
"""

import torch
from tensordict import TensorDictBase
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import GymEnv, TrajCounter, Transform


class CompletedTrajectoryRepertoire(Transform):
"""
A transform that keeps track of completed trajectories and filters them out during sampling.
"""

def __init__(self):
super().__init__()
self.completed_trajectories = set()
self.repertoire_tensor = torch.zeros((), dtype=torch.int64)

def _update_repertoire(self, tensordict: TensorDictBase) -> None:
"""Updates the repertoire of completed trajectories."""
done = tensordict["next", "terminated"].squeeze(-1)
traj = tensordict["next", "traj_count"][done].view(-1)
if traj.numel():
self.completed_trajectories = self.completed_trajectories.union(
traj.tolist()
)
self.repertoire_tensor = torch.tensor(
list(self.completed_trajectories), dtype=torch.int64
)

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Updates the repertoire of completed trajectories during insertion."""
self._update_repertoire(tensordict)
return tensordict

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Filters out incomplete trajectories during sampling."""
traj = tensordict["next", "traj_count"]
traj = traj.unsqueeze(-1)
has_traj = (traj == self.repertoire_tensor).any(-1)
has_traj = has_traj.view(tensordict.shape)
return tensordict[has_traj]


def main():
# Create a CartPole environment with trajectory counting
env = GymEnv("CartPole-v1").append_transform(TrajCounter())

# Create a replay buffer with the completed trajectory repertoire transform
buffer = ReplayBuffer(
storage=LazyTensorStorage(1_000_000), transform=CompletedTrajectoryRepertoire()
)

# Roll out the environment for 1000 steps
while True:
rollout = env.rollout(1000, break_when_any_done=False)
if not rollout["next", "done"][-1].item():
break

# Extend the replay buffer with the rollout
buffer.extend(rollout)

# Get the last trajectory count
last_traj_count = rollout[-1]["next", "traj_count"].item()
print(f"Incomplete trajectory: {last_traj_count}")

# Sample from the replay buffer 10 times
for _ in range(10):
sample_traj_counts = buffer.sample(32)["next", "traj_count"].unique()
print(f"Sampled trajectories: {sample_traj_counts}")
assert last_traj_count not in sample_traj_counts


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Dict, List, Optional

import torch
Expand Down
2 changes: 1 addition & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def make_env():
functools.partial(MultiSyncDataCollector, cat_results="stack"),
],
)
@pytest.mark.parametrize("init_random_frames", [50]) # 1226: faster execution
@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution
@pytest.mark.parametrize(
"explicit_spec,split_trajs", [[True, True], [False, False]]
) # 1226: faster execution
Expand Down
104 changes: 101 additions & 3 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def _state0(self) -> TensorDict:
def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1)
reward = action.clone()
action = action + torch.arange(action.shape[-1]) / action.shape[-1]

return TensorDict(
{
Expand All @@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest:
forest.extend(r4)
return forest

def _make_forest_intersect(self) -> MCTSForest:
def _make_forest_rebranching(self) -> MCTSForest:
"""
├── 0
│ ├── 16
Expand Down Expand Up @@ -449,7 +450,7 @@ def test_forest_check_ids(self):

def test_forest_intersect(self):
state0 = self._state0()
forest = self._make_forest_intersect()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0)
subtree = forest.get_tree(TensorDict(observation=19))

Expand All @@ -467,13 +468,110 @@ def test_forest_intersect(self):

def test_forest_intersect_vertices(self):
state0 = self._state0()
forest = self._make_forest_intersect()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0)
assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash"))
assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash"))
with pytest.raises(ValueError, match="key_type must be"):
tree.vertices(key_type="another key type")

@pytest.mark.skipif(not _has_gym, reason="requires gym")
def test_simple_tree(self):
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")
r = env.rollout(10)
state0 = r[0]
forest = MCTSForest()
forest.extend(r)
# forest = self._make_forest_intersect()
tree = forest.get_tree(state0, compact=False)
assert tree.max_length() == 9
for p in tree.valid_paths():
assert len(p) == 9

@pytest.mark.parametrize(
"tree_type,compact",
[
["simple", False],
["forest", False],
# parent of rebranching trees are still buggy
# ["rebranching", False],
# ["rebranching", True],
],
)
def test_forest_parent(self, tree_type, compact):
if tree_type == "simple":
if not _has_gym:
pytest.skip("requires gym")
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")
r = env.rollout(10)
state0 = r[0]
forest = MCTSForest()
forest.extend(r)
tree = forest.get_tree(state0, compact=compact)
elif tree_type == "forest":
state0 = self._state0()
forest = self._make_forest()
tree = forest.get_tree(state0, compact=compact)
else:
state0 = self._state0()
forest = self._make_forest_rebranching()
tree = forest.get_tree(state0, compact=compact)
# Check access
tree.subtree.parent
tree.subtree.subtree.parent
tree.subtree.subtree.subtree.parent

# check present of weakref
assert tree.subtree[0]._parent is not None
assert tree.subtree[0].subtree[0]._parent is not None

# Check content
assert_close(tree.subtree.parent, tree)
for p in tree.valid_paths():
root = tree
for it in p:
node = root.subtree[it]
assert_close(node.parent, root)
root = node

def test_forest_action_attr(self):
state0 = self._state0()
forest = self._make_forest()
tree = forest.get_tree(state0)
assert tree.branching_action is None
assert (tree.subtree.branching_action != tree.subtree.prev_action).any()
assert (
tree.subtree[0].subtree.branching_action
!= tree.subtree[0].subtree.prev_action
).any()
assert tree.prev_action is None

@pytest.mark.parametrize("intersect", [False, True])
def test_forest_check_obs_match(self, intersect):
state0 = self._state0()
if intersect:
forest = self._make_forest_rebranching()
else:
forest = self._make_forest()
tree = forest.get_tree(state0)
for path in tree.valid_paths():
prev_tree = tree
for p in path:
subtree = prev_tree.subtree[p]
assert (
subtree.node_data["observation"]
== subtree.rollout[..., -1]["next", "observation"]
).all()
assert (
subtree.node_observation
== subtree.rollout[..., -1]["next", "observation"]
).all()
prev_tree = subtree


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,10 +712,10 @@ def __init__(
)
self.reset_at_each_iter = reset_at_each_iter
self.init_random_frames = (
int(init_random_frames) if init_random_frames is not None else 0
int(init_random_frames) if init_random_frames not in (None, -1) else 0
)
if (
init_random_frames is not None
init_random_frames not in (-1, None, 0)
and init_random_frames % frames_per_batch != 0
and RL_WARNINGS
):
Expand Down
Loading

0 comments on commit 17e8dfb

Please sign in to comment.