Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] FlashAttention performance lags behind PyTorch #24934

Open
neel04 opened this issue Nov 17, 2024 · 4 comments
Open

[GPU] FlashAttention performance lags behind PyTorch #24934

neel04 opened this issue Nov 17, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@neel04
Copy link

neel04 commented Nov 17, 2024

Description

I'm benchmarking naive FlashAttention in Jax vs. the Pallas's version of FA3 vs. the new dot_product_attention interface with cudnn backend.

  • JAX/XLA's performance:

image

  • Torch's performance:

image

Why the discrepancy? I'd have expected performance to touch 550-600 TFLOPS/s. I'm using a few XLA flags, as specified in the script below - but is there anything I'm missing? Or is this about the maximum XLA can deliver on H100s?

Steps to reproduce

  1. Recreate the environment using uv. I'm assuming the drivers are installed. If not, you can use the pytorch/pytorch:2.4.0-cuda12.4.1-cudnn8-runtime image on the GPU, run the preliminary apt-get update and apt-get upgrade to set everything up.
pip3 install uv
uv venv 'main_env' --python 3.11
source main_env/bin/activate

uv pip install -U "jax[cuda12]"
uv pip install -q einops tqdm jaxtyping optax optuna equinox rich
uv pip install -q nvitop pdbpp tabulate
  1. Run either script
**JAX script**

import os, sys

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9'

os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_enable_cudnn_fmha=true'
)

import math
import time
from tabulate import tabulate
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_map
from jax.experimental.pallas.ops.gpu.attention import mha as pallas_mha
from functools import partial

class Timer(object):
    def __init__(self, into=None):
        self.into = into

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, type, value, traceback):
        if self.into is not None:
            self.into.append(time.time() - self.start)

    def elapsed(self):
        return time.time() - self.start

def cartesian(*lists):
    if lists:
        xs = lists[0]
        for x in xs:
            for rest in cartesian(*lists[1:]):
                yield (x,) + rest
    else:
        yield ()

def cross_attn_flops(B,T,TK,H,C):
    HC = H*C
    # T = TK for self-attention
    flops_fwd = (
        2 * B*H*T*TK*C + # S = Q@K
        3 * B*H*T*TK +  # P=softmax(S)
        2 * B*T*TK*H*C # O = P@V
    )
    return flops_fwd

def attn_flops(B,T,H,C):
    return cross_attn_flops(B,T,T,H,C)

dtype = jnp.float16
print(f'Using dtype: {dtype}')

def convert(xs):
    return tree_map(lambda x: x.astype(dtype), xs)
    
def ref_fwd(q,k,v):
    # reference implementation
    [n, l, h, d] = q.shape
    [n, lk, hk, d] = k.shape
    softmax_scale = 1 / math.sqrt(d)
    S = jnp.einsum('nlhd,nLhd->nhlL',q,k)
    P = jax.nn.softmax(S*softmax_scale, axis=-1)
    o = jnp.einsum('nhlL,nLhd->nlhd',P,v)
    return o.astype(q.dtype)

def jax_dpa_fwd(q, k, v):
    output = jax.nn.dot_product_attention(
        query=q,
        key=k,
        value=v,
        implementation='cudnn'
    )

    return output

# ----

Bx = [8,  16] # batch size
Tx = [1024, 2048] # seqlen
Hx = [16, 32] # number of heads
Cx = [64, 128] # head dim
sx = [2] # steps

def bench_attn(mha):
    @jax.jit
    def bench(q, k, v, steps: int):
        for i in steps:
            out = mha(q, k, v)
        return out

    times = []
    table = {}

    for B,T,H,C,s in cartesian(Bx,Tx,Hx,Cx,sx):
        q = jax.random.normal(jax.random.PRNGKey(0), [B, T, H, C], dtype=dtype)
        k = jax.random.normal(jax.random.PRNGKey(1), [B, T, H, C], dtype=dtype)
        v = jax.random.normal(jax.random.PRNGKey(2), [B, T, H, C], dtype=dtype)
        steps = jnp.arange(s)

        # Warmup
        for _ in range(2):
            bench(q, k, v, steps).block_until_ready()

        out = []
        for _ in range(8):
            with Timer(out):
                bench(q, k, v, steps).block_until_ready()

        t = sum(out[2:])/len(out[2:])

        flop = attn_flops(B,T,H,C) * s
        tf = flop / t / 1e12 / s
        print(f'flops ({B} {T} {H} {C} / {s}): {tf}T', out)
        table[(B,T,H,C,s)] = tf
        times.append(t)
    
    return table, times

naive_flops, custom_time = bench_attn(ref_fwd)
pallas_flops, pallas_time = bench_attn(partial(pallas_mha, segment_ids=None))
jax_dpa_flops, dpa_time = bench_attn(jax_dpa_fwd)

table = []

for idx, (B,T,H,C,s) in enumerate(cartesian(Bx,Tx,Hx,Cx,sx)):
    n_flops, n_time = naive_flops[(B,T,H,C,s)], custom_time[idx]
    p_flops, p_time = pallas_flops[(B,T,H,C,s)], pallas_time[idx]
    j_flops, j_time = jax_dpa_flops[(B,T,H,C,s)], dpa_time[idx]
    
    table.append((B,T,H,C, n_flops, p_flops, j_flops, attn_flops(B,T,H,C), n_time, p_time, j_time))

print(tabulate(table, headers=['B','T','H','C','TFlop/s (naive)','TFlop/s (pallas)','TFlop/s (jax_dpa)', 'FLOPs', 'Naive time', 'Pallas time', 'DPA time'], floatfmt='.5f'))

**PyTorch Benchmark script**

import torch
torch._dynamo.config.cache_size_limit = 10000  # Increase cache size to 10,000

import time
from tabulate import tabulate
from typing import Optional
from itertools import product
from torch.utils.flop_counter import FlopCounterMode

class Timer:
    def __init__(self, into=None):
        self.into = into

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, type, value, traceback):
        if self.into is not None:
            self.into.append(time.time() - self.start)

