-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Support compiling ReplayBuffer.extend/sample
without recompile
#2501
Comments
ReplayBuffer.extend
without recompileReplayBuffer.extend
without recompile
Related PR: #2426 |
ReplayBuffer.extend
without recompileReplayBuffer.extend
without recompile
@vmoens, you mentioned offline that you've seen recompiles every time you tried to call As for the recompiles that I have seen, if I set
So there are three recompiles to look into. The first and second one are caused by the use of the The third recompile is caused by the fact that this branch of Let me know what you think EDIT: Nevermind, I realized that if I also compile |
ReplayBuffer.extend
without recompileReplayBuffer.extend/sample
without recompile
yes I think this is the use case where I observed the many recompiles. |
I decided to benchmark a compiled back-to-back I stumbled onto the manual for the compiler here. I'm looking through it to find out what we can do to improve the performance. I didn't know about TORCH_TRACE before, and that seems to be a really nice way to see all the compiler issues. The graph breaks from the |
I was able to make some progress on this, and I have a branch where a compiled replay buffer is getting a significant speedup over eager in some cases and only a slight slowdown in other cases. The branch is pretty messy at the moment, so after I fix it up, I'll push a PR |
Motivation
Compiling a back-to-back call to
ReplayBuffer.extend
andReplayBuffer.sample
, and then calling it multiple times causes the function to be recompiled each time.Running the above script gives the following, showing that the first 9 calls cause recompilations. Then it hits the cache limit, so the calls after that don't get compiled anymore, and it's just running the eager function at that point (per pytorch docs: https://pytorch.org/docs/stable/generated/torch.compile.html).
Click to expand/collapse
Solution
Compiling and calling
ReplayBuffer.extend
andReplayBuffer.sample
back-to-back should not cause recompilation.We need to support the base case of
torchrl.data.ReplayBuffer(storage=torchrl.data.LazyTensorStorage(1000))
, as well as cases where the storage is aLazyMemmapStorage
and where the sampler is aSliceSampler
.Checklist
The text was updated successfully, but these errors were encountered: