Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to merge a KV cache into the Diff Attention? #1661

Open
ZetangForward opened this issue Nov 28, 2024 · 3 comments
Open

How to merge a KV cache into the Diff Attention? #1661

ZetangForward opened this issue Nov 28, 2024 · 3 comments
Assignees

Comments

@ZetangForward
Copy link

I notice the authors only provide the vanilla code in the Diff Attention Repo.
However, in the paper, the authors also report the performance in long-context scenarios.
Vanilla implementation of Diff Attention can not support 64K context length on 80GB GPU.
I wonder how authors achieve long context inference. Is there a KV cache version of Diff Attention?

@donglixp donglixp self-assigned this Dec 2, 2024
@YTianZHU
Copy link
Contributor

YTianZHU commented Dec 2, 2024

@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU.

If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support.

If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention

@ZetangForward
Copy link
Author

ZetangForward commented Dec 2, 2024

@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU.

If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support.

If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention

ok, thx. BTW, I want ask an additional question that is irrevelent with the code.

I found that there is no clear objective function to constrain the ''differential heads'' in the paper, but the ability to ''eliminate noise" is automatically learned through the designed gate mechanism. I am curious if it is possible to explain intuitively why vanilla's training objective function (i.e. Next Token Prediction+CE Loss) can eliminate differences between a set of heads (two heads)? Does this phenomenon occur in untrained models (e.g., Llama3)?

@YTianZHU

@YTianZHU
Copy link
Contributor

YTianZHU commented Dec 2, 2024

@ZetangForward Although the paired heads are independent in forward, they can perceive each other in backward. The two heads are fused together after differential attention, therefore there is information of both heads in the gradients of weights (Wq, Wk). The gradients of these weights guide the two heads to learn how to project the input according to each other.

A model without any training (with randomly initialized weights) can't have this ability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants