Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Split-trajectories and represent as nested tensor #2043

Merged
merged 10 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
from _utils_internal import get_default_devices
from tensordict import assert_allclose_td, TensorDict

from torchrl._utils import _ends_with
from torchrl.collectors.utils import split_trajectories
from torchrl.data.postprocs.postprocs import MultiStep

Expand Down Expand Up @@ -310,6 +312,72 @@ def test_splits(self, num_workers, traj_len, constr):
== split_trajs.get(("collector", "traj_ids")).max() + 1
)

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
@pytest.mark.parametrize(
"constr",
[
functools.partial(split_trajectories, prefix="collector", as_nested=True),
functools.partial(split_trajectories, as_nested=True),
functools.partial(
split_trajectories,
trajectory_key=("collector", "traj_ids"),
as_nested=True,
),
],
)
def test_split_traj_nested(self, num_workers, traj_len, constr):
trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
assert trajs.shape[0] == num_workers
assert trajs.shape[1] == traj_len
split_trajs = constr(trajs)
assert split_trajs.shape[-1] == -1
assert split_trajs["next", "done"].is_nested

@pytest.mark.parametrize("num_workers", range(3, 34, 3))
@pytest.mark.parametrize("traj_len", [10, 17, 50, 97])
@pytest.mark.parametrize(
"constr0,constr1",
[
[
functools.partial(
split_trajectories, prefix="collector", as_nested=True
),
functools.partial(
split_trajectories, prefix="collector", as_nested=False
),
],
[
functools.partial(split_trajectories, as_nested=True),
functools.partial(split_trajectories, as_nested=False),
],
[
functools.partial(
split_trajectories,
trajectory_key=("collector", "traj_ids"),
as_nested=True,
),
functools.partial(
split_trajectories,
trajectory_key=("collector", "traj_ids"),
as_nested=False,
),
],
],
)
def test_split_traj_nested_equiv(self, num_workers, traj_len, constr0, constr1):
trajs = TestSplits.create_fake_trajs(num_workers, traj_len)
assert trajs.shape[0] == num_workers
assert trajs.shape[1] == traj_len
split_trajs1 = constr1(trajs)
mask_key = None
for key in split_trajs1.keys(True, True):
if _ends_with(key, "mask"):
mask_key = key
break
split_trajs0 = constr0(trajs).to_padded_tensor(mask_key=mask_key)
assert (split_trajs0 == split_trajs1).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
42 changes: 41 additions & 1 deletion torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def split_trajectories(
prefix=None,
trajectory_key: NestedKey | None = None,
done_key: NestedKey | None = None,
as_nested: bool = False,
) -> TensorDictBase:
"""A util function for trajectory separation.

Expand All @@ -44,6 +45,8 @@ def split_trajectories(
Args:
rollout_tensordict (TensorDictBase): a rollout with adjacent trajectories
along the last dimension.

Keyword Args:
prefix (NestedKey, optional): the prefix used to read and write meta-data,
such as ``"traj_ids"`` (the optional integer id of each trajectory)
and the ``"mask"`` entry indicating which data are valid and which
Expand All @@ -56,6 +59,13 @@ def split_trajectories(
to ``(prefix, "traj_ids")``.
done_key (NestedKey, optional): the key pointing to the ``"done""`` signal,
if the trajectory could not be directly recovered. Defaults to ``"done"``.
as_nested (bool, optional): whether to return the results as nested
tensors. Defaults to ``False``.\

.. note:: Using ``split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key)``
should result in the exact same result as ``as_nested=False``. Since this is an experimental
feature and relies on nested_tensors, which API may change in the future, we made this
an optional feature. The runtime should be faster with ``as_nested=True``.
Copy link
Contributor Author

@vmoens vmoens Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Returns:
A new tensordict with a leading dimension corresponding to the trajectory.
Expand Down Expand Up @@ -171,7 +181,37 @@ def split_trajectories(
rollout_tensordict = rollout_tensordict.unsqueeze(0)
return rollout_tensordict

out_splits = rollout_tensordict.reshape(-1).split(splits, 0)
out_splits = rollout_tensordict.reshape(-1)

if as_nested:
if hasattr(torch, "_nested_compute_contiguous_strides_offsets"):

def nest(x, splits=splits):
# Convert splits into shapes
shape = torch.tensor([[int(split), *x.shape[1:]] for split in splits])
return torch._nested_view_from_buffer(
x.reshape(-1),
shape,
*torch._nested_compute_contiguous_strides_offsets(shape),
)

return out_splits._fast_apply(
nest,
batch_size=[len(splits), -1],
)
else:
out_splits = out_splits.split(splits, 0)

def nest(*x):
return torch.nested.nested_tensor(list(x))

return out_splits[0]._fast_apply(
nest,
*out_splits[1:],
batch_size=[len(out_splits), *out_splits[0].batch_size[:-1], -1],
)

out_splits = out_splits.split(splits, 0)

for out_split in out_splits:
out_split.set(
Expand Down
Loading