Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support compiling ReplayBuffer.extend/sample without recompile #2501

Open
1 task done
kurtamohler opened this issue Oct 18, 2024 · 5 comments
Open
1 task done
Assignees
Labels
enhancement New feature or request

Comments

@kurtamohler
Copy link
Collaborator

kurtamohler commented Oct 18, 2024

Motivation

Compiling a back-to-back call to ReplayBuffer.extend and ReplayBuffer.sample, and then calling it multiple times causes the function to be recompiled each time.

import torch
import torchrl

torch._logging.set_logs(recompiles=True)

rb = torchrl.data.ReplayBuffer(
    storage=torchrl.data.LazyTensorStorage(1000)
)

@torch.compile
def extend_and_sample(data):
    rb.extend(data)
    return rb.sample(2)

for idx in range(15):
    print('---------------------')
    print(f'iteration: {idx}')
    print(f'len: {len(rb.storage)}')
    data = torch.randn(idx + 1, 1)
    extend_and_sample(data)

Running the above script gives the following, showing that the first 9 calls cause recompilations. Then it hits the cache limit, so the calls after that don't get compiled anymore, and it's just running the eager function at that point (per pytorch docs: https://pytorch.org/docs/stable/generated/torch.compile.html).

Click to expand/collapse
---------------------
iteration: 0
len: 0
/home/endoplasm/miniconda/envs/torchrl-0/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:733: UserWarning: Graph break due to unsupported builtin None.SemLock.acquire. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
  torch._dynamo.utils.warn_once(msg)
V1021 16:37:15.101000 599972 site-packages/torch/_dynamo/guards.py:2842] [11/1] [__recompiles] Recompiling function _lazy_call_fn in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:389
V1021 16:37:15.101000 599972 site-packages/torch/_dynamo/guards.py:2842] [11/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:15.101000 599972 site-packages/torch/_dynamo/guards.py:2842] [11/1] [__recompiles]     - 11/0: L['self'].func_name == 'torchrl.data.replay_buffers.storages.TensorStorage.set'
V1021 16:37:15.108000 599972 site-packages/torch/_dynamo/guards.py:2842] [12/1] [__recompiles] Recompiling function torch_dynamo_resume_in__lazy_call_fn_at_394 in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:394
V1021 16:37:15.108000 599972 site-packages/torch/_dynamo/guards.py:2842] [12/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:15.108000 599972 site-packages/torch/_dynamo/guards.py:2842] [12/1] [__recompiles]     - 12/0: len(L['args']) == 3                                         
---------------------
iteration: 1
len: 1
V1021 16:37:16.300000 599972 site-packages/torch/_dynamo/guards.py:2842] [0/1] [__recompiles] Recompiling function extend_and_sample in /home/endoplasm/tmp/rb_compiled.py:10
V1021 16:37:16.300000 599972 site-packages/torch/_dynamo/guards.py:2842] [0/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.300000 599972 site-packages/torch/_dynamo/guards.py:2842] [0/1] [__recompiles]     - 0/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.327000 599972 site-packages/torch/_dynamo/guards.py:2842] [1/1] [__recompiles] Recompiling function extend in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/replay_buffers.py:610
V1021 16:37:16.327000 599972 site-packages/torch/_dynamo/guards.py:2842] [1/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.327000 599972 site-packages/torch/_dynamo/guards.py:2842] [1/1] [__recompiles]     - 1/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.340000 599972 site-packages/torch/_dynamo/guards.py:2842] [13/1] [__recompiles] Recompiling function set in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:686
V1021 16:37:16.340000 599972 site-packages/torch/_dynamo/guards.py:2842] [13/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.340000 599972 site-packages/torch/_dynamo/guards.py:2842] [13/1] [__recompiles]     - 13/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.357000 599972 site-packages/torch/_dynamo/guards.py:2842] [18/1] [__recompiles] Recompiling function torch_dynamo_resume_in_set_at_713 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:713
V1021 16:37:16.357000 599972 site-packages/torch/_dynamo/guards.py:2842] [18/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.357000 599972 site-packages/torch/_dynamo/guards.py:2842] [18/1] [__recompiles]     - 18/0: tensor 'L['data']' size mismatch at index 0. expected 1, actual 2
V1021 16:37:16.406000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/1] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.406000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.406000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/1] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
V1021 16:37:16.423000 599972 site-packages/torch/_dynamo/guards.py:2842] [31/1] [__recompiles] Recompiling function <lambda> in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:802
V1021 16:37:16.423000 599972 site-packages/torch/_dynamo/guards.py:2842] [31/1] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.423000 599972 site-packages/torch/_dynamo/guards.py:2842] [31/1] [__recompiles]     - 31/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 3
---------------------
iteration: 2
len: 3
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.452000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/2] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 3
len: 6
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.473000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/3] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 4
len: 10
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.490000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/4] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 5
len: 15
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.509000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/5] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 6
len: 21
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/5: L['___stack2'] == 21  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.527000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/6] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 7
len: 28
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/6: L['___stack2'] == 28  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/5: L['___stack2'] == 21  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.545000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/7] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
---------------------
iteration: 8
len: 36
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles] Recompiling function torch_dynamo_resume_in__rand_given_ndim_at_152 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     triggered by the following guard failure(s):
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/7: L['___stack2'] == 36  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/6: L['___stack2'] == 28  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/5: L['___stack2'] == 21  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/4: L['___stack2'] == 15  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/3: L['___stack2'] == 10  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/2: L['___stack2'] == 6  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/1: L['___stack2'] == 3  # return torch.randint(  # develop/torchrl-0/torchrl/data/replay_buffers/storages.py:150 in torch_dynamo_resume_in__rand_given_ndim_at_152 (_ops.py:723 in __call__)
V1021 16:37:16.564000 599972 site-packages/torch/_dynamo/guards.py:2842] [28/8] [__recompiles]     - 28/0: L['___stack2'] == 1                                         
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8] torch._dynamo hit config.cache_size_limit (8)
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8]    function: 'torch_dynamo_resume_in__rand_given_ndim_at_152' (/home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:152)
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8]    last reason: 28/0: L['___stack2'] == 1                                         
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1021 16:37:16.565000 599972 site-packages/torch/_dynamo/convert_frame.py:876] [28/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
---------------------
iteration: 9
len: 45
---------------------
iteration: 10
len: 55
---------------------
iteration: 11
len: 66
---------------------
iteration: 12
len: 78
---------------------
iteration: 13
len: 91
---------------------
iteration: 14
len: 105

Solution

Compiling and calling ReplayBuffer.extend and ReplayBuffer.sample back-to-back should not cause recompilation.

We need to support the base case of torchrl.data.ReplayBuffer(storage=torchrl.data.LazyTensorStorage(1000)), as well as cases where the storage is a LazyMemmapStorage and where the sampler is a SliceSampler.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@kurtamohler kurtamohler added the enhancement New feature or request label Oct 18, 2024
@kurtamohler kurtamohler self-assigned this Oct 18, 2024
@kurtamohler kurtamohler changed the title [Feature Request] Support compiling ReplayBuffer.extend without recompile [Bug] Support compiling ReplayBuffer.extend without recompile Oct 18, 2024
@kurtamohler
Copy link
Collaborator Author

Related PR: #2426

@kurtamohler kurtamohler changed the title [Bug] Support compiling ReplayBuffer.extend without recompile [Feature Request] Support compiling ReplayBuffer.extend without recompile Oct 18, 2024
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Oct 21, 2024

@vmoens, you mentioned offline that you've seen recompiles every time you tried to call extend on a ReplayBuffer(storage=LazyTensorStorage(1000)). Could you share with me a case where that happens? The test I added in #2504 only has recompiles on the first two calls.

As for the recompiles that I have seen, if I set num_extend_before_capture = 0 in the test, it no longer ignores those recompiles, and I get these recompile records:

$ python test/test_rb.py -k test_extend_recompile[100-ReplayBuffer-LazyTensorStorage-tensor-RoundRobinWriter-RandomSampler]
...
[11/1] [__recompiles] Recompiling function _lazy_call_fn in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:389
[11/1] [__recompiles]     triggered by the following guard failure(s):
[11/1] [__recompiles]     - 11/0: L['self'].func_name == 'torchrl.data.replay_buffers.storages.TensorStorage.set'
[12/1] [__recompiles] Recompiling function torch_dynamo_resume_in__lazy_call_fn_at_394 in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:394
[12/1] [__recompiles]     triggered by the following guard failure(s):
[12/1] [__recompiles]     - 12/0: len(L['args']) == 3                                         
[18/1] [__recompiles] Recompiling function torch_dynamo_resume_in_set_at_713 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:713
[18/1] [__recompiles]     triggered by the following guard failure(s):
[18/1] [__recompiles]     - 18/0: ___check_obj_id(L['self'].initialized, 8907584) 

So there are three recompiles to look into.

The first and second one are caused by the use of the implement_for decorator here. Since this is the first time the _lazy_call_fn function within the implement_for.__call__ method is being called with the string self.func_name = "torchrl.data.replay_buffers.storages.TensorStorage.set", and since torch.compile has to recompile every time a function is given a different string argument, I'm not sure there is much we can/should do about this one. Maybe it would be possible to change how implement_for works to avoid this recompile, but maybe it's not worth the trouble since these only happen once.

The third recompile is caused by the fact that this branch of TensorStorage.set() is only visited in the first call. I'd guess that this also is not worth the trouble to try to fix.

Let me know what you think

EDIT: Nevermind, I realized that if I also compile sample and call it between each extend call, then I see recompiles for more than just the first two iterations, apparently due to the changed storage size. So I'll update the issue description to include that. Still, if there are any other relevant cases, let me know

@kurtamohler kurtamohler changed the title [Feature Request] Support compiling ReplayBuffer.extend without recompile [Feature Request] Support compiling ReplayBuffer.extend/sample without recompile Oct 21, 2024
@vmoens
Copy link
Contributor

vmoens commented Oct 22, 2024

EDIT: Nevermind, I realized that if I also compile sample and call it between each extend call, then I see recompiles for more than just the first two iterations, apparently due to the changed storage size. So I'll update the issue description to include that. Still, if there are any other relevant cases, let me know

yes I think this is the use case where I observed the many recompiles.
I guess one low hanging fruit would be to disable multiprocessed replay buffers when they're compiled (length and cursor trackers are mp.Variable objects that are seen by each process, but I guess compile won't like that).

@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Oct 25, 2024

I decided to benchmark a compiled back-to-back extend and sample function call in #2514. For the case in the benchmark, the compiled function is about 5x slower than the eager function on my machine. So #2504 didn't really help much. Although it avoids recompiles, the performance it gives is so bad that it's not useful to compile the buffer.

I stumbled onto the manual for the compiler here. I'm looking through it to find out what we can do to improve the performance. I didn't know about TORCH_TRACE before, and that seems to be a really nice way to see all the compiler issues. The graph breaks from the mp.Value accesses do seem to be a significant issue, so I'll look into trying to do something like #2426 to avoid the multiprocessing APIs when we need to compile

@kurtamohler
Copy link
Collaborator Author

I was able to make some progress on this, and I have a branch where a compiled replay buffer is getting a significant speedup over eager in some cases and only a slight slowdown in other cases. The branch is pretty messy at the moment, so after I fix it up, I'll push a PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants