Skip to content

Commit

Permalink
[Feature] Improve performance of compiled ReplayBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Oct 31, 2024
1 parent edbf3de commit 0b03ad8
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 107 deletions.
41 changes: 29 additions & 12 deletions benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
)


class create_tensor_rb:
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
class create_compiled_tensor_rb:
def __init__(
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
):
self.storage = storage
self.rb = rb
self.sampler = sampler
self.size = size
self.storage_size = storage_size
self.data_size = data_size
self.iters = iters
self.compilable = compilable

def __call__(self):
kwargs = {}
if self.sampler is not None:
kwargs["sampler"] = self.sampler()
if self.storage is not None:
kwargs["storage"] = self.storage(10 * self.size)
kwargs["storage"] = self.storage(
self.storage_size, compilable=self.compilable
)

rb = self.rb(batch_size=3, **kwargs)
data = torch.randn(self.size, 1)
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
data = torch.randn(self.data_size, 1)
return ((rb, data, self.iters), {})


Expand All @@ -210,21 +216,32 @@ def fn(td):


@pytest.mark.parametrize(
"rb,storage,sampler,size,iters,compiled",
"rb,storage,sampler,storage_size,data_size,iters,compiled",
[
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
],
)
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
def test_rb_extend_sample(
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
):
if compiled:
torch._dynamo.reset_code_caches()

benchmark.pedantic(
extend_and_sample_compiled if compiled else extend_and_sample,
setup=create_tensor_rb(
setup=create_compiled_tensor_rb(
rb=rb,
storage=storage,
sampler=sampler,
size=size,
storage_size=storage_size,
data_size=data_size,
iters=iters,
compilable=compiled,
),
iterations=1,
warmup_rounds=10,
Expand Down
47 changes: 31 additions & 16 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,24 @@
)
@pytest.mark.parametrize("size", [3, 5, 100])
class TestComposableBuffers:
def _get_rb(self, rb_type, size, sampler, writer, storage):
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):

if storage is not None:
storage = storage(size)
storage = storage(size, compilable=compilable)

sampler_args = {}
if sampler is samplers.PrioritizedSampler:
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}

sampler = sampler(**sampler_args)
writer = writer()
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
writer = writer(compilable=compilable)
rb = rb_type(
storage=storage,
sampler=sampler,
writer=writer,
batch_size=3,
compilable=compilable,
)
return rb

def _get_datum(self, datatype):
Expand Down Expand Up @@ -421,8 +427,9 @@ def data_iter():
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
# Our Windows CI jobs do not have "cl", so skip this test.
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
@pytest.mark.parametrize("avoid_max_size", [False, True])
def test_extend_sample_recompile(
self, rb_type, sampler, writer, storage, size, datatype
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
):
if rb_type is not ReplayBuffer:
pytest.skip(
Expand All @@ -443,28 +450,36 @@ def test_extend_sample_recompile(

torch._dynamo.reset_code_caches()

storage_size = 10 * size
# Number of times to extend the replay buffer
num_extend = 10
data_size = size

# These two cases are separated because when the max storage size is
# reached, the code execution path changes, causing necessary
# recompiles.
if avoid_max_size:
storage_size = (num_extend + 1) * data_size
else:
storage_size = 2 * data_size

rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
compilable=True,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend_and_sample(data):
rb.extend(data)
return rb.sample()

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' and 'sample' currently cause
# recompilations, so avoid capturing those for now.
num_extend_before_capture = 2
# NOTE: The first three calls to 'extend' and 'sample' can currently
# cause recompilations, so avoid capturing those.
num_extend_before_capture = 3

for _ in range(num_extend_before_capture):
extend_and_sample(data)
Expand All @@ -477,12 +492,12 @@ def extend_and_sample(data):
for _ in range(num_extend - num_extend_before_capture):
extend_and_sample(data)

assert len(rb) == storage_size
assert len(records) == 0

finally:
torch._logging.set_logs()

assert len(rb) == min((num_extend * data_size), storage_size)
assert len(records) == 0

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down
27 changes: 23 additions & 4 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(
self.from_version = from_version
self.to_version = to_version
self.class_method = class_method
implement_for._setters.append(self)
self._is_supported = None

@staticmethod
def check_version(version: str, from_version: str | None, to_version: str | None):
Expand All @@ -304,6 +304,20 @@ def check_version(version: str, from_version: str | None, to_version: str | None
to_version is None or version < parse(to_version)
)

# If `implement_for` is used as a decorator, `torch.compile` adds guards
# around it. So instead, `implement_for` can be instantiated without
# decorating the function, and `implement_for.is_supported` can be called to
# explicitly switch between different implementation functions.
# TODO: Fix the decorator to avoid compiler guards.
@torch._dynamo.assume_constant_result
def is_supported(self):
if self._is_supported is None:
version = self.import_module(self.module_name)
self._is_supported = self.check_version(
version, self.from_version, self.to_version
)
return self._is_supported

@staticmethod
def get_class_that_defined_method(f):
"""Returns the class of a method, if it is defined, and None otherwise."""
Expand Down Expand Up @@ -381,6 +395,8 @@ def _delazify(self, func_name):
return out

def __call__(self, fn):
implement_for._setters.append(self)

# function names are unique
self.func_name = self.get_func_name(fn)
self.fn = fn
Expand All @@ -399,6 +415,11 @@ def _lazy_call_fn(*args, **kwargs):

return _lazy_call_fn

def unsupported(self, func_name):
raise ModuleNotFoundError(
f"Supported version of '{func_name}' has not been found."
)

def _call(self):

# If the module is missing replace the function with the mock.
Expand All @@ -408,9 +429,7 @@ def _call(self):

@wraps(fn)
def unsupported(*args, **kwargs):
raise ModuleNotFoundError(
f"Supported version of '{func_name}' has not been found."
)
self.unsupported(func_name)

self.do_set = False
# Return fitting implementation if it was encountered before.
Expand Down
22 changes: 18 additions & 4 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class ReplayBuffer:
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -217,11 +220,20 @@ def __init__(
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage = (
storage
if storage is not None
else ListStorage(max_size=1_000, compilable=compilable)
)
self._storage.attach(self)
self._sampler = sampler if sampler is not None else RandomSampler()
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer = (
writer
if writer is not None
else RoundRobinWriter(compilable=bool(compilable))
)
self._writer.register_storage(self._storage)

self._get_collate_fn(collate_fn)
Expand Down Expand Up @@ -600,7 +612,9 @@ def _add(self, data):
return index

def _extend(self, data: Sequence) -> torch.Tensor:
with self._replay_lock, self._write_lock:
is_compiling = torch.compiler.is_dynamo_compiling()
nc = contextlib.nullcontext()
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
if self.dim_extend > 0:
data = self._transpose(data)
index = self._writer.extend(data)
Expand Down Expand Up @@ -653,7 +667,7 @@ def update_priority(

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
with self._replay_lock if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext():
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
Expand Down
Loading

0 comments on commit 0b03ad8

Please sign in to comment.