Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Avoid some recompiles of ReplayBuffer.extend\sample #2504

Merged
merged 9 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

import contextlib
import logging
import os

import os.path
import time
import unittest
from functools import wraps

# Get relative file path
Expand Down Expand Up @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
return deco_retry


# After calling this function, any log record whose name contains 'record_name'
# and is emitted from the logger that has qualified name 'logger_qname' is
# appended to the 'records' list.
# NOTE: This function is based on testing utilities for 'torch._logging'
def capture_log_records(records, logger_qname, record_name):
assert isinstance(records, list)
logger = logging.getLogger(logger_qname)

class EmitWrapper:
def __init__(self, old_emit):
self.old_emit = old_emit

def __call__(self, record):
nonlocal records
self.old_emit(record)
if record_name in record.name:
records.append(record)

for handler in logger.handlers:
new_emit = EmitWrapper(handler.emit)
contextlib.ExitStack().enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)


@pytest.fixture
def dtype_fixture():
dtype = torch.get_default_dtype()
Expand Down
74 changes: 73 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pytest
import torch

from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)

from mocking_classes import CountingEnv
from packaging import version
Expand Down Expand Up @@ -399,6 +404,73 @@ def data_iter():
) if cond else contextlib.nullcontext():
rb.extend(data2)

def test_extend_sample_recompile(
self, rb_type, sampler, writer, storage, size, datatype
):
if _os_is_windows:
# Compiling on Windows requires "cl" compiler to be installed.
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
# Our Windows CI jobs do not have "cl", so skip this test.
pytest.skip("This test does not support Windows.")
if rb_type is not ReplayBuffer:
pytest.skip(
"Only replay buffer of type 'ReplayBuffer' is currently supported."
)
if sampler is not RandomSampler:
pytest.skip("Only sampler of type 'RandomSampler' is currently supported.")
if storage is not LazyTensorStorage:
pytest.skip(
"Only storage of type 'LazyTensorStorage' is currently supported."
)
if writer is not RoundRobinWriter:
pytest.skip(
"Only writer of type 'RoundRobinWriter' is currently supported."
)
if datatype == "tensordict":
pytest.skip("'tensordict' datatype is not currently supported.")

torch.compiler.reset()

storage_size = 10 * size
rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend_and_sample(data):
rb.extend(data)
return rb.sample()

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' and 'sample' currently cause
# recompilations, so avoid capturing those for now.
num_extend_before_capture = 2

for _ in range(num_extend_before_capture):
extend_and_sample(data)

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")

for _ in range(num_extend - num_extend_before_capture):
extend_and_sample(data)

assert len(rb) == storage_size
assert len(records) == 0

finally:
torch._logging.set_logs()

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down
27 changes: 26 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from _utils_internal import get_default_devices
from _utils_internal import capture_log_records, get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
Expand Down Expand Up @@ -380,6 +380,31 @@ def test_rng_decorator(device):
torch.testing.assert_close(s0b, s1b)


# Check that 'capture_log_records' captures records emitted when torch
# recompiles a function.
def test_capture_log_records_recompile():
torch.compiler.reset()

# This function recompiles each time it is called with a different string
# input.
@torch.compile
def str_to_tensor(s):
return bytes(s, "utf8")

str_to_tensor("a")

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")
str_to_tensor("b")

finally:
torch._logging.set_logs()

assert len(records) == 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
6 changes: 5 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,14 @@ def _empty(self):

def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if isinstance(self, TensorStorage):
kurtamohler marked this conversation as resolved.
Show resolved Hide resolved
storage_len = self._len
else:
storage_len = len(self)
if self.ndim == 1:
return torch.randint(
0,
len(self),
storage_len,
(batch_size,),
generator=self._rng,
device=getattr(self, "device", None),
Expand Down
Loading