-
Notifications
You must be signed in to change notification settings - Fork 521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider using torch.compile(model, fullgraph=True, mode="reduce-overhead")
#6
Comments
For performance reasons, you might also want to avoid synchronising after every loop iteration, and just doing it after 10 or 20 iterations and averaging out the result. That being said, I understand this would affect QoL for the script, so fair enough. |
On RTX 4090 / PyTorch nightly this reduces the throughput slightly (from 130k tok/s to 127 tok/s, or equivalently from 4030ms dt to 4116ms dt; using B=16 to make sure training fits into 24 GB VRAM). This is specifically attributed to |
Perhaps some tweaks are needed to make them run. You can see whether they were enabled or not running your program with |
Yes, the logs indicate that mode=reduce-overhead uses cuda graphs and by default they are not used. I assume there are some restrictions on kernel compilation/fusion when cuda graphs are enabled and these outweigh the CPU overhead savings in this case, as an individual step is fairly expensive anyway. |
Within PyTorch there are no heuristics on whether to use cudagraphs or not. If |
Sure - my point is that whatever else reduce-overhead changes in the compilation process, it’s more detrimental to overall performance on this workload on 4090 than cuda graphs are beneficial. |
someone has to try this out on a A100, it probably boosts the performance quite a lot. There are also other flags that are worth trying. |
Tried on a H100. Goes down from ~277k tok/sec to ~269k tok/sec on nightly 2.5.0.dev20240616+cu124 🤷♂️ |
A few points:
If you find the culprit of why didn't it run in the first place, feel free to tag @eellison in that PR. He's the maintainer of CUDA graphs within PyTorch. |
@marib00 very nice! Can you try "max-autotune" as well maybe? It is documented in https://pytorch.org/docs/stable/generated/torch.compile.html and might be even faster. Anyway, someone should create a PR |
@JohannesVod I did try "max-autotune" already and no change; it was compiling (i.e. autotuning) forever though. |
fullgraph=True
will make sure that there are no graphbreaks (this may already be the case).mode="reduce-overhead"
will use CUDA graphs if possible. See in [these benchmarks] that going from regulartorch.compile
toreduce-overhead
gives a good 70-100% speed-up on top of regulartorch.compile
.The text was updated successfully, but these errors were encountered: