Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Avoid some recompiles of ReplayBuffer.extend\sample #2504

Merged
merged 9 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

import contextlib
import logging
import os

import os.path
import time
import unittest
from functools import wraps

# Get relative file path
Expand Down Expand Up @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
return deco_retry


# After calling this function, any log record whose name contains 'record_name'
# and is emitted from the logger that has qualified name 'logger_qname' is
# appended to the 'records' list.
# NOTE: This function is based on testing utilities for 'torch._logging'
def capture_log_records(records, logger_qname, record_name):
assert isinstance(records, list)
logger = logging.getLogger(logger_qname)

class EmitWrapper:
def __init__(self, old_emit):
self.old_emit = old_emit

def __call__(self, record):
nonlocal records
self.old_emit(record)
if record_name in record.name:
records.append(record)

for handler in logger.handlers:
new_emit = EmitWrapper(handler.emit)
contextlib.ExitStack().enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)


@pytest.fixture
def dtype_fixture():
dtype = torch.get_default_dtype()
Expand Down
77 changes: 76 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pytest
import torch

from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)

from mocking_classes import CountingEnv
from packaging import version
Expand Down Expand Up @@ -111,6 +116,7 @@
_has_gym = importlib.util.find_spec("gym") is not None
_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None
_os_is_windows = sys.platform == "win32"
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

torch_2_3 = version.parse(
".".join([str(s) for s in version.parse(str(torch.__version__)).release])
Expand Down Expand Up @@ -399,6 +405,75 @@ def data_iter():
) if cond else contextlib.nullcontext():
rb.extend(data2)

@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
# Compiling on Windows requires "cl" compiler to be installed.
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
# Our Windows CI jobs do not have "cl", so skip this test.
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
def test_extend_sample_recompile(
self, rb_type, sampler, writer, storage, size, datatype
):
if rb_type is not ReplayBuffer:
pytest.skip(
"Only replay buffer of type 'ReplayBuffer' is currently supported."
)
if sampler is not RandomSampler:
pytest.skip("Only sampler of type 'RandomSampler' is currently supported.")
if storage is not LazyTensorStorage:
pytest.skip(
"Only storage of type 'LazyTensorStorage' is currently supported."
)
if writer is not RoundRobinWriter:
pytest.skip(
"Only writer of type 'RoundRobinWriter' is currently supported."
)
if datatype == "tensordict":
pytest.skip("'tensordict' datatype is not currently supported.")

torch._dynamo.reset_code_caches()

storage_size = 10 * size
rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend_and_sample(data):
rb.extend(data)
return rb.sample()

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' and 'sample' currently cause
# recompilations, so avoid capturing those for now.
num_extend_before_capture = 2

for _ in range(num_extend_before_capture):
extend_and_sample(data)

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")

for _ in range(num_extend - num_extend_before_capture):
extend_and_sample(data)

assert len(rb) == storage_size
assert len(records) == 0

finally:
torch._logging.set_logs()

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down
33 changes: 32 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

import torch

from _utils_internal import get_default_devices
from _utils_internal import capture_log_records, get_default_devices
from packaging import version
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.mark.parametrize("value", ["True", "1", "true"])
def test_get_binary_env_var_positive(value):
Expand Down Expand Up @@ -380,6 +383,34 @@ def test_rng_decorator(device):
torch.testing.assert_close(s0b, s1b)


# Check that 'capture_log_records' captures records emitted when torch
# recompiles a function.
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
def test_capture_log_records_recompile():
torch.compiler.reset()

# This function recompiles each time it is called with a different string
# input.
@torch.compile
def str_to_tensor(s):
return bytes(s, "utf8")

str_to_tensor("a")

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")
str_to_tensor("b")

finally:
torch._logging.set_logs()

assert len(records) == 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
19 changes: 18 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,29 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _empty(self):
...

# NOTE: This property is used to enable compiled Storages. Calling
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this note to summarize what I know and don't know. I agree that it makes sense to merge this, but I think I can figure out what's going on here if I spend a little more time on it, so I'll submit another PR when I do

# `len(self)` on a TensorStorage should normally cause a graph break since
# it uses a `mp.Value`, and it does cause a break when the `len(self)` call
# happens within a method of TensorStorage itself. However, when the
# `len(self)` call happens in the Storage base class, for an unknown reason
# the compiler doesn't seem to recognize that there should be a graph break,
# and the lack of a break causes a recompile each time `len(self)` is called
# in this context. Also for an unknown reason, we can force the graph break
# to happen if we wrap the `len(self)` call with a `property`-decorated
# function. For another unknown reason, if we change
# `TensorStorage._len_value` from `mp.Value` to int, it seems like there
# should no longer be any need to recompile, but recompiles happen anyway.
# Ideally, this should all be investigated and understood in the future.
@property
kurtamohler marked this conversation as resolved.
Show resolved Hide resolved
def len(self):
return len(self)

def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
return torch.randint(
0,
len(self),
self.len,
(batch_size,),
generator=self._rng,
device=getattr(self, "device", None),
Expand Down
Loading