diff --git a/benchmarks/test_replaybuffer_benchmark.py b/benchmarks/test_replaybuffer_benchmark.py index 34116ff9703..6336e7d3461 100644 --- a/benchmarks/test_replaybuffer_benchmark.py +++ b/benchmarks/test_replaybuffer_benchmark.py @@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size): ) -class create_tensor_rb: - def __init__(self, rb, storage, sampler, size=1_000_000, iters=100): +class create_compiled_tensor_rb: + def __init__( + self, rb, storage, sampler, storage_size, data_size, iters, compilable=False + ): self.storage = storage self.rb = rb self.sampler = sampler - self.size = size + self.storage_size = storage_size + self.data_size = data_size self.iters = iters + self.compilable = compilable def __call__(self): kwargs = {} if self.sampler is not None: kwargs["sampler"] = self.sampler() if self.storage is not None: - kwargs["storage"] = self.storage(10 * self.size) + kwargs["storage"] = self.storage( + self.storage_size, compilable=self.compilable + ) - rb = self.rb(batch_size=3, **kwargs) - data = torch.randn(self.size, 1) + rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs) + data = torch.randn(self.data_size, 1) return ((rb, data, self.iters), {}) @@ -210,21 +216,32 @@ def fn(td): @pytest.mark.parametrize( - "rb,storage,sampler,size,iters,compiled", + "rb,storage,sampler,storage_size,data_size,iters,compiled", [ - [ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True], - [ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True], + [ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False], ], ) -def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled): +def test_rb_extend_sample( + benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled +): + if compiled: + torch._dynamo.reset_code_caches() + benchmark.pedantic( extend_and_sample_compiled if compiled else extend_and_sample, - setup=create_tensor_rb( + setup=create_compiled_tensor_rb( rb=rb, storage=storage, sampler=sampler, - size=size, + storage_size=storage_size, + data_size=data_size, iters=iters, + compilable=compiled, ), iterations=1, warmup_rounds=10, diff --git a/test/test_rb.py b/test/test_rb.py index c14ccb64c04..6ceefef29d1 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -178,18 +178,24 @@ ) @pytest.mark.parametrize("size", [3, 5, 100]) class TestComposableBuffers: - def _get_rb(self, rb_type, size, sampler, writer, storage): + def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False): if storage is not None: - storage = storage(size) + storage = storage(size, compilable=compilable) sampler_args = {} if sampler is samplers.PrioritizedSampler: sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9} sampler = sampler(**sampler_args) - writer = writer() - rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3) + writer = writer(compilable=compilable) + rb = rb_type( + storage=storage, + sampler=sampler, + writer=writer, + batch_size=3, + compilable=compilable, + ) return rb def _get_datum(self, datatype): @@ -421,8 +427,9 @@ def data_iter(): # # Our Windows CI jobs do not have "cl", so skip this test. @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + @pytest.mark.parametrize("avoid_max_size", [False, True]) def test_extend_sample_recompile( - self, rb_type, sampler, writer, storage, size, datatype + self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size ): if rb_type is not ReplayBuffer: pytest.skip( @@ -443,15 +450,26 @@ def test_extend_sample_recompile( torch._dynamo.reset_code_caches() - storage_size = 10 * size + # Number of times to extend the replay buffer + num_extend = 10 + data_size = size + + # These two cases are separated because when the max storage size is + # reached, the code execution path changes, causing necessary + # recompiles. + if avoid_max_size: + storage_size = (num_extend + 1) * data_size + else: + storage_size = 2 * data_size + rb = self._get_rb( rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=storage_size, + compilable=True, ) - data_size = size data = self._get_data(datatype, size=data_size) @torch.compile @@ -459,12 +477,9 @@ def extend_and_sample(data): rb.extend(data) return rb.sample() - # Number of times to extend the replay buffer - num_extend = 30 - - # NOTE: The first two calls to 'extend' and 'sample' currently cause - # recompilations, so avoid capturing those for now. - num_extend_before_capture = 2 + # NOTE: The first three calls to 'extend' and 'sample' can currently + # cause recompilations, so avoid capturing those. + num_extend_before_capture = 3 for _ in range(num_extend_before_capture): extend_and_sample(data) @@ -477,12 +492,12 @@ def extend_and_sample(data): for _ in range(num_extend - num_extend_before_capture): extend_and_sample(data) - assert len(rb) == storage_size - assert len(records) == 0 - finally: torch._logging.set_logs() + assert len(rb) == min((num_extend * data_size), storage_size) + assert len(records) == 0 + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: pytest.skip( @@ -806,6 +821,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): s = new_replay_buffer.sample() assert (s.exclude("index") == 1).all() + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + # This test checks if the `torch._dynamo.disable` wrapper around + # `TensorStorage._rand_given_ndim` is still necessary. + def test__rand_given_ndim_recompile(self): + torch._dynamo.reset_code_caches() + + # Number of times to extend the replay buffer + num_extend = 10 + data_size = 100 + storage_size = (num_extend + 1) * data_size + sample_size = 3 + + storage = LazyTensorStorage(storage_size, compilable=True) + sampler = RandomSampler() + + # Override to avoid the `torch._dynamo.disable` wrapper + storage._rand_given_ndim = storage._rand_given_ndim_impl + + @torch.compile + def extend_and_sample(data): + storage.set(torch.arange(data_size) + len(storage), data) + return sampler.sample(storage, sample_size) + + data = torch.randint(100, (data_size, 1)) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend): + extend_and_sample(data) + + finally: + torch._logging.set_logs() + + assert len(storage) == num_extend * data_size + assert len(records) == 8, ( + "If this ever decreases, that's probably good news and the " + "`torch._dynamo.disable` wrapper around " + "`TensorStorage._rand_given_ndim` can be removed." + ) + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) def test_extend_lazystack(self, storage_type): diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 3af44ee0ed7..31e00614fd9 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -252,6 +252,11 @@ class implement_for: Keyword Args: class_method (bool, optional): if ``True``, the function will be written as a class method. Defaults to ``False``. + compilable (bool, optional): If ``False``, the module import happens + only on the first call to the wrapped function. If ``True``, the + module import happens when the wrapped function is initialized. This + allows the wrapped function to work well with ``torch.compile``. + Defaults to ``False``. Examples: >>> @implement_for("gym", "0.13", "0.14") @@ -290,11 +295,13 @@ def __init__( to_version: str = None, *, class_method: bool = False, + compilable: bool = False, ): self.module_name = module_name self.from_version = from_version self.to_version = to_version self.class_method = class_method + self._compilable = compilable implement_for._setters.append(self) @staticmethod @@ -386,18 +393,27 @@ def __call__(self, fn): self.fn = fn implement_for._lazy_impl[self.func_name].append(self._call) - @wraps(fn) - def _lazy_call_fn(*args, **kwargs): - # first time we call the function, we also do the replacement. - # This will cause the imports to occur only during the first call to fn + if self._compilable: + _call_fn = self._delazify(self.func_name) - result = self._delazify(self.func_name)(*args, **kwargs) - return result + if self.class_method: + return classmethod(_call_fn) - if self.class_method: - return classmethod(_lazy_call_fn) + return _call_fn + else: + + @wraps(fn) + def _lazy_call_fn(*args, **kwargs): + # first time we call the function, we also do the replacement. + # This will cause the imports to occur only during the first call to fn + + result = self._delazify(self.func_name)(*args, **kwargs) + return result + + if self.class_method: + return classmethod(_lazy_call_fn) - return _lazy_call_fn + return _lazy_call_fn def _call(self): diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 2672c90092f..5e7b80d7bed 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -19,6 +19,11 @@ import torch +try: + from torch.compiler import is_dynamo_compiling +except ImportError: + from torch._dynamo import is_compiling as is_dynamo_compiling + from tensordict import ( is_tensor_collection, is_tensorclass, @@ -132,6 +137,9 @@ class ReplayBuffer: .. warning:: As of now, the generator has no effect on the transforms. shared (bool, optional): whether the buffer will be shared using multiprocessing or not. Defaults to ``False``. + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> import torch @@ -217,11 +225,20 @@ def __init__( checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821 generator: torch.Generator | None = None, shared: bool = False, + compilable: bool = None, ) -> None: - self._storage = storage if storage is not None else ListStorage(max_size=1_000) + self._storage = ( + storage + if storage is not None + else ListStorage(max_size=1_000, compilable=compilable) + ) self._storage.attach(self) self._sampler = sampler if sampler is not None else RandomSampler() - self._writer = writer if writer is not None else RoundRobinWriter() + self._writer = ( + writer + if writer is not None + else RoundRobinWriter(compilable=bool(compilable)) + ) self._writer.register_storage(self._storage) self._get_collate_fn(collate_fn) @@ -600,7 +617,9 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - with self._replay_lock, self._write_lock: + is_compiling = is_dynamo_compiling() + nc = contextlib.nullcontext() + with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) @@ -653,7 +672,7 @@ def update_priority( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: + with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 21cbfce7b31..dc2ddff0ef9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -61,10 +61,15 @@ class Storage: _rng: torch.Generator | None = None def __init__( - self, max_size: int, checkpointer: StorageCheckpointerBase | None = None + self, + max_size: int, + checkpointer: StorageCheckpointerBase | None = None, + compilable: bool = False, ) -> None: self.max_size = int(max_size) self.checkpointer = checkpointer + self._compilable = compilable + self._attached_entities_set = set() @property def checkpointer(self): @@ -84,11 +89,14 @@ def _is_full(self): def _attached_entities(self): # RBs that use a given instance of Storage should add # themselves to this set. - _attached_entities = self.__dict__.get("_attached_entities_set", None) - if _attached_entities is None: - _attached_entities = set() - self.__dict__["_attached_entities_set"] = _attached_entities - return _attached_entities + _attached_entities_set = getattr(self, "_attached_entities_set", None) + if _attached_entities_set is None: + self._attached_entities_set = _attached_entities_set = set() + return _attached_entities_set + + @torch._dynamo.assume_constant_result + def _attached_entities_iter(self): + return list(self._attached_entities) @abc.abstractmethod def set(self, cursor: int, data: Any, *, set_cursor: bool = True): @@ -144,29 +152,12 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _empty(self): ... - # NOTE: This property is used to enable compiled Storages. Calling - # `len(self)` on a TensorStorage should normally cause a graph break since - # it uses a `mp.Value`, and it does cause a break when the `len(self)` call - # happens within a method of TensorStorage itself. However, when the - # `len(self)` call happens in the Storage base class, for an unknown reason - # the compiler doesn't seem to recognize that there should be a graph break, - # and the lack of a break causes a recompile each time `len(self)` is called - # in this context. Also for an unknown reason, we can force the graph break - # to happen if we wrap the `len(self)` call with a `property`-decorated - # function. For another unknown reason, if we change - # `TensorStorage._len_value` from `mp.Value` to int, it seems like there - # should no longer be any need to recompile, but recompiles happen anyway. - # Ideally, this should all be investigated and understood in the future. - @property - def len(self): - return len(self) - def _rand_given_ndim(self, batch_size): # a method to return random indices given the storage ndim if self.ndim == 1: return torch.randint( 0, - self.len, + len(self), (batch_size,), generator=self._rng, device=getattr(self, "device", None), @@ -241,10 +232,10 @@ class ListStorage(Storage): _default_checkpointer = ListStorageCheckpointer - def __init__(self, max_size: int | None = None): + def __init__(self, max_size: int | None = None, compilable: bool = False): if max_size is None: max_size = torch.iinfo(torch.int64).max - super().__init__(max_size) + super().__init__(max_size, compilable=compilable) self._storage = [] def set( @@ -381,6 +372,9 @@ class TensorStorage(Storage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -440,6 +434,7 @@ def __init__( *, device: torch.device = "cpu", ndim: int = 1, + compilable: bool = False, ): if not ((storage is None) ^ (max_size is None)): if storage is None: @@ -455,7 +450,7 @@ def __init__( else: max_size = tree_flatten(storage)[0][0].shape[0] self.ndim = ndim - super().__init__(max_size) + super().__init__(max_size, compilable=compilable) self.initialized = storage is not None if self.initialized: self._len = max_size @@ -474,16 +469,24 @@ def __init__( @property def _len(self): _len_value = self.__dict__.get("_len_value", None) - if _len_value is None: - _len_value = self._len_value = mp.Value("i", 0) - return _len_value.value + if not self._compilable: + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + return _len_value.value + else: + if _len_value is None: + _len_value = self._len_value = 0 + return _len_value @_len.setter def _len(self, value): - _len_value = self.__dict__.get("_len_value", None) - if _len_value is None: - _len_value = self._len_value = mp.Value("i", 0) - _len_value.value = value + if not self._compilable: + _len_value = self.__dict__.get("_len_value", None) + if _len_value is None: + _len_value = self._len_value = mp.Value("i", 0) + _len_value.value = value + else: + self._len_value = value @property def _total_shape(self): @@ -550,7 +553,16 @@ def shape(self): if _total_shape is not None: return torch.Size([self._len_along_dim0] + list(_total_shape[1:])) + # TODO: Without this disable, compiler recompiles for back-to-back calls. + # Figuring out a way to avoid this disable would give better performance. + @torch._dynamo.disable() def _rand_given_ndim(self, batch_size): + return self._rand_given_ndim_impl(batch_size) + + # At the moment, this is separated into its own function so that we can test + # it without the `torch._dynamo.disable` and detect if future updates to the + # compiler fix the recompile issue. + def _rand_given_ndim_impl(self, batch_size): if self.ndim == 1: return super()._rand_given_ndim(batch_size) shape = self.shape @@ -623,8 +635,11 @@ def assert_is_sharable(tensor): def __setstate__(self, state): len = state.pop("len__context", None) if len is not None: - _len_value = mp.Value("i", len) - state["_len_value"] = _len_value + if not state["_compilable"]: + state["_len_value"] = len + else: + _len_value = mp.Value("i", len) + state["_len_value"] = _len_value self.__dict__.update(state) def state_dict(self) -> Dict[str, Any]: @@ -674,7 +689,7 @@ def load_state_dict(self, state_dict): self.initialized = state_dict["initialized"] self._len = state_dict["_len"] - @implement_for("torch", "2.3") + @implement_for("torch", "2.3", compilable=True) def _set_tree_map(self, cursor, data, storage): def set_tensor(datum, store): store[cursor] = datum @@ -682,7 +697,7 @@ def set_tensor(datum, store): # this won't be available until v2.3 tree_map(set_tensor, data, storage) - @implement_for("torch", "2.0", "2.3") + @implement_for("torch", "2.0", "2.3", compilable=True) def _set_tree_map(self, cursor, data, storage): # noqa: 534 # flatten data and cursor data_flat = tree_flatten(data)[0] @@ -700,7 +715,7 @@ def _get_new_len(self, data, cursor): numel = leaf.shape[:ndim].numel() self._len = min(self._len + numel, self.max_size) - @implement_for("torch", "2.0", None) + @implement_for("torch", "2.0", None, compilable=True) def set( self, cursor: Union[int, Sequence[int], slice], @@ -742,7 +757,7 @@ def set( else: self._set_tree_map(cursor, data, self._storage) - @implement_for("torch", None, "2.0") + @implement_for("torch", None, "2.0", compilable=True) def set( # noqa: F811 self, cursor: Union[int, Sequence[int], slice], @@ -893,6 +908,9 @@ class LazyTensorStorage(TensorStorage): measuring the storage size. For instance, a storage of shape ``[3, 4]`` has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``. Defaults to ``1``. + compilable (bool, optional): whether the storage is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. Examples: >>> data = TensorDict({ @@ -952,14 +970,24 @@ def __init__( *, device: torch.device = "cpu", ndim: int = 1, + compilable: bool = False, ): - super().__init__(storage=None, max_size=max_size, device=device, ndim=ndim) + super().__init__( + storage=None, + max_size=max_size, + device=device, + ndim=ndim, + compilable=compilable, + ) def _init( self, data: Union[TensorDictBase, torch.Tensor, "PyTree"], # noqa: F821 ) -> None: - torchrl_logger.debug("Creating a TensorStorage...") + if not self._compilable: + # TODO: Investigate why this seems to have a performance impact with + # the compiler + torchrl_logger.debug("Creating a TensorStorage...") if self.device == "auto": self.device = data.device @@ -1089,8 +1117,9 @@ def __init__( device: torch.device = "cpu", ndim: int = 1, existsok: bool = False, + compilable: bool = False, ): - super().__init__(max_size, ndim=ndim) + super().__init__(max_size, ndim=ndim, compilable=compilable) self.initialized = False self.scratch_dir = None self.existsok = existsok @@ -1264,10 +1293,6 @@ def _rng(self, value): for storage in self._storages: storage._rng = value - @property - def _attached_entities(self): - return set() - def extend(self, value): raise RuntimeError diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 3a95c3975cc..7fb865453d6 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -40,8 +40,9 @@ class Writer(ABC): _storage: Storage _rng: torch.Generator | None = None - def __init__(self) -> None: + def __init__(self, compilable: bool = False) -> None: self._storage = None + self._compilable = compilable def register_storage(self, storage: Storage) -> None: self._storage = storage @@ -138,10 +139,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class RoundRobinWriter(Writer): - """A RoundRobin Writer class for composable replay buffers.""" + """A RoundRobin Writer class for composable replay buffers. - def __init__(self, **kw) -> None: - super().__init__(**kw) + Args: + compilable (bool, optional): whether the writer is compilable. + If ``True``, the writer cannot be shared between multiple processes. + Defaults to ``False``. + + """ + + def __init__(self, compilable: bool = False) -> None: + super().__init__(compilable=compilable) self._cursor = 0 def dumps(self, path): @@ -197,7 +205,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -213,30 +221,46 @@ def _empty(self): @property def _cursor(self): _cursor_value = self.__dict__.get("_cursor_value", None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value("i", 0) - return _cursor_value.value + if not self._compilable: + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + return _cursor_value.value + else: + if _cursor_value is None: + _cursor_value = self._cursor_value = 0 + return _cursor_value @_cursor.setter def _cursor(self, value): - _cursor_value = self.__dict__.get("_cursor_value", None) - if _cursor_value is None: - _cursor_value = self._cursor_value = mp.Value("i", 0) - _cursor_value.value = value + if not self._compilable: + _cursor_value = self.__dict__.get("_cursor_value", None) + if _cursor_value is None: + _cursor_value = self._cursor_value = mp.Value("i", 0) + _cursor_value.value = value + else: + self._cursor_value = value @property def _write_count(self): _write_count = self.__dict__.get("_write_count_value", None) - if _write_count is None: - _write_count = self._write_count_value = mp.Value("i", 0) - return _write_count.value + if not self._compilable: + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + return _write_count.value + else: + if _write_count is None: + _write_count = self._write_count_value = 0 + return _write_count @_write_count.setter def _write_count(self, value): - _write_count = self.__dict__.get("_write_count_value", None) - if _write_count is None: - _write_count = self._write_count_value = mp.Value("i", 0) - _write_count.value = value + if not self._compilable: + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + _write_count.value = value + else: + self._write_count_value = value def __getstate__(self): state = super().__getstate__() @@ -249,7 +273,10 @@ def __getstate__(self): def __setstate__(self, state): cursor = state.pop("cursor__context", None) if cursor is not None: - _cursor_value = mp.Value("i", cursor) + if not state["_compilable"]: + _cursor_value = mp.Value("i", cursor) + else: + _cursor_value = cursor state["_cursor_value"] = _cursor_value self.__dict__.update(state)