Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] compatibility of consolidate with compile (quick version) #1061

Merged
merged 9 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 184 additions & 25 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,40 @@
# LICENSE file in the root directory of this source tree.

import argparse
import time
from typing import Any

import pytest
import torch
from packaging import version

from tensordict import TensorDict
from tensordict import tensorclass, TensorDict
from tensordict.utils import logger as tensordict_logger

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
return TensorDict(
{
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
for i in range(16)
},
batch_size=[16],
device="cpu",
)
@tensorclass
class NJT:
_values: torch.Tensor
_offsets: torch.Tensor
_lengths: torch.Tensor
njt_shape: Any = None

@classmethod
def from_njt(cls, njt_tensor):
return cls(
_values=njt_tensor._values,
_offsets=njt_tensor._offsets,
_lengths=njt_tensor._lengths,
njt_shape=njt_tensor.size(0),
).clone()


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch.compiler.reset()
yield


def _make_njt():
Expand All @@ -34,14 +48,29 @@ def _make_njt():
)


@pytest.fixture
def njt_td():
def _njt_td():
return TensorDict(
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
# {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
{str(i): _make_njt() for i in range(32)},
device="cpu",
)


@pytest.fixture
def njt_td():
return _njt_td()


@pytest.fixture
def td():
njtd = _njt_td()
for k0, v0 in njtd.items():
njtd[k0] = NJT.from_njt(v0)
# for k1, v1 in v0.items():
# njtd[k0, k1] = NJT.from_njt(v1)
return njtd


@pytest.fixture
def default_device():
if torch.cuda.is_available():
Expand All @@ -52,22 +81,152 @@ def default_device():
pytest.skip("CUDA/MPS is not available")


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.parametrize(
"compile_mode,num_threads",
[
[False, None],
# [False, 4],
# [False, 16],
["default", None],
["reduce-overhead", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5"
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestConsolidate:
def test_consolidate(
self, benchmark, td, compile_mode, num_threads, default_device
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")

# td = td.to(default_device)

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
consolidate = torch.compile(
consolidate, mode=compile_mode, dynamic=False, fullgraph=True
)

t0 = time.time()
consolidate(td, num_threads=num_threads)
elapsed = time.time() - t0
tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec")

for _ in range(3):
consolidate(td, num_threads=num_threads)

benchmark(consolidate, td, num_threads)

def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
pytest.skip(
"Compiling NJTs consolidation currently triggers a RuntimeError."
)
# consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)

for _ in range(3):
consolidate(njt_td, num_threads=num_threads)

benchmark(consolidate, njt_td, num_threads)


@pytest.mark.parametrize(
"consolidated,compile_mode,num_threads",
[
[False, False, None],
[True, False, None],
["within", False, None],
# [True, False, 4],
# [True, False, 16],
[True, "default", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.2"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found")
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
td = td.consolidate()
benchmark(lambda: td.to(default_device))
def test_to(
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
td = td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
if consolidated:
njt_td = njt_td.consolidate()
benchmark(lambda: njt_td.to(default_device))
if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(td, num_threads=num_threads)

benchmark(to, td, num_threads)

def test_to_njt(
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
):
if compile_mode:
pytest.skip(
"Compiling NJTs consolidation currently triggers a RuntimeError."
)

tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
njt_td = njt_td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(njt_td, num_threads=num_threads)

benchmark(to, njt_td, num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
pytest.main(
[
__file__,
"--capture",
"no",
"--exitfirst",
"--benchmark-group-by",
"func",
"-vvv",
]
+ unknown
)
14 changes: 12 additions & 2 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(DEVICE)

@pytest.fixture(scope="function", autouse=True)
def auto_device():
device = torch.get_default_device()
torch.set_default_device(DEVICE)
yield
torch.set_default_device(device)


compile = functools.partial(torch.compile, fullgraph=True)
compile_overhead = functools.partial(
Expand All @@ -32,7 +39,10 @@
@pytest.fixture(scope="function", autouse=True)
def reset_dynamo():
# Start a fresh compile for each parameter of the test case
torch._dynamo.reset()
try:
torch.compiler.reset()
except AttributeError:
torch._dynamo.reset()
gc.collect()
yield

Expand Down
12 changes: 9 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3860,8 +3860,9 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
pad = 8 - pad
else:
pad = 0
flat_size.append(n + pad)
stop = start + flat_size[-1]
flat_size.append(sum([n, pad]))
# Using sum to tell dynamo to use sym_sum
stop = sum([start, flat_size[-1]])
if requires_metadata:
metadata_dict["leaves"][key] = (
_DTYPE2STRDTYPE[dtype],
Expand Down Expand Up @@ -4136,6 +4137,8 @@ def view_old_as_new(v, oldv):
return v[: oldv.numel()].view(oldv.shape)
return v.view(oldv.shape)

if num_threads is None:
num_threads = 0
if num_threads > 0:

def assign(
Expand Down Expand Up @@ -4241,7 +4244,10 @@ def _view_and_pad(tensor):
if v.device != storage.device:
v = v.to(storage.device, non_blocking=non_blocking)
stride = v.stride()
if (stride and stride[-1] != 1) or v.storage_offset():
if is_dynamo_compiling():
if not v.is_contiguous():
v = v.clone(memory_format=torch.contiguous_format)
elif (stride and stride[-1] != 1) or v.storage_offset():
v = v.clone(memory_format=torch.contiguous_format)
v, pad = _view_and_pad(v)
items.append(v)
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __subclasscheck__(self, subclass):
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"data_ptr",
"dim",
Expand Down
25 changes: 17 additions & 8 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6380,27 +6380,36 @@ def test_to_device_dtype_inplace(self, td_name, device):
td = getattr(self, td_name)(device)
if torch.cuda.is_available():
dest = torch.device("cuda:0")
elif torch.mps.is_available():
dest = torch.device("mps:0")
# elif torch.mps.is_available():
# dest = torch.device("mps:0")
else:
dest = torch.device("cpu")

if td_name in ("sub_td", "sub_td2"):
cm = pytest.raises(
cm_device = cm_dtype = pytest.raises(
TypeError,
match="Cannot send a _SubTensorDict instance to device/dtype inplace",
)
elif td_name in ("permute_td", "unsqueezed_td", "squeezed_td", "td_h5"):
cm = pytest.raises(TypeError, match="Cannot use inplace=True with")
cm_device = cm_dtype = pytest.raises(
TypeError, match="Cannot use inplace=True with"
)
elif td_name in ("memmap_td",) and dest.type == "cpu":
cm_device = contextlib.nullcontext()
cm_dtype = pytest.raises(
RuntimeError, match="Cannot modify locked TensorDict."
)
elif td.is_locked:
cm = pytest.raises(RuntimeError, match="Cannot modify locked TensorDict.")
cm_device = cm_dtype = pytest.raises(
RuntimeError, match="Cannot modify locked TensorDict."
)
else:
cm = contextlib.nullcontext()
with cm:
cm_device = cm_dtype = contextlib.nullcontext()
with cm_dtype:
td.to(torch.float32, inplace=True)
assert td.dtype == torch.float32, td

with cm:
with cm_device:
td.to(dest, inplace=True)
assert td.device == dest
for v in td.values(
Expand Down
Loading