diff --git a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m index 8e79407f1..4932d027c 100644 --- a/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m +++ b/lib/nnc/cmd/blas/mps/ccv_nnc_gemm_mps.m @@ -108,10 +108,8 @@ static int _ccv_nnc_gemm_forw(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint b_batch_size = b_nd < 3 ? 1 : b->info.dim[b_nd - 3]; for (i = 0; i < b_nd - 3; i++) b_batch_size *= b->info.dim[i]; - assert(a_batch_size == b_batch_size || a_batch_size == 1); if (a_batch_size == 1 && b_batch_size > 1) a_batch_inc = 0; - assert(w_batch_size == a_batch_size || w_batch_size == 1); if (w_batch_size == 1 && b_batch_size > 1) w_batch_inc = 0; @autoreleasepool { diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index 8d0ed0fd8..80060e17d 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -130,13 +130,16 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c (bias ? (q->info.datatype == bias->info.datatype) : 1); assert(is_same_dtype); + uint16_t data_type_size = UINT16_MAX; uint32_t mtl_data_type = UINT32_MAX; switch (q->info.datatype) { case CCV_16F: { + data_type_size = 2; mtl_data_type = 16; break; } case CCV_32F: { + data_type_size = 4; mtl_data_type = 3; break; } @@ -171,10 +174,13 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c .batch_dims_q = { 0 }, .batch_dims_mask = { 0 }, }; + // The matrix offsets can only be 4096 bytes, hence, our batch size can only be 128 at most. We use 64 for better alignment. + const int split_batch_size = ccv_min(batch_size, 64); + const int residual_batch_size = batch_size % split_batch_size; if (attention_is_batched) { - params.batch_dims_q[0] = batch_size; + params.batch_dims_q[0] = split_batch_size; params.batch_dims_q[1] = 0; - params.batch_dims_mask[0] = attn_mask ? amdim[0] : batch_size; + params.batch_dims_mask[0] = attn_mask ? amdim[0] : split_batch_size; params.batch_dims_mask[1] = 0; } ccv_nnc_mfa_prepare_attention(context, params); @@ -199,7 +205,36 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c o->dataof, attn_mask ? attn_mask->dataof : 0, }; - ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets); + int i; + if (batch_size <= split_batch_size) + { + ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets); + } else { + const int batch_count = batch_size / split_batch_size; + const uint64_t byte_stride_mask = R * C * data_type_size; + for (i = 0; i < batch_count; i++) + { + if (i > 0) + { + tensor_offsets[0] = q->dataof + i * split_batch_size * byte_stride_mask; + tensor_offsets[1] = k->dataof + i * split_batch_size * byte_stride_mask; + tensor_offsets[2] = v->dataof + i * split_batch_size * byte_stride_mask; + tensor_offsets[3] = o->dataof + i * split_batch_size * byte_stride_mask; + } + ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets); + } + if (residual_batch_size > 0) + { + tensor_offsets[0] = q->dataof + batch_count * split_batch_size * byte_stride_mask; + tensor_offsets[1] = k->dataof + batch_count * split_batch_size * byte_stride_mask; + tensor_offsets[2] = v->dataof + batch_count * split_batch_size * byte_stride_mask; + tensor_offsets[3] = o->dataof + batch_count * split_batch_size * byte_stride_mask; + params.batch_dims_q[0] = residual_batch_size; + params.batch_dims_mask[0] = attn_mask ? amdim[0] : residual_batch_size; + ccv_nnc_mfa_prepare_attention(context, params); + ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets); + } + } // NNC notation: // D = C * W^T + bias @@ -242,17 +277,17 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c // The C matrix of the GEMM cannot be transposed, so the assume the C matrix // is NHWC. assert(c->info.format == CCV_TENSOR_FORMAT_NHWC); - int M = cdim[2]; + int M = cdim[1] * cdim[2]; int N = cdim[3]; int K = H * D; if (o_nd == 3) { assert(adim[1] == attention_batch_size); - assert(adim[2] == M); + assert(adim[1] * adim[2] == M); } else { assert(adim[0] == attention_batch_size); - assert(adim[1] == M); + assert(adim[0] * adim[1] == M); } if (H > 1) { assert(adim[2] * adim[3] == K); @@ -318,7 +353,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c .D_trans = false, .alpha = (float)1.0, .beta = (float)0.0, - .batched = (gemm_is_batched ? 1 : 0), + .batched = 0, .fused_activation_function = 0, .fused_bias = (bias ? 1 : 0), @@ -326,10 +361,6 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c .batch_dims_b = { 0 }, .batch_dims_d = { 0 }, }; - if (gemm_is_batched) { - params.batch_dims_a[0] = gemm_batch_size; - params.batch_dims_a[1] = 0; - } ccv_nnc_mfa_prepare_gemm(context, params); mtl_buffer_t* bias_buffer = NULL; diff --git a/lib/nnc/mfa/ccv_nnc_mfa_normalization.cpp b/lib/nnc/mfa/ccv_nnc_mfa_normalization.cpp index a5073bb9f..95807db7a 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_normalization.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_normalization.cpp @@ -201,7 +201,7 @@ kernel void normalization( uint3 tgid [[threadgroup_position_in_grid]], ushort sidx [[simdgroup_index_in_threadgroup]], - ushort lid [[thread_index_in_threadgroup]] + uint lid [[thread_index_in_threadgroup]] ) { uint threadgroup_index = tgid.z * sequence_count + tgid.x; {