Skip to content

Commit

Permalink
[Feature] compatibility of consolidate with compile (quick version)
Browse files Browse the repository at this point in the history
ghstack-source-id: 1bf3ca550dfe5499b58f878f72c4f1687b0f247e
Pull Request resolved: #1061
  • Loading branch information
vmoens committed Nov 4, 2024
1 parent 752e6dc commit 3cf52a0
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 38 deletions.
209 changes: 184 additions & 25 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,40 @@
# LICENSE file in the root directory of this source tree.

import argparse
import time
from typing import Any

import pytest
import torch
from packaging import version

from tensordict import TensorDict
from tensordict import tensorclass, TensorDict
from tensordict.utils import logger as tensordict_logger

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
return TensorDict(
{
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
for i in range(16)
},
batch_size=[16],
device="cpu",
)
@tensorclass
class NJT:
_values: torch.Tensor
_offsets: torch.Tensor
_lengths: torch.Tensor
njt_shape: Any = None

@classmethod
def from_njt(cls, njt_tensor):
return cls(
_values=njt_tensor._values,
_offsets=njt_tensor._offsets,
_lengths=njt_tensor._lengths,
njt_shape=njt_tensor.size(0),
).clone()


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch.compiler.reset()
yield


def _make_njt():
Expand All @@ -34,14 +48,29 @@ def _make_njt():
)


@pytest.fixture
def njt_td():
def _njt_td():
return TensorDict(
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
# {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
{str(i): _make_njt() for i in range(32)},
device="cpu",
)


@pytest.fixture
def njt_td():
return _njt_td()


@pytest.fixture
def td():
njtd = _njt_td()
for k0, v0 in njtd.items():
njtd[k0] = NJT.from_njt(v0)
# for k1, v1 in v0.items():
# njtd[k0, k1] = NJT.from_njt(v1)
return njtd


@pytest.fixture
def default_device():
if torch.cuda.is_available():
Expand All @@ -52,22 +81,152 @@ def default_device():
pytest.skip("CUDA/MPS is not available")


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.parametrize(
"compile_mode,num_threads",
[
[False, None],
# [False, 4],
# [False, 16],
["default", None],
["reduce-overhead", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5"
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestConsolidate:
def test_consolidate(
self, benchmark, td, compile_mode, num_threads, default_device
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")

# td = td.to(default_device)

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
consolidate = torch.compile(
consolidate, mode=compile_mode, dynamic=False, fullgraph=True
)

t0 = time.time()
consolidate(td, num_threads=num_threads)
elapsed = time.time() - t0
tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec")

for _ in range(3):
consolidate(td, num_threads=num_threads)

benchmark(consolidate, td, num_threads)

def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
pytest.skip(
"Compiling NJTs consolidation currently triggers a RuntimeError."
)
# consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)

for _ in range(3):
consolidate(njt_td, num_threads=num_threads)

benchmark(consolidate, njt_td, num_threads)


@pytest.mark.parametrize(
"consolidated,compile_mode,num_threads",
[
[False, False, None],
[True, False, None],
["within", False, None],
# [True, False, 4],
# [True, False, 16],
[True, "default", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.2"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found")
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
td = td.consolidate()
benchmark(lambda: td.to(default_device))
def test_to(
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
td = td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
if consolidated:
njt_td = njt_td.consolidate()
benchmark(lambda: njt_td.to(default_device))
if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(td, num_threads=num_threads)

benchmark(to, td, num_threads)

def test_to_njt(
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
):
if compile_mode:
pytest.skip(
"Compiling NJTs consolidation currently triggers a RuntimeError."
)

tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
njt_td = njt_td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(njt_td, num_threads=num_threads)

benchmark(to, njt_td, num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
pytest.main(
[
__file__,
"--capture",
"no",
"--exitfirst",
"--benchmark-group-by",
"func",
"-vvv",
]
+ unknown
)
14 changes: 12 additions & 2 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(DEVICE)

@pytest.fixture(scope="function", autouse=True)
def auto_device():
device = torch.get_default_device()
torch.set_default_device(DEVICE)
yield
torch.set_default_device(device)


compile = functools.partial(torch.compile, fullgraph=True)
compile_overhead = functools.partial(
Expand All @@ -32,7 +39,10 @@
@pytest.fixture(scope="function", autouse=True)
def reset_dynamo():
# Start a fresh compile for each parameter of the test case
torch._dynamo.reset()
try:
torch.compiler.reset()
except AttributeError:
torch._dynamo.reset()
gc.collect()
yield

Expand Down
12 changes: 9 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3860,8 +3860,9 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
pad = 8 - pad
else:
pad = 0
flat_size.append(n + pad)
stop = start + flat_size[-1]
flat_size.append(sum([n, pad]))
# Using sum to tell dynamo to use sym_sum
stop = sum([start, flat_size[-1]])
if requires_metadata:
metadata_dict["leaves"][key] = (
_DTYPE2STRDTYPE[dtype],
Expand Down Expand Up @@ -4136,6 +4137,8 @@ def view_old_as_new(v, oldv):
return v[: oldv.numel()].view(oldv.shape)
return v.view(oldv.shape)

if num_threads is None:
num_threads = 0
if num_threads > 0:

def assign(
Expand Down Expand Up @@ -4241,7 +4244,10 @@ def _view_and_pad(tensor):
if v.device != storage.device:
v = v.to(storage.device, non_blocking=non_blocking)
stride = v.stride()
if (stride and stride[-1] != 1) or v.storage_offset():
if is_dynamo_compiling():
if not v.is_contiguous():
v = v.clone(memory_format=torch.contiguous_format)
elif (stride and stride[-1] != 1) or v.storage_offset():
v = v.clone(memory_format=torch.contiguous_format)
v, pad = _view_and_pad(v)
items.append(v)
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __subclasscheck__(self, subclass):
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"data_ptr",
"dim",
Expand Down
25 changes: 17 additions & 8 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6380,27 +6380,36 @@ def test_to_device_dtype_inplace(self, td_name, device):
td = getattr(self, td_name)(device)
if torch.cuda.is_available():
dest = torch.device("cuda:0")
elif torch.mps.is_available():
dest = torch.device("mps:0")
# elif torch.mps.is_available():
# dest = torch.device("mps:0")
else:
dest = torch.device("cpu")

if td_name in ("sub_td", "sub_td2"):
cm = pytest.raises(
cm_device = cm_dtype = pytest.raises(
TypeError,
match="Cannot send a _SubTensorDict instance to device/dtype inplace",
)
elif td_name in ("permute_td", "unsqueezed_td", "squeezed_td", "td_h5"):
cm = pytest.raises(TypeError, match="Cannot use inplace=True with")
cm_device = cm_dtype = pytest.raises(
TypeError, match="Cannot use inplace=True with"
)
elif td_name in ("memmap_td",) and dest.type == "cpu":
cm_device = contextlib.nullcontext()
cm_dtype = pytest.raises(
RuntimeError, match="Cannot modify locked TensorDict."
)
elif td.is_locked:
cm = pytest.raises(RuntimeError, match="Cannot modify locked TensorDict.")
cm_device = cm_dtype = pytest.raises(
RuntimeError, match="Cannot modify locked TensorDict."
)
else:
cm = contextlib.nullcontext()
with cm:
cm_device = cm_dtype = contextlib.nullcontext()
with cm_dtype:
td.to(torch.float32, inplace=True)
assert td.dtype == torch.float32, td

with cm:
with cm_device:
td.to(dest, inplace=True)
assert td.device == dest
for v in td.values(
Expand Down

1 comment on commit 3cf52a0

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 3cf52a0 Previous: 752e6dc Ratio
benchmarks/common/common_ops_test.py::test_set_shared 6754.679832337671 iter/sec (stddev: 0.0001455094557934899) 20119.3762844694 iter/sec (stddev: 0.00002409671897002342) 2.98
benchmarks/tensorclass/test_torch_functions.py::test_full_like 102.02800556542219 iter/sec (stddev: 0.00043603702511272273) 1743.0510243651881 iter/sec (stddev: 0.0001254528707912698) 17.08
benchmarks/tensorclass/test_torch_functions.py::test_zeros_like 140.00803010094938 iter/sec (stddev: 0.0018806135923055598) 5047.532146099815 iter/sec (stddev: 0.000015791535862530366) 36.05
benchmarks/tensorclass/test_torch_functions.py::test_ones_like 137.2047351135962 iter/sec (stddev: 0.0018634313948608457) 5052.006547834159 iter/sec (stddev: 0.000004544463308845592) 36.82
benchmarks/tensorclass/test_torch_functions.py::test_clone 153.9984954297178 iter/sec (stddev: 0.00012293193756781392) 2409.575086904825 iter/sec (stddev: 0.000008300850743497264) 15.65
benchmarks/tensorclass/test_torch_functions.py::test_stack 19.545933901933335 iter/sec (stddev: 0.0003603931567423208) 1135.5566841624611 iter/sec (stddev: 0.00046878776739966196) 58.10
benchmarks/tensorclass/test_torch_functions.py::test_cat 19.59550976180525 iter/sec (stddev: 0.0003737408907081909) 812.18443768779 iter/sec (stddev: 0.00003448986539566331) 41.45

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.