From a7cb2bb46019fbac352057a6224e4a10d22a2f83 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Oct 2024 19:24:25 -0700 Subject: [PATCH] [Performance] Make _to_consolidated compatible with compile ghstack-source-id: 55de7c9301c0d22b39e22b44dff553d4fac5adfe Pull Request resolved: https://github.com/pytorch/tensordict/pull/1041 --- benchmarks/common/h2d_test.py | 194 ++++++- benchmarks/compile/compile_td_test.py | 6 + tensordict/_reductions.py | 2 +- tensordict/_td.py | 7 +- tensordict/base.py | 766 ++++++++++++-------------- tensordict/tensorclass.py | 67 +-- tensordict/utils.py | 44 ++ test/test_tensordict.py | 10 +- 8 files changed, 611 insertions(+), 485 deletions(-) diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index b08298dc1..0b159741a 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -4,26 +4,40 @@ # LICENSE file in the root directory of this source tree. import argparse +import time +from typing import Any import pytest import torch from packaging import version -from tensordict import TensorDict +from tensordict import tensorclass, TensorDict +from tensordict.utils import logger as tensordict_logger TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) -@pytest.fixture -def td(): - return TensorDict( - { - str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)} - for i in range(16) - }, - batch_size=[16], - device="cpu", - ) +@tensorclass +class NJT: + _values: torch.Tensor + _offsets: torch.Tensor + _lengths: torch.Tensor + njt_shape: Any = None + + @classmethod + def from_njt(cls, njt_tensor): + return cls( + _values=njt_tensor._values, + _offsets=njt_tensor._offsets, + _lengths=njt_tensor._lengths, + njt_shape=njt_tensor.size(0), + ).clone() + + +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch.compiler.reset() + yield def _make_njt(): @@ -34,14 +48,29 @@ def _make_njt(): ) -@pytest.fixture -def njt_td(): +def _njt_td(): return TensorDict( - {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, + {str(i): _make_njt() for i in range(128)}, device="cpu", ) +@pytest.fixture +def njt_td(): + return _njt_td() + + +@pytest.fixture +def td(): + njtd = _njt_td() + for k0, v0 in njtd.items(): + njtd[k0] = NJT.from_njt(v0) + # for k1, v1 in v0.items(): + # njtd[k0, k1] = NJT.from_njt(v1) + return njtd + + @pytest.fixture def default_device(): if torch.cuda.is_available(): @@ -52,22 +81,139 @@ def default_device(): pytest.skip("CUDA/MPS is not available") -@pytest.mark.parametrize("consolidated", [False, True]) +@pytest.mark.parametrize( + "compile_mode,num_threads", + [ + [False, None], + # [False, 4], + # [False, 16], + ["default", None], + ["reduce-overhead", None], + ], +) +@pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" +) +class TestConsolidate: + def test_consolidate(self, benchmark, td, compile_mode, num_threads): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + consolidate = torch.compile( + consolidate, mode=compile_mode, dynamic=True, fullgraph=True + ) + + t0 = time.time() + consolidate(td, num_threads=num_threads) + elapsed = time.time() - t0 + tensordict_logger.info(f"elapsed time first call: {elapsed:.2f} sec") + + for _ in range(3): + consolidate(td, num_threads=num_threads) + + benchmark(consolidate, td, num_threads) + + def test_to_njt(self, benchmark, njt_td, compile_mode, num_threads): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + + def consolidate(td, num_threads): + return td.consolidate(num_threads=num_threads) + + if compile_mode: + consolidate = torch.compile(consolidate, mode=compile_mode, dynamic=True) + + for _ in range(3): + consolidate(njt_td, num_threads=num_threads) + + benchmark(consolidate, njt_td, num_threads) + + +@pytest.mark.parametrize( + "consolidated,compile_mode,num_threads", + [ + [False, False, None], + [True, False, None], + ["within", False, None], + # [True, False, 4], + # [True, False, 16], + [True, "default", None], + ], +) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.1"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device): - if consolidated: - td = td.consolidate() - benchmark(lambda: td.to(default_device)) + def test_to( + self, benchmark, consolidated, td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + td = td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: - def test_to_njt(self, benchmark, consolidated, njt_td, default_device): - if consolidated: - njt_td = njt_td.consolidate() - benchmark(lambda: njt_td.to(default_device)) + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode, dynamic=True) + + for _ in range(3): + to(td, num_threads=num_threads) + + benchmark(to, td, num_threads) + + def test_to_njt( + self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + njt_td = njt_td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: + + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode, dynamic=True) + + for _ in range(3): + to(njt_td, num_threads=num_threads) + + benchmark(to, njt_td, num_threads) if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [ + __file__, + "--capture", + "no", + "--exitfirst", + "--benchmark-group-by", + "func", + "-vvv", + ] + + unknown + ) diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 3a1ef0ee1..c07859490 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -23,6 +23,12 @@ class MyTensorClass: f: torch.Tensor +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + yield + + # Functions def add_one(td): return td + 1 diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index be8aa42f1..fd4fe8be8 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -138,7 +138,7 @@ def _make_td(cls, state): def _reduce_td(data: TensorDict): consolidated = getattr(data, "_consolidated", None) - if consolidated and consolidated["metadata"] is not None: + if isinstance(consolidated, dict): storage = consolidated["storage"] storge_metadata = consolidated["metadata"] return ( diff --git a/tensordict/_td.py b/tensordict/_td.py index 4387839b5..1252621d1 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -4210,7 +4210,7 @@ def _iter(): if self.leaves_only: for key in self._keys(): target_class = self.tensordict.entry_class(key) - if _is_tensor_collection(target_class): + if self.is_leaf(target_class): continue yield key else: @@ -4239,9 +4239,10 @@ def _iter_helper( # For lazy stacks value = value[0] cls = type(value) - is_leaf = self.is_leaf(cls) - if self.include_nested and not is_leaf: + is_tc = _is_tensor_collection(cls) + if self.include_nested and is_tc: yield from self._iter_helper(value, prefix=full_key) + is_leaf = self.is_leaf(cls) if not self.leaves_only or is_leaf: yield full_key diff --git a/tensordict/base.py b/tensordict/base.py index 358cae1b1..17028d3b1 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -71,6 +71,7 @@ _shape, _split_tensordict, _td_fields, + _to_escape_compile, _unravel_key_to_tuple, _zip_strict, cache, @@ -92,6 +93,8 @@ TensorDictFuture, unravel_key, unravel_key_list, + view_and_pad, + view_old_as_new, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor from torch.nn.parameter import UninitializedTensorMixin @@ -3573,192 +3576,47 @@ def saved_path(self): ) # Generic method to get a class metadata - def _reduce_get_metadata(self): + def _reduce_get_metadata(self) -> dict: return { "device": str(self.device) if self.device is not None else None, - "names": self.names, + "names": self._maybe_names(), "batch_size": list(self.batch_size), "is_locked": self._is_locked, } - # @cache # noqa: B019 - def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): + def _reduce_vals_and_metadata(self, *, metadata): """Returns a nested dictionary of metadata, a flat Dict[NestedKey, Tensor] containing tensor data and a list of tensor sizes.""" - if dtype is NO_DEFAULT: - dtype = self.dtype - need_padding = dtype is None - # If the dtype is not unique (self.dtype is None) then we need the metadata - # because we need a custom unpickler - requires_metadata = requires_metadata | need_padding - - if requires_metadata: - # metadata is nested - metadata_dict = { - "cls": type(self).__name__, - "non_tensors": {}, - "leaves": {}, - "cls_metadata": self._reduce_get_metadata(), - } - else: - metadata_dict = None - - # flat_key_values is flat - flat_key_values = {} - - flat_size = [] - start = 0 - - def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): - nonlocal start - n = value.element_size() * value.numel() - if need_padding: - pad = n % 8 - if pad != 0: - pad = 8 - pad - else: - pad = 0 - flat_size.append(n + pad) - stop = start + flat_size[-1] - if requires_metadata: - metadata_dict["leaves"][key] = ( - _DTYPE2STRDTYPE[dtype], - list(shape), - # _DEVICE2STRDEVICE[device], - start, - stop, - pad, - ) - start = stop + if not metadata: + return None, list(self.items(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)) + + metadata_dict = { + "cls": type(self).__name__, + "non_tensors": {}, + "leaves": {}, + "nodes": {}, + "cls_metadata": self._reduce_get_metadata(), + } - def assign( - key, - value, - track_key=(), - metadata_dict=metadata_dict, - flat_size=flat_size, - ): - total_key = key if isinstance(key, tuple) else (key,) - total_key = track_key + total_key - cls = type(value) - if issubclass(cls, torch.Tensor): - pass - elif _is_non_tensor(cls): - if requires_metadata: - metadata_dict["non_tensors"][key] = ( - value.data, - list(value.batch_size), - ) - return - elif _is_tensor_collection(cls): - metadata_dict_key = None - if requires_metadata: - metadata_dict_key = metadata_dict[key] = { - "cls": cls.__name__, - "non_tensors": {}, - "leaves": {}, - "cls_metadata": value._reduce_get_metadata(), - } - local_assign = partial( - assign, - track_key=total_key, - metadata_dict=metadata_dict_key, - flat_size=flat_size, + for k, it in self.items(True, False, is_leaf=_NESTED_TENSORS_AS_LISTS): + if _is_non_tensor(type(it)): + metadata_dict["non_tensors"][k] = ( + it.data, + list(it.batch_size), ) - value._fast_apply( - local_assign, - named=True, - nested_keys=True, - call_on_nested=True, - is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - ) - return - # Tensors: DTensor, nested and then regular - if hasattr(value, "full_tensor"): - raise NotImplementedError("DTensor is not supported yet") - if getattr(value, "is_nested", False): - if value.layout is torch.jagged: - # Get the values - values = value._values - shape = [v if isinstance(v, int) else -1 for v in values.shape] - # Get the offsets - offsets = value._offsets - # Get the lengths - lengths = value._lengths - - # Now we're saving the two tensors - # We will rely on the fact that the writing order is preserved in python dict - # (since python 3.7). Later, we will read the NJT then the NJT offset in that order - # to do the allocation. - flat_key_values[_prefix_last_key(total_key, "")] = value - flat_size.append(0) - flat_key_values[_prefix_last_key(total_key, "")] = ( - values - ) - add_single_value( - values, - _prefix_last_key(key, ""), - metadata_dict, - values.dtype, - shape, - flat_size, - ) - # Lengths - if lengths is not None: - flat_key_values[ - _prefix_last_key(total_key, "") - ] = lengths - add_single_value( - lengths, - _prefix_last_key(key, ""), - metadata_dict, - lengths.dtype, - lengths.shape, - flat_size, - ) - # Offsets - flat_key_values[_prefix_last_key(total_key, "")] = ( - offsets - ) - add_single_value( - offsets, - _prefix_last_key(key, ""), - metadata_dict, - offsets.dtype, - offsets.shape, - flat_size, - ) - - else: - raise NotImplementedError( - "NST is not supported, please use layout=torch.jagged when building the nested tensor." - ) - return - flat_key_values[total_key] = value - add_single_value( - value, - key, - metadata_dict, - value.dtype, - value.shape, - # value.device, - flat_size, - ) - - self._fast_apply( - assign, - named=True, - call_on_nested=True, - nested_keys=True, - is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - filter_empty=True, - ) - return metadata_dict, flat_key_values, flat_size, need_padding + elif _is_tensor_collection(type(it)): + metadata_dict["nodes"][k] = { + "cls": type(it).__name__, + "cls_metadata": it._reduce_get_metadata(), + } + else: + metadata_dict["leaves"][k] = it + return metadata_dict, None def consolidate( self, filename: Path | str | None = None, *, - num_threads=0, + num_threads: int | None = None, device: torch.device | None = None, non_blocking: bool = False, inplace: bool = False, @@ -3767,6 +3625,7 @@ def consolidate( share_memory: bool = False, pin_memory: bool = False, metadata: bool = False, + set_on_tensor: bool = False, ) -> None: """Consolidates the tensordict content in a single storage for fast serialization. @@ -3828,15 +3687,83 @@ def consolidate( if self.is_consolidated(): return self - ( - metadata_dict, - flat_dict, - flat_size, - need_padding, - ) = self._reduce_vals_and_metadata( - requires_metadata=filename is not None or metadata, dtype=None - ) - filesize = sum(flat_size) + metadata = metadata or filename + metadata_dict, items = self._reduce_vals_and_metadata(metadata=metadata) + + start = 0 + lengths = [] + swaps = [] + origs = [] + + def view_and_pad(key, tensor: torch.Tensor, lengths=lengths) -> torch.Tensor: + nonlocal start + if hasattr(tensor, "full_tensor"): + raise NotImplementedError("DTensor is not supported yet") + if getattr(tensor, "is_nested", False): + if tensor.layout is torch.jagged: + # Get the values + values = tensor._values + shape = [v if isinstance(v, int) else -1 for v in values.shape] + # Get the offsets + offsets = tensor._offsets + # Get the lengths + lengths = tensor._lengths + + # Now we're saving the two tensors + # We will rely on the fact that the writing order is preserved in python dict + # (since python 3.7). Later, we will read the NJT then the NJT offset in that order + # to do the allocation. + origs.append(tensor) + swaps.append(None) + + view_and_pad(_prefix_last_key(key, ""), values) + # Lengths + if lengths is not None: + view_and_pad( + _prefix_last_key(key, ""), lengths + ) + # Offsets + view_and_pad(_prefix_last_key(key, ""), offsets) + else: + raise NotImplementedError( + "Strided nested-tensors are not supported yet." + ) + if is_dynamo_compiling(): + # We should maybe clone by default but that seems a bit too harsh? + tensor = tensor.clone(memory_format=torch.contiguous_format) + else: + stride = tensor.stride() + if (stride and stride[-1] != 1) or tensor.storage_offset(): + tensor = tensor.clone(memory_format=torch.contiguous_format) + + origs.append(tensor) + swap = tensor.view(-1).view(torch.uint8) + # result must always have a multiple of 8 elements + pad = swap.numel() % 8 + if pad != 0: + swap = torch.cat([swap, swap.new_zeros(8 - pad)]) + n = swap.numel() + if metadata: + info = ( + _DTYPE2STRDTYPE[tensor.dtype], + list(tensor.shape), + start, + pad, + n, + ) + metadata_dict["leaves"][key] = info + start = start + n + lengths.append(n) + swaps.append(swap) + + if metadata: + for key, val in metadata_dict: + view_and_pad(key, val) + else: + for key, val in items: + view_and_pad(key, val) + + filesize = start device = torch.device(device) if device is not None else None if filename is None: storage = torch.empty( @@ -3883,195 +3810,64 @@ def consolidate( total_storage[-8:] = len_metadata total_storage[-8 - metadata_dict_json.numel() : -8] = metadata_dict_json storage = total_storage[:-suffix] - # assert len(storage.untyped_storage()) == filesize - - offsets = torch.tensor([0] + flat_size).cumsum(0).tolist() - - def view_old_as_new(v, oldv): - v = v.view(oldv.dtype) - if v.numel() > oldv.numel(): - return v[: oldv.numel()].view(oldv.shape) - return v.view(oldv.shape) + if num_threads is None: + num_threads = 0 if num_threads > 0: - - def assign( - *, - k, - v, - start, - stop, - njts, - storage=storage, - non_blocking=non_blocking, - ): - """Reads a slice of the storage and assigns the resulting tensor in flat_dict.""" - # v may need padding - if k[-1].startswith(""): - njts[k] = v - return - v_pad = v.view(-1).view(torch.uint8) - exp_length = stop - start - pad = exp_length - v_pad.numel() - if pad: - v_pad = torch.cat([v_pad, v_pad.new_zeros(pad)]) - storage[start:stop].copy_(v_pad, non_blocking=non_blocking) - - storage_slice = storage[start:stop] - shape, dtype = v.shape, v.dtype - new_v = storage_slice.view(dtype) - if pad: - new_v = new_v[: v.numel()] - new_v = new_v.view(shape) - flat_dict[k] = new_v - - njts = {} - if num_threads > 1: - executor = ThreadPoolExecutor(num_threads) - r = [] - for i, (k, v) in enumerate(flat_dict.items()): - r.append( - executor.submit( - assign, - k=k, - v=v, - start=offsets[i], - stop=offsets[i + 1], - njts=njts, - ) - ) - if not return_early: - wait(r) - else: - # TODO: We'd need to merge the second half of this function to make this a thing - raise NotImplementedError( - "return_early is not implemented yet for `consolidate`." - ) - else: - for i, (k, v) in enumerate(flat_dict.items()): - assign( - k=k, - v=v, - start=offsets[i], - stop=offsets[i + 1], - njts=njts, - ) - for njt_key, njt in njts.items(): - newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) - njt_key_values = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_offset = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_lengths = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - val = _rebuild_njt_from_njt( - njt, - values=flat_dict.pop(njt_key_values), - offsets=flat_dict.pop(njt_key_offset), - lengths=flat_dict.pop(njt_key_lengths, None), - ) - del flat_dict[njt_key] - flat_dict[newkey] = val - - if non_blocking and device.type != "cuda": - # sync if needed - self._sync_all() + raise NotImplementedError else: - - def _view_and_pad(tensor): - result = tensor.view(-1).view(torch.uint8) - # result must always have a multiple of 8 elements - pad = 0 - if need_padding: - pad = result.numel() % 8 - if pad != 0: - result = torch.cat([result, result.new_zeros(8 - pad)]) - return result, pad - - items = [] - for v in flat_dict.values(): - if v.is_nested: - continue - if v.device != storage.device: - v = v.to(storage.device, non_blocking=non_blocking) - stride = v.stride() - if (stride and stride[-1] != 1) or v.storage_offset(): - v = v.clone(memory_format=torch.contiguous_format) - v, pad = _view_and_pad(v) - items.append(v) if non_blocking and device.type != "cuda": # sync if needed - self._sync_all() - torch.cat(items, out=storage) - for v, (k, oldv) in _zip_strict( - storage.split(flat_size), list(flat_dict.items()) - ): - if not k[-1].startswith("<"): - flat_dict[k] = view_old_as_new(v, oldv) - elif k[-1].startswith(""): - # NJT/NT always comes before offsets/shapes - nt = oldv - assert not v.numel() - nt_lengths = None - del flat_dict[k] - elif k[-1].startswith(""): - nt_vaues = view_old_as_new(v, oldv) - del flat_dict[k] - elif k[-1].startswith(""): - nt_lengths = view_old_as_new(v, oldv) - del flat_dict[k] - elif k[-1].startswith(""): - newk = k[:-1] + (k[-1].replace("", ""),) - nt_offsets = view_old_as_new(v, oldv) - del flat_dict[k] - - val = _rebuild_njt_from_njt( - nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths - ) - - flat_dict[newk] = val - - # delete the nested value to make sure that if there was an - # ordering mismatch we wouldn't be looking at the value key of - # another nested tensor. - del nt, nt_vaues, nt_offsets, nt_lengths - else: - flat_dict[k] = view_old_as_new(v, oldv) + td._sync_all() + torch.cat(swaps, out=storage) + swaps = storage.split(lengths) + + result = [ + view_old_as_new( + v, + oldv, + # set_on_tensor=set_on_tensor) + ) + for (v, oldv) in zip(swaps, origs, strict=True) + ] - def assign_val(key, val): - if isinstance(key, str): - key = (key,) - return flat_dict.get(key, val) + if set_on_tensor: + return self - if filename is None: - device = self.device - elif not inplace: - device = torch.device("cpu") - elif self.device is not None and self.device != torch.device("cpu"): - self.clear_device_() - device = None + if filename is None: + device = self.device + elif not inplace: + device = torch.device("cpu") + elif self.device is not None and self.device != torch.device("cpu"): + self.clear_device_() + device = None + else: + device = None + if inplace: + out = self + elif device in (self.device, None): + out = self.copy() + else: + out = self._fast_apply(lambda x: x, device=device) + if metadata: + keys = metadata_dict["leaves"].keys() + else: + keys, _ = zip(*items) + for k, v in _zip_strict(keys, result): + if isinstance(k, str): + k = (k,) + out._set_tuple(k, v, validated=True, inplace=False) + if metadata: + out._consolidated = {"storage": storage} + out._consolidated["metadata"] = metadata_dict_or_values else: - device = None - result = self._fast_apply( - assign_val, - named=True, - nested_keys=True, - is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - out=self if inplace else None, - device=device, - ) - result._consolidated = {"storage": storage, "metadata": metadata_dict} + out._consolidated = True + if filename is not None: if use_buffer: with open(filename, "w+b") as f: f.write(total_storage._handler.buffer) - # with open(Path(filename).with_suffix(".json"), "wb") as f: - # metadata_dict["size"] = filesize - # f.write(json.dumps(metadata_dict)) - return result + return out @classmethod def from_consolidated(cls, filename): @@ -4096,7 +3892,20 @@ def from_consolidated(cls, filename): def is_consolidated(self): """Checks if a TensorDict has a consolidated storage.""" - return hasattr(self, "_consolidated") + return getattr(self, "_consolidated", False) + + def consolidated_storage(self): + consolidated_data = getattr(self, "_consolidated", False) + if isinstance(consolidated_data, dict): + return consolidated_data["storage"] + elif consolidated_data: + for k, t in self.items(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + break + storage = t.untyped_storage() + return torch.empty((), dtype=torch.uint8, device=self.device).set_( + storage, storage_offset=0, stride=(1,), size=(len(storage),) + ) + return None def memmap_( self, @@ -5573,53 +5382,38 @@ def items( if is_leaf is None: is_leaf = _default_is_leaf - def _items(): - if include_nested and leaves_only: - # check the conditions once only - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - else: - yield k, val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield k, val - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - elif leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if is_leaf(type(val)): - yield k, val - else: - for k in self.keys(): - yield k, self._get_str(k, NO_DEFAULT) - if sort: yield from sorted( - _items(), + self.items(include_nested, leaves_only, is_leaf), key=lambda item: ( item[0] if isinstance(item[0], str) else ".".join(item[0]) ), ) + + if include_nested: + # check the conditions once only + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + cls = type(val) + if not leaves_only or is_leaf(cls): + yield k, val + if _is_tensor_collection(cls): + yield from ( + (_unravel_key_to_tuple((k, _key)), _val) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) + ) + elif leaves_only: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if is_leaf(type(val)): + yield k, val else: - yield from _items() + for k in self.keys(): + yield k, self._get_str(k, NO_DEFAULT) def non_tensor_items(self, include_nested: bool = False): """Returns all non-tensor leaves, maybe recursively.""" @@ -10527,6 +10321,7 @@ def to(self, *args, **kwargs) -> T: pin_memory=non_blocking_pin, num_threads=num_threads, non_blocking=non_blocking, + compilable=is_dynamo_compiling(), ) if non_blocking is None: @@ -10584,14 +10379,62 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): + def _to_consolidated( + self, *, device, pin_memory, num_threads, non_blocking, compilable + ): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 + storage = self._consolidated["storage"] - if pin_memory: - storage = storage.pin_memory() - storage_cast = storage.to(device, non_blocking=True) + + storage_cast = _to_escape_compile(storage, device=device, pin_memory=pin_memory) + _consolidated = { + "storage": storage_cast, + } + if "metadata" in self._consolidated: + # faster than deepcopy + def copy_dict(d): + return { + k: v if not isinstance(v, dict) else copy_dict(v) + for k, v in d.items() + } + + _consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) + + if compilable: + result = self._to_consolidated_compile( + device=device, + num_threads=num_threads, + storage_cast=storage_cast, + _consolidated=_consolidated, + ) + else: + result = self._to_consolidated_eager( + device=device, + num_threads=num_threads, + storage_cast=storage_cast, + _consolidated=_consolidated, + ) + + if non_blocking in (False, None): + if device.type == "cuda" and non_blocking is False: + # sending to CUDA force sync + cuda_device = device + elif storage.device.type == "cuda": + # sending from cuda: need sync unless intentionally not asked for + cuda_device = storage.device.type + else: + cuda_device = None + if cuda_device is not None: + torch.cuda.current_stream(cuda_device).synchronize() + + return result + + def _to_consolidated_eager( + self, *, device, num_threads, storage_cast, _consolidated + ): + untyped_storage = storage_cast.untyped_storage() def set_(x): @@ -10650,28 +10493,105 @@ def set_(x): result = self._fast_apply( set_, device=torch.device(device), num_threads=num_threads ) - result._consolidated = {"storage": storage_cast} - if "metadata" in self._consolidated: - # faster than deepcopy - def copy_dict(d): - return { - k: v if not isinstance(v, dict) else copy_dict(v) - for k, v in d.items() - } + result._consolidated = _consolidated + return result - result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) - if non_blocking in (False, None): - if device.type == "cuda" and non_blocking is False: - # sending to CUDA force sync - cuda_device = device - elif storage.device.type == "cuda": - # sending from cuda: need sync unless intentionally not asked for - cuda_device = storage.device.type - else: - cuda_device = None - if cuda_device is not None: - torch.cuda.current_stream(cuda_device).synchronize() + def _to_consolidated_compile( + self, *, device, num_threads, storage_cast, _consolidated + ): + + def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): + root = False + if lengths is None: + lengths = [] + pos = [] + keys = [] + root = True + for k, v in metadata["leaves"].items(): + lengths.append(v[-2]) + pos.append(v[-1]) + keys.append(prefix + (k,)) + for k, d in metadata.items(): + if "leaves" in d: + get_tensors_length( + d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,) + ) + if root: + # l = torch.empty(len(lengths), dtype=torch.long) + # l[torch.as_tensor(pos)] = torch.as_tensor(lengths) + out0 = [ + None, + ] * len(pos) + out1 = [ + None, + ] * len(pos) + for p, l, k in zip(pos, lengths, keys): + out0[p] = k + out1[p] = l + return out0, out1 + + def split_storage(consolidated): + keys, splits = get_tensors_length(consolidated["metadata"]) + return dict(zip(keys, consolidated["storage"].split(splits))) + if num_threads is None: + # unspecified num_threads should mean 0 + num_threads = 0 + + slice_map = split_storage(_consolidated) + + def view_as(src, dest): + return src.view(dest.dtype)[: dest.numel()].view(dest.shape) + + def set_(name, x): + if not isinstance(name, tuple): + name = (name,) + if x.is_nested: + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + from torch.nested import nested_tensor_from_jagged + + values = x._values + lengths = x._lengths + offsets = x._offsets + storage_offsets = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + offsets = view_as(storage_offsets, offsets) + if lengths is not None: + storage_lengths = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + lengths = view_as(storage_lengths, lengths) + storage_values = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + return nested_tensor_from_jagged( + view_as(storage_values, values), offsets=offsets, lengths=lengths + ) + + return view_as(slice_map[name], x) + + result = self._fast_apply( + set_, + device=torch.device(device), + num_threads=num_threads, + named=True, + nested_keys=True, + ) + result._consolidated = _consolidated return result def _sync_all(self): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8906eefd4..124503274 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -134,6 +134,7 @@ def __subclasscheck__(self, subclass): "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", + "_reduce_get_metadata", "_values_list", "data_ptr", "dim", @@ -569,7 +570,7 @@ def __torch_function__( setattr(cls, method_name, getattr(TensorDict, method_name)) for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): - setattr(cls, method_name, _wrap_td_method(method_name)) + setattr(cls, method_name, _wrap_td_method(method_name, force_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) @@ -857,7 +858,7 @@ def get_parent_locals(cls, localns=localns): cls._type_hints = None -def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 +def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa: D417 """Tensor class wrapper to instantiate a new tensor class object. Args: @@ -865,12 +866,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 non_tensordict (dict): Dictionary with non-tensor and nested tensor class objects """ - if not isinstance(tensordict, TensorDictBase): + if safe and not isinstance(tensordict, TensorDictBase): raise RuntimeError( f"Expected a TensorDictBase instance but got {type(tensordict)}" ) # Validating keys of tensordict - # tensordict = tensordict.copy() tensor_keys = tensordict.keys() # TODO: compile doesn't like set() over an arbitrary object if is_dynamo_compiling(): @@ -890,10 +890,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 exp_keys = set(cls.__expected_keys__) if non_tensordict is not None: nontensor_keys = set(non_tensordict.keys()) + total_keys = tensor_keys.union(nontensor_keys) else: nontensor_keys = set() non_tensordict = {} - total_keys = tensor_keys.union(nontensor_keys) + total_keys = tensor_keys for key in nontensor_keys: if key not in tensor_keys: continue @@ -917,11 +918,12 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 # empty tensordict and writing values to it. we can skip this because we already # have a tensordict to use as the underlying tensordict tc = cls.__new__(cls) - tc.__dict__["_tensordict"] = tensordict - tc.__dict__["_non_tensordict"] = non_tensordict + tc.__dict__.update( + {"_tensordict": tensordict, "_non_tensordict": non_tensordict} + ) # since we aren't calling the dataclass init method, we need to manually check # whether a __post_init__ method has been defined and invoke it if so - if hasattr(tc, "__post_init__"): + if hasattr(cls, "__post_init__"): tc.__post_init__() return tc else: @@ -1142,7 +1144,28 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): +def _wrap_td_method( + funcname, *, copy_non_tensor=False, no_wrap=False, force_wrap=False +): + def deliver_result(self, result, kwargs): + if result is None: + return + if (force_wrap or isinstance(result, TensorDictBase)) and kwargs.get( + "out" + ) is not result: + if not is_dynamo_compiling(): + non_tensordict = super(type(self), self).__getattribute__( + "_non_tensordict" + ) + else: + non_tensordict = self._non_tensordict + non_tensordict = dict(non_tensordict) + if copy_non_tensor and non_tensordict: + # use tree_map to copy + non_tensordict = tree_map(lambda x: x, non_tensordict) + return self._from_tensordict(result, non_tensordict, safe=False) + return result + def wrapped_func(self, *args, **kwargs): if not is_dynamo_compiling(): td = super(type(self), self).__getattribute__("_tensordict") @@ -1154,34 +1177,12 @@ def wrapped_func(self, *args, **kwargs): if no_wrap: return result - def check_out(kwargs, result): - out = kwargs.get("out") - if out is result: - # No need to transform output - return True - return False - if result is td: return self - def deliver_result(result): - if isinstance(result, TensorDictBase) and not check_out(kwargs, result): - if not is_dynamo_compiling(): - non_tensordict = super(type(self), self).__getattribute__( - "_non_tensordict" - ) - else: - non_tensordict = self._non_tensordict - non_tensordict = dict(non_tensordict) - if copy_non_tensor: - # use tree_map to copy - non_tensordict = tree_map(lambda x: x, non_tensordict) - return self._from_tensordict(result, non_tensordict) - return result - if isinstance(result, tuple): - return tuple(deliver_result(r) for r in result) - return deliver_result(result) + return tuple(deliver_result(self, r, kwargs) for r in result) + return deliver_result(self, result, kwargs) return wrapped_func diff --git a/tensordict/utils.py b/tensordict/utils.py index 280b224a0..e0ec809ae 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2694,3 +2694,47 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): values, **kwargs, ) + + +@torch.library.custom_op("tensordict::_to_escape_compile", mutates_args=()) +def _to_escape_compile( + storage: torch.Tensor, device: torch.device, pin_memory: bool +) -> torch.Tensor: + if pin_memory: + storage = storage.pin_memory() + storage_cast = storage.to(device, non_blocking=True) + return storage_cast + + +@_to_escape_compile.register_fake +def _(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: + return torch.empty_like(storage, device=device) + + +def view_and_pad(tensor: torch.Tensor, need_padding: bool) -> torch.Tensor: + result = tensor.view(-1).view(torch.uint8) + # result must always have a multiple of 8 elements + if need_padding: + pad = result.numel() % 8 + if pad != 0: + result = torch.cat([result, result.new_zeros(8 - pad)]) + return result + + +def view_old_as_new( + v: torch.Tensor, oldv: torch.Tensor, set_on_tensor=False +) -> torch.Tensor: + if set_on_tensor: + oldv.set_( + v.untyped_storage(), + storage_offset=v.storage_offset(), + stride=v.stride(), + size=oldv.size(), + ) + return oldv + if oldv is None: + return v + v = v.view(oldv.dtype) + if v.numel() > oldv.numel(): + return v[: oldv.numel()].view(oldv.shape) + return v.view(oldv.shape) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 0f1f65b5d..1dbcbcea1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -430,7 +430,15 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty ), td_c.to_dict() assert td_c["d"] == "a string!" - storage = td_c._consolidated["storage"] + storage = td_c.consolidated_storage() + print( + storage.untyped_storage().data_ptr(), + td_c["b", "c"].untyped_storage().data_ptr(), + ) + print( + storage.untyped_storage().data_ptr(), td_c["a"].untyped_storage().data_ptr() + ) + assert isinstance(storage, torch.Tensor) storage *= 0 if not nested: assert (td.to(td_c.device) != td_c).any(), td_c.to_dict()