You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm benchmarking naive FlashAttention in Jax vs. the Pallas's version of FA3 vs. the new dot_product_attention interface with cudnn backend.
JAX/XLA's performance:
Torch's performance:
Why the discrepancy? I'd have expected performance to touch 550-600 TFLOPS/s. I'm using a few XLA flags, as specified in the script below - but is there anything I'm missing? Or is this about the maximum XLA can deliver on H100s?
Steps to reproduce
Recreate the environment using uv. I'm assuming the drivers are installed. If not, you can use the pytorch/pytorch:2.4.0-cuda12.4.1-cudnn8-runtime image on the GPU, run the preliminary apt-get update and apt-get upgrade to set everything up.
Update: I changed the torch script to use FlopCounterMode. Now the results are more realistic/accurate but JAX still lags behind despite explicitly being forced to use CuDNN.
i think your torch script might not work as expected, since the inputs format of torch.nn.functional.scaled_dot_product_attention is [B, H, T, C] [https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html]
This is a complicated benchmarking setup, with many things potentially going wrong. Can you simplify this to just measuring milliseconds, and also have a correctness test (that PyTorch and JAX give the same output for the same input).
@sbodenstein I have updated both scripts to now report times as well. However, I opted to skip correctness tests because reproducibility requires us to sacrifice performance which I'm afraid of touching
The variance however is very low between runs plus we can average over multiple steps (sx) so this should be a non-issue.
Description
I'm benchmarking naive FlashAttention in
Jax
vs. the Pallas's version ofFA3
vs. the newdot_product_attention
interface withcudnn
backend.Why the discrepancy? I'd have expected performance to touch 550-600 TFLOPS/s. I'm using a few XLA flags, as specified in the script below - but is there anything I'm missing? Or is this about the maximum
XLA
can deliver on H100s?Steps to reproduce
uv
. I'm assuming the drivers are installed. If not, you can use thepytorch/pytorch:2.4.0-cuda12.4.1-cudnn8-runtime
image on the GPU, run the preliminaryapt-get update
andapt-get upgrade
to set everything up.**JAX script**
**PyTorch Benchmark script**
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: