Skip to content

Commit

Permalink
[Feature] Add test for recompiles of ReplayBuffer.extend
Browse files Browse the repository at this point in the history
ghstack-source-id: f50d4ecbc9c2fe0f5334e3cabc0e51cd4f930b80
Pull Request resolved: #2504
  • Loading branch information
kurtamohler committed Oct 18, 2024
1 parent a27514c commit 35ce37d
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 2 deletions.
27 changes: 27 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

import contextlib
import logging
import os

import os.path
import time
import unittest
from functools import wraps

# Get relative file path
Expand Down Expand Up @@ -204,6 +206,31 @@ def f_retry(*args, **kwargs):
return deco_retry


# After calling this function, any log record whose name contains 'record_name'
# and is emitted from the logger that has qualified name 'logger_qname' is
# appended to the 'records' list.
# NOTE: This function is based on testing utilities for 'torch._logging'
def capture_log_records(records, logger_qname, record_name):
assert isinstance(records, list)
logger = logging.getLogger(logger_qname)

class EmitWrapper:
def __init__(self, old_emit):
self.old_emit = old_emit

def __call__(self, record):
nonlocal records
self.old_emit(record)
if record_name in record.name:
records.append(record)

for handler in logger.handlers:
new_emit = EmitWrapper(handler.emit)
contextlib.ExitStack().enter_context(
unittest.mock.patch.object(handler, "emit", new_emit)
)


@pytest.fixture
def dtype_fixture():
dtype = torch.get_default_dtype()
Expand Down
64 changes: 63 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pytest
import torch

from _utils_internal import CARTPOLE_VERSIONED, get_default_devices, make_tc
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)

from mocking_classes import CountingEnv
from packaging import version
Expand Down Expand Up @@ -399,6 +404,63 @@ def data_iter():
) if cond else contextlib.nullcontext():
rb.extend(data2)

def test_extend_recompile(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is not ReplayBuffer:
pytest.skip(
"Only replay buffer of type 'ReplayBuffer' is currently supported."
)
if sampler in (PrioritizedSampler,):
pytest.skip(f"Sampler of type '{sampler.__name__}' is not yet supported.")
if storage is not LazyTensorStorage:
pytest.skip(
"Only storage of type 'LazyTensorStorage' is currently supported."
)
if writer is not RoundRobinWriter:
pytest.skip(
"Only writer of type 'RoundRobinWriter' is currently supported."
)

torch.compiler.reset()

storage_size = 10 * size
rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend(data):
rb.extend(data)

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' currently cause recompilations,
# so avoid capturing those for now.
num_extend_before_capture = 2

for _ in range(num_extend_before_capture):
extend(data)

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")

for _ in range(num_extend - num_extend_before_capture):
extend(data)

assert len(records) == 0
assert len(rb) == storage_size

finally:
torch._logging.set_logs()

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down
27 changes: 26 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from _utils_internal import get_default_devices
from _utils_internal import capture_log_records, get_default_devices
from torchrl._utils import _rng_decorator, get_binary_env_var, implement_for

from torchrl.envs.libs.gym import gym_backend, GymWrapper, set_gym_backend
Expand Down Expand Up @@ -380,6 +380,31 @@ def test_rng_decorator(device):
torch.testing.assert_close(s0b, s1b)


# Check that 'capture_log_records' captures records emitted when torch
# recompiles a function.
def test_capture_log_records_recompile():
torch.compiler.reset()

# This function recompiles each time it is called with a different string
# input.
@torch.compile
def str_to_tensor(s):
return bytes(s, "utf8")

str_to_tensor("a")

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")
str_to_tensor("b")

finally:
torch._logging.set_logs()

assert len(records) == 1


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 35ce37d

Please sign in to comment.