Skip to content

Commit

Permalink
Fix the bug the softmax_lse is not the right size.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jun 19, 2024
1 parent a08e4b4 commit bd40f0b
Showing 1 changed file with 5 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,12 @@ static void _ccv_nnc_scaled_dot_product_attention_tensor_auto_forw(const ccv_nnc
if (output_size == 1)
return;
assert(output_size > 1);
// This is saved softmax_lse, which would be in 32F if exists.
outputs[1] = inputs[0];
if (q_nd == 4)
{
// Switch head v.s. sequence length.
outputs[1].dim[1] = outputs[1].dim[2];
outputs[1].dim[2] = inputs[0].dim[1];
}
outputs[1].dim[q_nd - 1] = inputs[1].dim[k_nd - 3]; // saved softmax should have sequence length of query x key.
outputs[1].dim[q_nd - 3] = inputs[0].dim[q_nd - 2];
outputs[1].dim[q_nd - 2] = inputs[0].dim[q_nd - 3];
outputs[1].dim[q_nd - 1] = 0;
outputs[1].datatype = CCV_32F;
}
}

Expand Down

0 comments on commit bd40f0b

Please sign in to comment.