diff --git a/benchmarks/common/h2d_test.py b/benchmarks/common/h2d_test.py index 0e20aae75..3751f2d24 100644 --- a/benchmarks/common/h2d_test.py +++ b/benchmarks/common/h2d_test.py @@ -4,26 +4,39 @@ # LICENSE file in the root directory of this source tree. import argparse +from typing import Any import pytest import torch from packaging import version -from tensordict import TensorDict +from tensordict import tensorclass, TensorDict +from tensordict.utils import logger as tensordict_logger TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) -@pytest.fixture -def td(): - return TensorDict( - { - str(i): {str(j): torch.randn(16, 16, device="cpu") for j in range(16)} - for i in range(16) - }, - batch_size=[16], - device="cpu", - ) +@tensorclass +class NJT: + _values: torch.Tensor + _offsets: torch.Tensor + _lengths: torch.Tensor + njt_shape: Any = None + + @classmethod + def from_njt(cls, njt_tensor): + return NJT( + _values=njt_tensor._values, + _offsets=njt_tensor._offsets, + _lengths=njt_tensor._lengths, + njt_shape=njt_tensor.size(0), + ) + + +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + yield def _make_njt(): @@ -34,14 +47,27 @@ def _make_njt(): ) -@pytest.fixture -def njt_td(): +def _njt_td(): return TensorDict( {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)}, device="cpu", ) +@pytest.fixture +def njt_td(): + return _njt_td() + + +@pytest.fixture +def td(): + njtd = _njt_td() + for k0, v0 in njtd.items(): + for k1, v1 in v0.items(): + njtd[k0, k1] = NJT.from_njt(v1) + return njtd + + @pytest.fixture def default_device(): if torch.cuda.is_available(): @@ -52,22 +78,81 @@ def default_device(): pytest.skip("CUDA/MPS is not available") -@pytest.mark.parametrize("consolidated", [False, True]) +@pytest.mark.parametrize( + "consolidated,compile_mode,num_threads", + [ + [False, False, None], + [True, False, None], + ["within", False, None], + # [True, False, 4], + # [True, False, 16], + # [True, "default", None], + ], +) @pytest.mark.skipif( TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) class TestTo: - def test_to(self, benchmark, consolidated, td, default_device): - if consolidated: - td = td.consolidate() - benchmark(lambda: td.to(default_device)) + def test_to( + self, benchmark, consolidated, td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"td size {td.bytes() / 1024 / 1024:.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + td = td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: + + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode) + + for _ in range(3): + to(td, num_threads=num_threads) + + benchmark(to, td, num_threads) - def test_to_njt(self, benchmark, consolidated, njt_td, default_device): - if consolidated: - njt_td = njt_td.consolidate() - benchmark(lambda: njt_td.to(default_device)) + def test_to_njt( + self, benchmark, consolidated, njt_td, default_device, compile_mode, num_threads + ): + tensordict_logger.info(f"njtd size {njt_td.bytes() / 1024 / 1024 :.2f} Mb") + pin_mem = default_device.type == "cuda" + if consolidated is True: + njt_td = njt_td.consolidate(pin_memory=pin_mem) + + if consolidated == "within": + + def to(td, num_threads): + return td.consolidate(pin_memory=pin_mem).to( + default_device, num_threads=num_threads + ) + + else: + + def to(td, num_threads): + return td.to(default_device, num_threads=num_threads) + + if compile_mode: + to = torch.compile(to, mode=compile_mode) + + for _ in range(3): + to(njt_td, num_threads=num_threads) + + benchmark(to, njt_td, num_threads) if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main( + [__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + + unknown + ) diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 3a1ef0ee1..c07859490 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -23,6 +23,12 @@ class MyTensorClass: f: torch.Tensor +@pytest.fixture(autouse=True, scope="function") +def empty_compiler_cache(): + torch._dynamo.reset_code_caches() + yield + + # Functions def add_one(td): return td + 1 diff --git a/tensordict/base.py b/tensordict/base.py index 358cae1b1..554697cd9 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -71,6 +71,7 @@ _shape, _split_tensordict, _td_fields, + _to_escape_compile, _unravel_key_to_tuple, _zip_strict, cache, @@ -92,6 +93,9 @@ TensorDictFuture, unravel_key, unravel_key_list, + view_and_pad, + view_cat_split, + view_old_as_new, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor from torch.nn.parameter import UninitializedTensorMixin @@ -3573,15 +3577,14 @@ def saved_path(self): ) # Generic method to get a class metadata - def _reduce_get_metadata(self): + def _reduce_get_metadata(self) -> dict: return { "device": str(self.device) if self.device is not None else None, - "names": self.names, + "names": self._maybe_names(), "batch_size": list(self.batch_size), "is_locked": self._is_locked, } - # @cache # noqa: B019 def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): """Returns a nested dictionary of metadata, a flat Dict[NestedKey, Tensor] containing tensor data and a list of tensor sizes.""" if dtype is NO_DEFAULT: @@ -3607,9 +3610,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): flat_size = [] start = 0 + sorting_index = 0 def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): - nonlocal start + nonlocal start, sorting_index n = value.element_size() * value.numel() if need_padding: pad = n % 8 @@ -3627,7 +3631,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): start, stop, pad, + flat_size[-1], + sorting_index, ) + sorting_index = sorting_index + 1 start = stop def assign( @@ -3640,8 +3647,9 @@ def assign( total_key = key if isinstance(key, tuple) else (key,) total_key = track_key + total_key cls = type(value) - if issubclass(cls, torch.Tensor): + if cls is Tensor or issubclass(cls, Tensor): pass + # must go before is_tensor_collection elif _is_non_tensor(cls): if requires_metadata: metadata_dict["non_tensors"][key] = ( @@ -3658,8 +3666,9 @@ def assign( "leaves": {}, "cls_metadata": value._reduce_get_metadata(), } - local_assign = partial( - assign, + local_assign = lambda key, value: assign( + key, + value, track_key=total_key, metadata_dict=metadata_dict_key, flat_size=flat_size, @@ -3670,6 +3679,7 @@ def assign( nested_keys=True, call_on_nested=True, is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, + filter_empty=True, ) return # Tensors: DTensor, nested and then regular @@ -3767,6 +3777,7 @@ def consolidate( share_memory: bool = False, pin_memory: bool = False, metadata: bool = False, + set_on_tensor: bool = False, ) -> None: """Consolidates the tensordict content in a single storage for fast serialization. @@ -3887,12 +3898,6 @@ def consolidate( offsets = torch.tensor([0] + flat_size).cumsum(0).tolist() - def view_old_as_new(v, oldv): - v = v.view(oldv.dtype) - if v.numel() > oldv.numel(): - return v[: oldv.numel()].view(oldv.shape) - return v.view(oldv.shape) - if num_threads > 0: def assign( @@ -3915,14 +3920,22 @@ def assign( pad = exp_length - v_pad.numel() if pad: v_pad = torch.cat([v_pad, v_pad.new_zeros(pad)]) - storage[start:stop].copy_(v_pad, non_blocking=non_blocking) - storage_slice = storage[start:stop] + storage_slice.copy_(v_pad, non_blocking=non_blocking) + shape, dtype = v.shape, v.dtype new_v = storage_slice.view(dtype) if pad: new_v = new_v[: v.numel()] new_v = new_v.view(shape) + if set_on_tensor: + v.set_( + new_v.untyped_storage(), + storage_offset=new_v.storage_offset(), + stride=new_v.stride(), + size=new_v.size(), + ) + return flat_dict[k] = new_v njts = {} @@ -3956,76 +3969,80 @@ def assign( stop=offsets[i + 1], njts=njts, ) - for njt_key, njt in njts.items(): - newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) - njt_key_values = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_offset = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - njt_key_lengths = njt_key[:-1] + ( - njt_key[-1].replace("", ""), - ) - val = _rebuild_njt_from_njt( - njt, - values=flat_dict.pop(njt_key_values), - offsets=flat_dict.pop(njt_key_offset), - lengths=flat_dict.pop(njt_key_lengths, None), - ) - del flat_dict[njt_key] - flat_dict[newkey] = val + if not set_on_tensor: + for njt_key, njt in njts.items(): + newkey = njt_key[:-1] + (njt_key[-1].replace("", ""),) + njt_key_values = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) + njt_key_offset = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) + njt_key_lengths = njt_key[:-1] + ( + njt_key[-1].replace("", ""), + ) + val = _rebuild_njt_from_njt( + njt, + values=flat_dict.pop(njt_key_values), + offsets=flat_dict.pop(njt_key_offset), + lengths=flat_dict.pop(njt_key_lengths, None), + ) + del flat_dict[njt_key] + flat_dict[newkey] = val if non_blocking and device.type != "cuda": # sync if needed self._sync_all() + if set_on_tensor: + return self else: - def _view_and_pad(tensor): - result = tensor.view(-1).view(torch.uint8) - # result must always have a multiple of 8 elements - pad = 0 - if need_padding: - pad = result.numel() % 8 - if pad != 0: - result = torch.cat([result, result.new_zeros(8 - pad)]) - return result, pad - items = [] for v in flat_dict.values(): if v.is_nested: + items.append(None) continue if v.device != storage.device: v = v.to(storage.device, non_blocking=non_blocking) - stride = v.stride() - if (stride and stride[-1] != 1) or v.storage_offset(): + if is_dynamo_compiling(): v = v.clone(memory_format=torch.contiguous_format) - v, pad = _view_and_pad(v) + else: + stride = v.stride() + if (stride and stride[-1] != 1) or v.storage_offset(): + v = v.clone(memory_format=torch.contiguous_format) items.append(v) - if non_blocking and device.type != "cuda": - # sync if needed - self._sync_all() - torch.cat(items, out=storage) - for v, (k, oldv) in _zip_strict( - storage.split(flat_size), list(flat_dict.items()) - ): + + items = view_cat_split( + self, + items, + storage, + need_padding, + non_blocking, + device, + flat_size, + set_on_tensor, + ) + if set_on_tensor: + return self + + for k, v in _zip_strict(list(flat_dict.keys()), items): if not k[-1].startswith("<"): - flat_dict[k] = view_old_as_new(v, oldv) + flat_dict[k] = v elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes - nt = oldv + nt = flat_dict[k] assert not v.numel() nt_lengths = None del flat_dict[k] elif k[-1].startswith(""): - nt_vaues = view_old_as_new(v, oldv) + nt_vaues = v del flat_dict[k] elif k[-1].startswith(""): - nt_lengths = view_old_as_new(v, oldv) + nt_lengths = v del flat_dict[k] elif k[-1].startswith(""): newk = k[:-1] + (k[-1].replace("", ""),) - nt_offsets = view_old_as_new(v, oldv) + nt_offsets = v del flat_dict[k] val = _rebuild_njt_from_njt( @@ -4039,7 +4056,7 @@ def _view_and_pad(tensor): # another nested tensor. del nt, nt_vaues, nt_offsets, nt_lengths else: - flat_dict[k] = view_old_as_new(v, oldv) + flat_dict[k] = v def assign_val(key, val): if isinstance(key, str): @@ -4055,12 +4072,16 @@ def assign_val(key, val): device = None else: device = None + if inplace: + result = self + else: + result = None result = self._fast_apply( assign_val, named=True, nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, - out=self if inplace else None, + out=result, device=device, ) result._consolidated = {"storage": storage, "metadata": metadata_dict} @@ -10527,6 +10548,7 @@ def to(self, *args, **kwargs) -> T: pin_memory=non_blocking_pin, num_threads=num_threads, non_blocking=non_blocking, + compilable=is_dynamo_compiling(), ) if non_blocking is None: @@ -10584,14 +10606,54 @@ def to_pinmem(tensor, _to=to): self._sync_all() return result - def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): + def _to_consolidated( + self, *, device, pin_memory, num_threads, non_blocking, compilable + ): if num_threads is None: # unspecified num_threads should mean 0 num_threads = 0 + storage = self._consolidated["storage"] - if pin_memory: - storage = storage.pin_memory() - storage_cast = storage.to(device, non_blocking=True) + + storage_cast = _to_escape_compile(storage, device=device, pin_memory=pin_memory) + _consolidated = { + "storage": storage_cast, + } + if "metadata" in self._consolidated: + # faster than deepcopy + def copy_dict(d): + return { + k: v if not isinstance(v, dict) else copy_dict(v) + for k, v in d.items() + } + + _consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) + + if compilable: + result = self._to_consolidated_compile( + device=device, num_threads=num_threads, storage_cast=storage_cast, _consolidated=_consolidated, + ) + else: + result = self._to_consolidated_eager( + device=device, num_threads=num_threads, storage_cast=storage_cast, _consolidated=_consolidated, + ) + + if non_blocking in (False, None): + if device.type == "cuda" and non_blocking is False: + # sending to CUDA force sync + cuda_device = device + elif storage.device.type == "cuda": + # sending from cuda: need sync unless intentionally not asked for + cuda_device = storage.device.type + else: + cuda_device = None + if cuda_device is not None: + torch.cuda.current_stream(cuda_device).synchronize() + + return result + + def _to_consolidated_eager(self, *, device, num_threads, storage_cast, _consolidated): + untyped_storage = storage_cast.untyped_storage() def set_(x): @@ -10650,28 +10712,103 @@ def set_(x): result = self._fast_apply( set_, device=torch.device(device), num_threads=num_threads ) - result._consolidated = {"storage": storage_cast} - if "metadata" in self._consolidated: - # faster than deepcopy - def copy_dict(d): - return { - k: v if not isinstance(v, dict) else copy_dict(v) - for k, v in d.items() - } + result._consolidated = _consolidated + return result - result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"]) - if non_blocking in (False, None): - if device.type == "cuda" and non_blocking is False: - # sending to CUDA force sync - cuda_device = device - elif storage.device.type == "cuda": - # sending from cuda: need sync unless intentionally not asked for - cuda_device = storage.device.type - else: - cuda_device = None - if cuda_device is not None: - torch.cuda.current_stream(cuda_device).synchronize() + def _to_consolidated_compile(self, *, device, num_threads, storage_cast, _consolidated): + + def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()): + root = False + if lengths is None: + lengths = [] + pos = [] + keys = [] + root = True + for k, v in metadata["leaves"].items(): + lengths.append(v[-2]) + pos.append(v[-1]) + keys.append(prefix + (k,)) + for k, d in metadata.items(): + if "leaves" in d: + get_tensors_length( + d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,) + ) + if root: + # l = torch.empty(len(lengths), dtype=torch.long) + # l[torch.as_tensor(pos)] = torch.as_tensor(lengths) + out0 = [ + None, + ] * len(pos) + out1 = [ + None, + ] * len(pos) + for p, l, k in zip(pos, lengths, keys): + out0[p] = k + out1[p] = l + return out0, out1 + + def split_storage(consolidated): + keys, splits = get_tensors_length(consolidated["metadata"]) + return dict(zip(keys, consolidated["storage"].split(splits))) + + if num_threads is None: + # unspecified num_threads should mean 0 + num_threads = 0 + + slice_map = split_storage(_consolidated) + + def view_as(src, dest): + return src.view(dest.dtype)[: dest.numel()].view(dest.shape) + + def set_(name, x): + if not isinstance(name, tuple): + name = (name,) + if x.is_nested: + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + from torch.nested import nested_tensor_from_jagged + + values = x._values + lengths = x._lengths + offsets = x._offsets + storage_offsets = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + offsets = view_as(storage_offsets, offsets) + if lengths is not None: + storage_lengths = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + lengths = view_as(storage_lengths, lengths) + storage_values = slice_map[ + ( + *name[:-1], + "" + name[-1], + ) + ] + return nested_tensor_from_jagged( + view_as(storage_values, values), offsets=offsets, lengths=lengths + ) + + return view_as(slice_map[name], x) + result = self._fast_apply( + set_, + device=torch.device(device), + num_threads=num_threads, + named=True, + nested_keys=True, + ) + result._consolidated = _consolidated return result def _sync_all(self): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 8906eefd4..f59b239f4 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -134,6 +134,7 @@ def __subclasscheck__(self, subclass): "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", + "_reduce_get_metadata", "_values_list", "data_ptr", "dim", @@ -857,7 +858,7 @@ def get_parent_locals(cls, localns=localns): cls._type_hints = None -def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 +def _from_tensordict(cls, tensordict, non_tensordict=None, safe=True): # noqa: D417 """Tensor class wrapper to instantiate a new tensor class object. Args: @@ -865,12 +866,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 non_tensordict (dict): Dictionary with non-tensor and nested tensor class objects """ - if not isinstance(tensordict, TensorDictBase): + if safe and not isinstance(tensordict, TensorDictBase): raise RuntimeError( f"Expected a TensorDictBase instance but got {type(tensordict)}" ) # Validating keys of tensordict - # tensordict = tensordict.copy() tensor_keys = tensordict.keys() # TODO: compile doesn't like set() over an arbitrary object if is_dynamo_compiling(): @@ -890,10 +890,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 exp_keys = set(cls.__expected_keys__) if non_tensordict is not None: nontensor_keys = set(non_tensordict.keys()) + total_keys = tensor_keys.union(nontensor_keys) else: nontensor_keys = set() non_tensordict = {} - total_keys = tensor_keys.union(nontensor_keys) + total_keys = tensor_keys for key in nontensor_keys: if key not in tensor_keys: continue @@ -917,11 +918,11 @@ def _from_tensordict(cls, tensordict, non_tensordict=None): # noqa: D417 # empty tensordict and writing values to it. we can skip this because we already # have a tensordict to use as the underlying tensordict tc = cls.__new__(cls) - tc.__dict__["_tensordict"] = tensordict - tc.__dict__["_non_tensordict"] = non_tensordict + tc.__dict__.update({"_tensordict": tensordict, + "_non_tensordict": non_tensordict}) # since we aren't calling the dataclass init method, we need to manually check # whether a __post_init__ method has been defined and invoke it if so - if hasattr(tc, "__post_init__"): + if hasattr(cls, "__post_init__"): tc.__post_init__() return tc else: @@ -1143,6 +1144,27 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): + def check_out(kwargs, result): + # No need to transform output if True + return kwargs.get("out") is result + + def deliver_result(self, result, kwargs): + if result is None: + return + if isinstance(result, TensorDictBase) and not check_out(kwargs, result): + if not is_dynamo_compiling(): + non_tensordict = super(type(self), self).__getattribute__( + "_non_tensordict" + ) + else: + non_tensordict = self._non_tensordict + non_tensordict = dict(non_tensordict) + if copy_non_tensor and non_tensordict: + # use tree_map to copy + non_tensordict = tree_map(lambda x: x, non_tensordict) + return self._from_tensordict(result, non_tensordict, safe=False) + return result + def wrapped_func(self, *args, **kwargs): if not is_dynamo_compiling(): td = super(type(self), self).__getattribute__("_tensordict") @@ -1154,34 +1176,12 @@ def wrapped_func(self, *args, **kwargs): if no_wrap: return result - def check_out(kwargs, result): - out = kwargs.get("out") - if out is result: - # No need to transform output - return True - return False - if result is td: return self - def deliver_result(result): - if isinstance(result, TensorDictBase) and not check_out(kwargs, result): - if not is_dynamo_compiling(): - non_tensordict = super(type(self), self).__getattribute__( - "_non_tensordict" - ) - else: - non_tensordict = self._non_tensordict - non_tensordict = dict(non_tensordict) - if copy_non_tensor: - # use tree_map to copy - non_tensordict = tree_map(lambda x: x, non_tensordict) - return self._from_tensordict(result, non_tensordict) - return result - if isinstance(result, tuple): - return tuple(deliver_result(r) for r in result) - return deliver_result(result) + return tuple(deliver_result(self, r, kwargs) for r in result) + return deliver_result(self, result, kwargs) return wrapped_func diff --git a/tensordict/utils.py b/tensordict/utils.py index 280b224a0..1cee55a2f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2694,3 +2694,65 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths): values, **kwargs, ) + + +@torch.library.custom_op("tensordict::_to_escape_compile", mutates_args=()) +def _to_escape_compile( + storage: torch.Tensor, device: torch.device, pin_memory: bool +) -> torch.Tensor: + if pin_memory: + storage = storage.pin_memory() + storage_cast = storage.to(device, non_blocking=True) + return storage_cast + + +@_to_escape_compile.register_fake +def _(storage: torch.Tensor, device: torch.device, pin_memory: bool) -> torch.Tensor: + return torch.empty_like(storage, device=device) + + +def view_and_pad(tensor: torch.Tensor, need_padding: bool) -> torch.Tensor: + result = tensor.view(-1).view(torch.uint8) + # result must always have a multiple of 8 elements + if need_padding: + pad = result.numel() % 8 + if pad != 0: + result = torch.cat([result, result.new_zeros(8 - pad)]) + return result + + +def view_old_as_new(v: torch.Tensor, oldv: torch.Tensor) -> torch.Tensor: + if oldv is None: + return v + v = v.view(oldv.dtype) + if v.numel() > oldv.numel(): + return v[: oldv.numel()].view(oldv.shape) + return v.view(oldv.shape) + + +@torch.compiler.disable() +def view_cat_split( + td, items, storage, need_padding, non_blocking, device, flat_size, set_on_tensor +): + items_flat = [view_and_pad(v, need_padding) for v in items if v is not None] + if non_blocking and device.type != "cuda": + # sync if needed + td._sync_all() + torch.cat(items_flat, out=storage) + # TODO: breaks with NJT + result = [ + view_old_as_new(v, oldv) + for (v, oldv) in zip(storage.split(flat_size), items, strict=True) + ] + if set_on_tensor: + for t_dest, t_src in zip(result, items): + if t_src is None: + # njt is decomposed + continue + t_src.set_( + t_dest.untyped_storage(), + storage_offset=t_dest.storage_offset(), + stride=t_dest.stride(), + size=t_dest.size(), + ) + return result