class OptimizedMHA(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        B, T, H, C = q.shape
        scale = 1.0 / (C ** 0.5)

        # Compute attention scores
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale  # [B, H, T, T]
        attn = torch.softmax(attn, dim=-1)

        # Apply attention to values
        out = torch.matmul(attn, v)  # [B, H, T, C]
        out = out.transpose(1, 2)  # [B, T, H, C]

        return out

class TorchMHA(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.scaled_dot_product_attention(
            q, k, v,
            dropout_p=0.0,
            is_causal=False
        )

def get_flops(model, sample_inputs):
    """Get FLOPs using FlopCounterMode"""
    with FlopCounterMode(model) as flop_counter:
        _ = model(*sample_inputs)
    return flop_counter.get_total_flops()
def bench_implementation(model, q, k, v, warmup=10, steps=100):
    """Benchmark a specific implementation"""
    # Warmup
    for _ in range(warmup):
        with torch.no_grad():
            _ = model(q, k, v)
    torch.cuda.synchronize()

    # Benchmark timing
    times = []
    for _ in range(steps):
        with torch.no_grad(), Timer(times):
            _ = model(q, k, v)
            torch.cuda.synchronize()

    # Calculate statistics
    times = times[10:]  # Discard first 10 runs
    avg_time = sum(times) / len(times)
    return avg_time

def bench_attention():
    # Configuration
    device = torch.device("cuda")
    dtype = torch.float16

    # Test parameters
    Bx = [8, 16]               # batch size
    Tx = [1024, 2048]            # sequence length
    Hx = [16, 32]          # number of heads
    Cx = [64, 128]     # head dimension
    sx = [4]               # steps to run

    # Initialize models
    custom_model = OptimizedMHA().to(device)
    torch_model = TorchMHA().to(device)

    # Compile both models
    compiled_custom = torch.compile(
        custom_model,
        mode="max-autotune-no-cudagraphs",
        fullgraph=True,
    )
    compiled_torch = torch.compile(
        torch_model,
        mode="max-autotune-no-cudagraphs",
        fullgraph=True,
    )

    results = []

    # Run benchmarks for each configuration
    for B, T, H, C, s in product(Bx, Tx, Hx, Cx, sx):
        # Create input tensors
        q = torch.randn(B, T, H, C, device=device, dtype=dtype)
        k = torch.randn(B, T, H, C, device=device, dtype=dtype)
        v = torch.randn(B, T, H, C, device=device, dtype=dtype)

        q = q.transpose(1, 2)  # [B, H, T, C]
        k = k.transpose(1, 2)  # [B, H, T, C]
        v = v.transpose(1, 2)  # [B, H, T, C]

        # Get FLOPs using FlopCounterMode (on CPU with float32)
        model_cpu = OptimizedMHA()
        q_cpu = q.cpu().float()
        k_cpu = k.cpu().float()
        v_cpu = v.cpu().float()
        flops = get_flops(model_cpu, (q_cpu, k_cpu, v_cpu))

        # Benchmark both implementations
        custom_time = bench_implementation(compiled_custom, q, k, v)
        torch_time = bench_implementation(compiled_torch, q, k, v)

        # Calculate TFLOPs/s for both
        custom_tflops = flops / custom_time / 1e12
        torch_tflops = flops / torch_time / 1e12

        # Calculate speedup
        speedup = custom_time / torch_time  # >1 means torch is faster

        print(f"\nConfig (B={B}, T={T}, H={H}, C={C}):")
        print(f"Custom impl: {custom_tflops:.2f} TFlop/s")
        print(f"Torch impl: {torch_tflops:.2f} TFlop/s")
        print(f"Speedup (Torch vs Custom): {speedup:.2f}x")
        print(f"Measured FLOPs: {flops:,}")

        results.append((
            B, T, H, C,
            round(custom_tflops, 2),
            round(torch_tflops, 2),
            round(speedup, 2),
            flops,
            custom_time,
            torch_time
        ))

    # Print results table
    headers = [
        'Batch', 'SeqLen', 'Heads', 'HeadDim',
        'Naive MHA TFlop/s', 'SDPA TFlop/s',
        'Advantage', 'FLOPs', 'Custom Time', 'SDPA Time'
    ]
    print("\nResults:")
    print(tabulate(results, headers=headers, floatfmt='.5f'))

if __name__ == "__main__":
    bench_attention()


System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.11.10 (main, Oct 16 2024, 04:38:48) [Clang 18.1.8 ]
device info: NVIDIA H100 PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='nifty-orthodox-whale', release='6.8.0-40-generic', version='#40~22.04.3-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 30 17:30:19 UTC 2', machine='x86_64')


$ nvidia-smi
Sun Nov 17 02:04:44 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06             Driver Version: 535.183.06   CUDA Version: 12.4     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 PCIe               On  | 00000000:00:07.0 Off |                    0 |
| N/A   31C    P0              54W / 350W |    467MiB / 81559MiB |      2%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@neel04 neel04 added the bug Something isn't working label Nov 17, 2024
@neel04
Copy link
Author

neel04 commented Nov 17, 2024

Update: I changed the torch script to use FlopCounterMode. Now the results are more realistic/accurate but JAX still lags behind despite explicitly being forced to use CuDNN.

cc @kaixih @sbodenstein @dfm

@Rick0827
Copy link

i think your torch script might not work as expected, since the inputs format of torch.nn.functional.scaled_dot_product_attention is [B, H, T, C] [https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html]
63aaff67935ab7ebb1dc5feb4226787

@sbodenstein
Copy link
Contributor

This is a complicated benchmarking setup, with many things potentially going wrong. Can you simplify this to just measuring milliseconds, and also have a correctness test (that PyTorch and JAX give the same output for the same input).

@neel04
Copy link
Author

neel04 commented Nov 19, 2024

@Rick0827 Thank you for pointing that out.

@sbodenstein I have updated both scripts to now report times as well. However, I opted to skip correctness tests because reproducibility requires us to sacrifice performance which I'm afraid of touching

The variance however is very low between runs plus we can average over multiple steps (sx) so this should be a non-issue.

On A100:

  • JAX:

image

  • Torch:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants