Skip to content

Commit

Permalink
[Feature] compatibility of consolidate with compile (quick version)
Browse files Browse the repository at this point in the history
ghstack-source-id: 9b4808042765e3e2a0db762fd84bcc11c3e2f1a2
Pull Request resolved: #1061
  • Loading branch information
vmoens committed Oct 31, 2024
1 parent 3963e51 commit 5948f25
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 29 deletions.
202 changes: 178 additions & 24 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,40 @@
# LICENSE file in the root directory of this source tree.

import argparse
import time
from typing import Any

import pytest
import torch
from packaging import version

from tensordict import TensorDict
from tensordict import tensorclass, TensorDict
from tensordict.utils import logger as tensordict_logger

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
return TensorDict(
{
str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)}
for i in range(16)
},
batch_size=[16],
device="cpu",
)
@tensorclass
class NJT:
_values: torch.Tensor
_offsets: torch.Tensor
_lengths: torch.Tensor
njt_shape: Any = None

@classmethod
def from_njt(cls, njt_tensor):
return cls(
_values=njt_tensor._values,
_offsets=njt_tensor._offsets,
_lengths=njt_tensor._lengths,
njt_shape=njt_tensor.size(0),
).clone()


@pytest.fixture(autouse=True, scope="function")
def empty_compiler_cache():
torch.compiler.reset()
yield


def _make_njt():
Expand All @@ -34,14 +48,29 @@ def _make_njt():
)


@pytest.fixture
def njt_td():
def _njt_td():
return TensorDict(
{str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
# {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
{str(i): _make_njt() for i in range(32)},
device="cpu",
)


@pytest.fixture
def njt_td():
return _njt_td()


@pytest.fixture
def td():
njtd = _njt_td()
for k0, v0 in njtd.items():
njtd[k0] = NJT.from_njt(v0)
# for k1, v1 in v0.items():
# njtd[k0, k1] = NJT.from_njt(v1)
return njtd


@pytest.fixture
def default_device():
if torch.cuda.is_available():
Expand All @@ -52,22 +81,147 @@ def default_device():
pytest.skip("CUDA/MPS is not available")


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.parametrize(
"compile_mode,num_threads",
[
[False, None],
# [False, 4],
# [False, 16],
["default", None],
["reduce-overhead", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestConsolidate:
def test_consolidate(
self, benchmark, td, compile_mode, num_threads, default_device
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")

# td = td.to(default_device)

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
consolidate = torch.compile(
consolidate, mode=compile_mode, dynamic=False, fullgraph=True
)

t0 = time.time()
consolidate(td, num_threads=num_threads)
elapsed = time.time() - t0
tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec")

for _ in range(3):
consolidate(td, num_threads=num_threads)

benchmark(consolidate, td, num_threads)

def test_consolidate_njt(self, benchmark, njt_td, compile_mode, num_threads):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")

def consolidate(td, num_threads):
return td.consolidate(num_threads=num_threads)

if compile_mode:
pytest.skip(
"Compiling NJTs consolidation currently triggers a RuntimeError."
)
# consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True)

for _ in range(3):
consolidate(njt_td, num_threads=num_threads)

benchmark(consolidate, njt_td, num_threads)


@pytest.mark.parametrize(
"consolidated,compile_mode,num_threads",
[
[False, False, None],
[True, False, None],
["within", False, None],
# [True, False, 4],
# [True, False, 16],
[True, "default", None],
],
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA device found")
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
td = td.consolidate()
benchmark(lambda: td.to(default_device))
def test_to(
self, benchmark, consolidated, td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
td = td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
if consolidated:
njt_td = njt_td.consolidate()
benchmark(lambda: njt_td.to(default_device))
def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(td, num_threads=num_threads)

benchmark(to, td, num_threads)

def test_to_njt(
self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads
):
tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb")
pin_mem = default_device.type == "cuda"
if consolidated is True:
njt_td = njt_td.consolidate(pin_memory=pin_mem)

if consolidated == "within":

def to(td, num_threads):
return td.consolidate(pin_memory=pin_mem).to(
default_device, num_threads=num_threads
)

else:

def to(td, num_threads):
return td.to(default_device, num_threads=num_threads)

if compile_mode:
to = torch.compile(to, mode=compile_mode, dynamic=True)

for _ in range(3):
to(njt_td, num_threads=num_threads)

benchmark(to, njt_td, num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
pytest.main(
[
__file__,
"--capture",
"no",
"--exitfirst",
"--benchmark-group-by",
"func",
"-vvv",
]
+ unknown
)
14 changes: 12 additions & 2 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(DEVICE)

@pytest.fixture(scope="function", autouse=True)
def auto_device():
device = torch.get_default_device()
torch.set_default_device(DEVICE)
yield
torch.set_default_device(device)


compile = functools.partial(torch.compile, fullgraph=True)
compile_overhead = functools.partial(
Expand All @@ -32,7 +39,10 @@
@pytest.fixture(scope="function", autouse=True)
def reset_dynamo():
# Start a fresh compile for each parameter of the test case
torch._dynamo.reset()
try:
torch.compiler.reset()
except AttributeError:
torch._dynamo.reset()
gc.collect()
yield

Expand Down
12 changes: 9 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3860,8 +3860,9 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
pad = 8 - pad
else:
pad = 0
flat_size.append(n + pad)
stop = start + flat_size[-1]
flat_size.append(sum([n, pad]))
# Using sum to tell dynamo to use sym_sum
stop = sum([start, flat_size[-1]])
if requires_metadata:
metadata_dict["leaves"][key] = (
_DTYPE2STRDTYPE[dtype],
Expand Down Expand Up @@ -4136,6 +4137,8 @@ def view_old_as_new(v, oldv):
return v[: oldv.numel()].view(oldv.shape)
return v.view(oldv.shape)

if num_threads is None:
num_threads = 0
if num_threads > 0:

def assign(
Expand Down Expand Up @@ -4241,7 +4244,10 @@ def _view_and_pad(tensor):
if v.device != storage.device:
v = v.to(storage.device, non_blocking=non_blocking)
stride = v.stride()
if (stride and stride[-1] != 1) or v.storage_offset():
if is_dynamo_compiling():
if not v.is_contiguous():
v = v.clone(memory_format=torch.contiguous_format)
elif (stride and stride[-1] != 1) or v.storage_offset():
v = v.clone(memory_format=torch.contiguous_format)
v, pad = _view_and_pad(v)
items.append(v)
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass):
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"data_ptr",
"dim",
Expand Down

0 comments on commit 5948f25

Please sign in to comment.