Skip to content

Commit

Permalink
Previously wrong fix on batch size limitation for sdp.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 7, 2023
1 parent 738b3f2 commit aaec1fb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,10 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
.batch_dims_q = { 0 },
.batch_dims_mask = { 0 },
};
// The matrix offsets can only be 4096 bytes, hence, our batch size can only be 128 at most. We use 64 for better alignment.
const int split_batch_size = ccv_min(batch_size, 64);
const int residual_batch_size = batch_size % split_batch_size;
if (attention_is_batched) {
params.batch_dims_q[0] = split_batch_size;
params.batch_dims_q[0] = batch_size;
params.batch_dims_q[1] = 0;
params.batch_dims_mask[0] = attn_mask ? amdim[0] : split_batch_size;
params.batch_dims_mask[0] = attn_mask ? amdim[0] : batch_size;
params.batch_dims_mask[1] = 0;
}
ccv_nnc_mfa_prepare_attention(context, params);
Expand All @@ -205,36 +202,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
o->dataof,
attn_mask ? attn_mask->dataof : 0,
};
int i;
if (batch_size <= split_batch_size)
{
ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets);
} else {
const int batch_count = batch_size / split_batch_size;
const uint64_t byte_stride_mask = R * C * data_type_size;
for (i = 0; i < batch_count; i++)
{
if (i > 0)
{
tensor_offsets[0] = q->dataof + i * split_batch_size * byte_stride_mask;
tensor_offsets[1] = k->dataof + i * split_batch_size * byte_stride_mask;
tensor_offsets[2] = v->dataof + i * split_batch_size * byte_stride_mask;
tensor_offsets[3] = o->dataof + i * split_batch_size * byte_stride_mask;
}
ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets);
}
if (residual_batch_size > 0)
{
tensor_offsets[0] = q->dataof + batch_count * split_batch_size * byte_stride_mask;
tensor_offsets[1] = k->dataof + batch_count * split_batch_size * byte_stride_mask;
tensor_offsets[2] = v->dataof + batch_count * split_batch_size * byte_stride_mask;
tensor_offsets[3] = o->dataof + batch_count * split_batch_size * byte_stride_mask;
params.batch_dims_q[0] = residual_batch_size;
params.batch_dims_mask[0] = attn_mask ? amdim[0] : residual_batch_size;
ccv_nnc_mfa_prepare_attention(context, params);
ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets);
}
}
ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets);

// NNC notation:
// D = C * W^T + bias
Expand Down
22 changes: 12 additions & 10 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,18 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
byte_stride_block_mask = grid_size.width * grid_size.height * 1;
}

simd::ulong4 matrix_offsets[batch_sizes[0]];
for (int i = 0; i < batch_sizes[0]; ++i) {
matrix_offsets[i] = simd::ulong4 {
i * byte_stride_mask,
i * byte_stride_block_mask,
0,
0,
};
}
encoder->setBytes(matrix_offsets, batch_sizes[0] * 32, 10);
if (hash.masked) {
simd::ulong4 matrix_offsets[batch_sizes[0]];
for (int i = 0; i < batch_sizes[0]; ++i) {
matrix_offsets[i] = simd::ulong4 {
i * byte_stride_mask,
i * byte_stride_block_mask,
0,
0,
};
}
encoder->setBytes(matrix_offsets, batch_sizes[0] * 32, 10);
}
}

if (params.masked) {
Expand Down

0 comments on commit aaec1fb

Please sign in to comment.