From 923f26e633f6af7cb557355a8fb2e9070bef76b8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 08:04:52 +0100 Subject: [PATCH] [Performance] Make _to_consolidated compatible with compile ghstack-source-id: b924c0d94db3e1b59f48b9fa22b98f4cfe89d6b9 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1041 --- benchmarks/common/h2d_test.py | 35 ++++- benchmarks/compile/compile_td_test.py | 7 + tensordict/base.py | 192 +++++++++++++++++++++++--- 3 files changed, 213 insertions(+), 21 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 0e20aae75..9db87297d 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -14,6 +14,13 @@ TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + print("Emptying cache") + yield + + @pytest.fixture def td(): return TensorDict( @@ -52,20 +59,38 @@ def default_device(): pytest.skip("CUDA/MPS is not available") -@pytest.mark.parametrize("consolidated", [False, True]) +@pytest.mark.parametrize( + "consolidated,compiled", [[False, False], [True, False], [True, True]] +) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device): + def test_to(self, benchmark, consolidated, td, default_device, compiled): if consolidated: td = td.consolidate() - benchmark(lambda: td.to(default_device)) - def test_to_njt(self, benchmark, consolidated, njt_td, default_device): + def to(td): + return td.to(default_device) + + if compiled: + to = torch.compile(to) + for _ in range(3): + to(td) + benchmark(to, td) + + def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled): if consolidated: njt_td = njt_td.consolidate() - benchmark(lambda: njt_td.to(default_device)) + + def to(td): + return td.to(default_device) + + if compiled: + to = torch.compile(to) + for _ in range(3): + to(njt_td) + benchmark(to, njt_td) if __name__ == "__main__": diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 3a1ef0ee1..b87df1918 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -23,6 +23,13 @@ class MyTensorClass: f: torch.Tensor +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + print("Emptying cache") + yield + + # Functions def add_one(td): return td + 1 diff --git a/tensordict/base.py b/tensordict/base.py index 48b306520..a98428caf 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): flat_size = [] start = 0 + sorting_index = 0 def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): - nonlocal start + nonlocal start, sorting_index n = value.element_size() * value.numel() if need_padding: pad = n % 8 @@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): start, stop, pad, + flat_size[-1], + sorting_index, ) + sorting_index = sorting_index + 1 start = stop def assign( @@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T: pin_memory=non_blocking_pin, num_threads=num_threads, non_blocking=non_blocking, + compilable=is_dynamo_compiling(), ) if non_blocking is None: @@ -10498,14 +10503,49 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): + def _to_consolidated( + self, *, device, pin_memory, num_threads, non_blocking, compilable + ): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 + storage = self._consolidated["storage"] - if pin_memory: - storage = storage.pin_memory() - storage_cast = storage.to(device, non_blocking=True) + + @torch.compiler.disable() + def to(storage): + if pin_memory: + storage = storage.pin_memory() + storage_cast = storage.to(device, non_blocking=True) + return storage_cast + + storage_cast = to(storage) + + if compilable: + result = self._to_consolidated_compile( + device=device, num_threads=num_threads, storage_cast=storage_cast + ) + else: + result = self._to_consolidated_eager( + device=device, num_threads=num_threads, storage_cast=storage_cast + ) + + if non_blocking in (False, None): + if device.type == "cuda" and non_blocking is False: + # sending to CUDA force sync + cuda_device = device + elif storage.device.type == "cuda": + # sending from cuda: need sync unless intentionally not asked for + cuda_device = storage.device.type + else: + cuda_device = None + if cuda_device is not None: + torch.cuda.current_stream(cuda_device).synchronize() + + return result + + def _to_consolidated_eager(self, *, device, num_threads, storage_cast): + untyped_storage = storage_cast.untyped_storage() def set_(x): @@ -10574,18 +10614,138 @@ def copy_dict(d): } result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) - if non_blocking in (False, None): - if device.type == "cuda" and non_blocking is False: - # sending to CUDA force sync - cuda_device = device - elif storage.device.type == "cuda": - # sending from cuda: need sync unless intentionally not asked for - cuda_device = storage.device.type - else: - cuda_device = None - if cuda_device is not None: - torch.cuda.current_stream(cuda_device).synchronize() + return result + + def _to_consolidated_compile(self, *, device, num_threads, storage_cast): + + def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): + root = False + if lengths is None: + lengths = [] + pos = [] + keys = [] + root = True + for k, v in metadata["leaves"].items(): + lengths.append(v[-2]) + pos.append(v[-1]) + keys.append(prefix + (k,)) + for k, d in metadata.items(): + if "leaves" in d: + get_tensors_length( + d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,) + ) + if root: + # l = torch.empty(len(lengths), dtype=torch.long) + # l[torch.as_tensor(pos)] = torch.as_tensor(lengths) + out0 = [ + None, + ] * len(pos) + out1 = [ + None, + ] * len(pos) + for p, l, k in zip(pos, lengths, keys): + out0[p] = k + out1[p] = l + return out0, out1 + + def split_storage(consolidated): + keys, splits = get_tensors_length(consolidated["metadata"]) + return dict(zip(keys, consolidated["storage"].split(splits))) + + if num_threads is None: + # unspecified num_threads should mean 0 + num_threads = 0 + + _consolidated = {"storage": storage_cast} + if "metadata" in self._consolidated: + # faster than deepcopy + def copy_dict(d): + return { + k: v if not isinstance(v, dict) else copy_dict(v) + for k, v in d.items() + } + + _consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) + + slice_map = split_storage(_consolidated) + + def view_as(src, dest): + return src.view(dest.dtype)[: dest.numel()].view(dest.shape) + def set_(name, x): + if not isinstance(name, tuple): + name = (name,) + if x.is_nested: + from torch._subclasses.fake_tensor import FakeTensor + from torch._subclasses.functional_tensor import FunctionalTensor + from torch.nested._internal.nested_tensor import ( + _tensor_symint_registry, + NestedTensor, + ) + from torch.nested._internal.ops import extract_kwargs + + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + kwargs = extract_kwargs(x) + values = x._values + lengths = x._lengths + offsets = x._offsets + storage_offsets = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + kwargs["offsets"] = view_as(storage_offsets, offsets) + if lengths is not None: + storage_lengths = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + kwargs["lengths"] = view_as(storage_lengths, lengths) + ragged_source = lengths + else: + ragged_source = offsets + new_thing = kwargs.get("lengths", kwargs.get("offsets")) + if isinstance(new_thing, (FakeTensor, FunctionalTensor)): + from torch._subclasses.functional_tensor import ( + mb_unwrap_functional_tensor, + ) + + # Temporary hack until we have the union find + tgt = mb_unwrap_functional_tensor(new_thing) + src = mb_unwrap_functional_tensor(ragged_source) + tgt.nested_int_memo = src.nested_int_memo + else: + _tensor_symint_registry[new_thing] = _tensor_symint_registry[ + ragged_source + ] + + storage_values = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + return NestedTensor( + view_as(storage_values, values), + **kwargs, + ) + return view_as(slice_map[name], x) + + result = self._fast_apply( + set_, + device=torch.device(device), + num_threads=num_threads, + named=True, + nested_keys=True, + ) + result._consolidated = _consolidated return result def _sync_all(self):