Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]: Discrepancy in Pre-filling Time and Memory Consumption on Single A100 #84

Open
lepangdan opened this issue Nov 15, 2024 · 3 comments
Assignees
Labels
question Further information is requested

Comments

@lepangdan
Copy link

Describe the issue

I came across your statement in the paper where you mentioned:

"When serving LLaMA-3-8B on a single A100 machine, the model would keep users waiting for 6 minutes to finish the pre-filling stage given a prompt of 300K tokens, and this number increases to 30 minutes for a prompt of 1M tokens."

However, I am also running on a single A100 (80GB) and using Hugging Face's implementation of LLaMA in SDPA mode. With a 50k token context, the pre-fill time is around 2.5 seconds, but when using 100k tokens, I run into an "Out of Memory" issue.

Could you clarify why there is such a significant discrepancy between your results and mine? Is there something I might be missing or misunderstanding?

Thanks for your help!

@lepangdan lepangdan added the question Further information is requested label Nov 15, 2024
@iofu728 iofu728 self-assigned this Nov 18, 2024
@iofu728
Copy link
Contributor

iofu728 commented Nov 18, 2024

Hi @lepangdan,

Thanks for your question!

  1. First, I apologize for the error in the Introduction section of our paper. The sentence should read: "3 minutes to finish the pre-filling stage given a prompt of 300K tokens," not 6 minutes. You can also verify this in Figure 1(b). We will update the arXiv and NeurIPS versions ASAP. Thank you for pointing this out!

  2. Regarding the TTFT of SDPA being three times faster than what we measured, I suspect the issue might be the absence of torch.cuda.synchronize(). You can follow the script provided in our repository (https://github.com/microsoft/MInference/blob/main/experiments/benchmarks/benchmark_e2e.py) and add it to this line #L125:

    attn_implementation="sdpq", # default is flash_attention_2

    Then run:

    python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000

    The TTFT should be around 7.5 seconds. This result can also be cross-verified with the vLLM implementation:

    Image

  3. Lastly, the original HF implementation does not support very large context windows. As stated in Appendix C.3, we detail the optimization steps we performed. You can use --attn_type minference_with_dense with our optimized implementation or leverage vLLM to achieve longer context windows.

Thanks again for raising these points, and please let me know if you have further questions!

@lepangdan
Copy link
Author

lepangdan commented Nov 25, 2024

Hi, @iofu728
Thanks for your helpful reply.

  1. I ran the command python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 300_000, which resulted in 142s for 300,000 tokens, confirming that 3 minutes is close.
  2. a)Yes, the issue was caused by the absence of torch.cuda.synchronize(). Thank you for your help!
    b) When I ran python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000, the result was 9.6 seconds, which seems fairly close to the 7.5 seconds you mentioned, from my perspective.
    c) I noticed that when I run both python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000 or python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 50_000, the printed model shows that the default attention is in LlamaSdpaAttention (as shown below when I added the print(model) code at the end, on line 137 in https://github.com/microsoft/MInference/blob/7a3e5acaaf0e83105d941a4067f53020ca1eba12/experiments/benchmarks/benchmark_e2e.py), rather than flash_attention_2. Instead, when I explicitly add the argument attn_implementation="flash_attention_2" on line 125, the printed model shows LlamaFlashAttention2. According to the paper, the baseline should be flash_attention_2 , right? But the default setting seems to be sdpa, I'm a bit confused about this . Could you confirm if I might be misunderstanding something?"

default:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)

after adding augment attn_implementation="flash_attention_2":

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)

Looking forward to your reply.

@iofu728
Copy link
Contributor

iofu728 commented Nov 26, 2024

Hi @lepangdan,

Thank you for your feedback. The results reported in the paper were obtained using minference_with_dense as it supports longer contexts. In fact, you can also specify the use of flash_attn by setting attn_implementation="flash_attention_2".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants