Why doesn't the backward of the fused attention kernel account for the normalization constant in the softmax function? #4629
Unanswered
jeffwillette
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The fused softmax tutorial (https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py) shows a backward implementation.
In this implementation, when calculating the derivative with respect to$V$ , we would expect to see the attention matrix $A$ , as we are calculating $\frac{\partial}{\partial V} AV = A$ . However, it looks as if the implementation ignores the normalization constant (sum over the rows) of the attention matrix, and just calculates $A$ as $QK^\top - max(QK^\top, dim=1)$ instead of $\frac{QK^\top - max(QK^\top, dim=1)}{sum(QK^\top - max(QK^\top, dim=1), dim=1)}$
triton/python/tutorials/06-fused-attention.py
Lines 235 to 245 in 0e3cadd
Tests pass, and this appears to be equivalent to the eager attention backward. My question is why? Is there some line I am missing which incorporates the normalization constant, or is it just safely ignored because it doesn't change the output that much?
Beta Was this translation helpful? Give feedback.
All reactions