-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to merge a KV cache into the Diff Attention? #1661
Comments
@ZetangForward hi, in long-context scenarios, we use flash decoding, and it can support 64K length inference on 80GB GPU. If you use Diff-Transformer/multihead_flashdiff_1, you can refer to https://aka.ms/flash-diff for flash decoding support. If you use Diff-Transformer/multihead_flashdiff_2, you can refer to official flash decoding at https://github.com/Dao-AILab/flash-attention |
ok, thx. BTW, I want ask an additional question that is irrevelent with the code. I found that there is no clear objective function to constrain the ''differential heads'' in the paper, but the ability to ''eliminate noise" is automatically learned through the designed gate mechanism. I am curious if it is possible to explain intuitively why vanilla's training objective function (i.e. Next Token Prediction+CE Loss) can eliminate differences between a set of heads (two heads)? Does this phenomenon occur in untrained models (e.g., Llama3)? |
@ZetangForward Although the paired heads are independent in forward, they can perceive each other in backward. The two heads are fused together after differential attention, therefore there is information of both heads in the gradients of weights (Wq, Wk). The gradients of these weights guide the two heads to learn how to project the input according to each other. A model without any training (with randomly initialized weights) can't have this ability. |
I notice the authors only provide the vanilla code in the Diff Attention Repo.
However, in the paper, the authors also report the performance in long-context scenarios.
Vanilla implementation of Diff Attention can not support 64K context length on 80GB GPU.
I wonder how authors achieve long context inference. Is there a KV cache version of Diff Attention?
The text was updated successfully, but these errors were encountered: