Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Improve performance of compiled ReplayBuffer #2529

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
)


class create_tensor_rb:
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
class create_compiled_tensor_rb:
def __init__(
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
):
self.storage = storage
self.rb = rb
self.sampler = sampler
self.size = size
self.storage_size = storage_size
self.data_size = data_size
self.iters = iters
self.compilable = compilable

def __call__(self):
kwargs = {}
if self.sampler is not None:
kwargs["sampler"] = self.sampler()
if self.storage is not None:
kwargs["storage"] = self.storage(10 * self.size)
kwargs["storage"] = self.storage(
self.storage_size, compilable=self.compilable
)

rb = self.rb(batch_size=3, **kwargs)
data = torch.randn(self.size, 1)
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
data = torch.randn(self.data_size, 1)
return ((rb, data, self.iters), {})


Expand All @@ -210,21 +216,32 @@ def fn(td):


@pytest.mark.parametrize(
"rb,storage,sampler,size,iters,compiled",
"rb,storage,sampler,storage_size,data_size,iters,compiled",
[
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
],
)
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
def test_rb_extend_sample(
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
):
if compiled:
torch._dynamo.reset_code_caches()

benchmark.pedantic(
extend_and_sample_compiled if compiled else extend_and_sample,
setup=create_tensor_rb(
setup=create_compiled_tensor_rb(
rb=rb,
storage=storage,
sampler=sampler,
size=size,
storage_size=storage_size,
data_size=data_size,
iters=iters,
compilable=compiled,
),
iterations=1,
warmup_rounds=10,
Expand Down
93 changes: 77 additions & 16 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,24 @@
)
@pytest.mark.parametrize("size", [3, 5, 100])
class TestComposableBuffers:
def _get_rb(self, rb_type, size, sampler, writer, storage):
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):

if storage is not None:
storage = storage(size)
storage = storage(size, compilable=compilable)

sampler_args = {}
if sampler is samplers.PrioritizedSampler:
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}

sampler = sampler(**sampler_args)
writer = writer()
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
writer = writer(compilable=compilable)
rb = rb_type(
storage=storage,
sampler=sampler,
writer=writer,
batch_size=3,
compilable=compilable,
)
return rb

def _get_datum(self, datatype):
Expand Down Expand Up @@ -421,8 +427,9 @@ def data_iter():
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
# Our Windows CI jobs do not have "cl", so skip this test.
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
@pytest.mark.parametrize("avoid_max_size", [False, True])
def test_extend_sample_recompile(
self, rb_type, sampler, writer, storage, size, datatype
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
):
if rb_type is not ReplayBuffer:
pytest.skip(
Expand All @@ -443,28 +450,36 @@ def test_extend_sample_recompile(

torch._dynamo.reset_code_caches()

storage_size = 10 * size
# Number of times to extend the replay buffer
num_extend = 10
data_size = size

# These two cases are separated because when the max storage size is
# reached, the code execution path changes, causing necessary
# recompiles.
if avoid_max_size:
storage_size = (num_extend + 1) * data_size
else:
storage_size = 2 * data_size

rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
compilable=True,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend_and_sample(data):
rb.extend(data)
return rb.sample()

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' and 'sample' currently cause
# recompilations, so avoid capturing those for now.
num_extend_before_capture = 2
# NOTE: The first three calls to 'extend' and 'sample' can currently
# cause recompilations, so avoid capturing those.
num_extend_before_capture = 3

for _ in range(num_extend_before_capture):
extend_and_sample(data)
Expand All @@ -477,12 +492,12 @@ def extend_and_sample(data):
for _ in range(num_extend - num_extend_before_capture):
extend_and_sample(data)

assert len(rb) == storage_size
assert len(records) == 0

finally:
torch._logging.set_logs()

assert len(rb) == min((num_extend * data_size), storage_size)
assert len(records) == 0

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down Expand Up @@ -806,6 +821,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
s = new_replay_buffer.sample()
assert (s.exclude("index") == 1).all()

@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
# This test checks if the `torch._dynamo.disable` wrapper around
# `TensorStorage._rand_given_ndim` is still necessary.
def test__rand_given_ndim_recompile(self):
torch._dynamo.reset_code_caches()

# Number of times to extend the replay buffer
num_extend = 10
data_size = 100
storage_size = (num_extend + 1) * data_size
sample_size = 3

storage = LazyTensorStorage(storage_size, compilable=True)
sampler = RandomSampler()

# Override to avoid the `torch._dynamo.disable` wrapper
storage._rand_given_ndim = storage._rand_given_ndim_impl

@torch.compile
def extend_and_sample(data):
storage.set(torch.arange(data_size) + len(storage), data)
return sampler.sample(storage, sample_size)

data = torch.randint(100, (data_size, 1))

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")

for _ in range(num_extend):
extend_and_sample(data)

finally:
torch._logging.set_logs()

assert len(storage) == num_extend * data_size
assert len(records) == 8, (
"If this ever decreases, that's probably good news and the "
"`torch._dynamo.disable` wrapper around "
"`TensorStorage._rand_given_ndim` can be removed."
)

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_extend_lazystack(self, storage_type):

Expand Down
34 changes: 25 additions & 9 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ class implement_for:
Keyword Args:
class_method (bool, optional): if ``True``, the function will be written as a class method.
Defaults to ``False``.
compilable (bool, optional): If ``False``, the module import happens
only on the first call to the wrapped function. If ``True``, the
module import happens when the wrapped function is initialized. This
allows the wrapped function to work well with ``torch.compile``.
Defaults to ``False``.

Examples:
>>> @implement_for("gym", "0.13", "0.14")
Expand Down Expand Up @@ -290,11 +295,13 @@ def __init__(
to_version: str = None,
*,
class_method: bool = False,
compilable: bool = False,
kurtamohler marked this conversation as resolved.
Show resolved Hide resolved
):
self.module_name = module_name
self.from_version = from_version
self.to_version = to_version
self.class_method = class_method
self._compilable = compilable
implement_for._setters.append(self)

@staticmethod
Expand Down Expand Up @@ -386,18 +393,27 @@ def __call__(self, fn):
self.fn = fn
implement_for._lazy_impl[self.func_name].append(self._call)

@wraps(fn)
def _lazy_call_fn(*args, **kwargs):
# first time we call the function, we also do the replacement.
# This will cause the imports to occur only during the first call to fn
if self._compilable:
_call_fn = self._delazify(self.func_name)

result = self._delazify(self.func_name)(*args, **kwargs)
return result
if self.class_method:
return classmethod(_call_fn)

if self.class_method:
return classmethod(_lazy_call_fn)
return _call_fn
else:

@wraps(fn)
def _lazy_call_fn(*args, **kwargs):
# first time we call the function, we also do the replacement.
# This will cause the imports to occur only during the first call to fn

result = self._delazify(self.func_name)(*args, **kwargs)
return result

if self.class_method:
return classmethod(_lazy_call_fn)

return _lazy_call_fn
return _lazy_call_fn

def _call(self):

Expand Down
27 changes: 23 additions & 4 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

import torch

try:
from torch.compiler import is_dynamo_compiling
except ImportError:
from torch._dynamo import is_compiling as is_dynamo_compiling

from tensordict import (
is_tensor_collection,
is_tensorclass,
Expand Down Expand Up @@ -132,6 +137,9 @@ class ReplayBuffer:
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.

Examples:
>>> import torch
Expand Down Expand Up @@ -217,11 +225,20 @@ def __init__(
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage = (
storage
if storage is not None
else ListStorage(max_size=1_000, compilable=compilable)
)
self._storage.attach(self)
self._sampler = sampler if sampler is not None else RandomSampler()
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer = (
writer
if writer is not None
else RoundRobinWriter(compilable=bool(compilable))
)
self._writer.register_storage(self._storage)

self._get_collate_fn(collate_fn)
Expand Down Expand Up @@ -600,7 +617,9 @@ def _add(self, data):
return index

def _extend(self, data: Sequence) -> torch.Tensor:
with self._replay_lock, self._write_lock:
is_compiling = is_dynamo_compiling()
nc = contextlib.nullcontext()
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
if self.dim_extend > 0:
data = self._transpose(data)
index = self._writer.extend(data)
Expand Down Expand Up @@ -653,7 +672,7 @@ def update_priority(

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext():
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
Expand Down
Loading
Loading