Skip to content

Commit

Permalink
Fix cases with GEMM broadcast.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 30, 2023
1 parent c9d2a6c commit 81c7ea2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ static inline void _ccv_nnc_gbmm_and_bias(cublasHandle_t cublas, const void* con
for (i = 0; i < dim; i++)
{
_ccv_nnc_gbmm_and_bias(cublas, ones,
a_nd > 3 ? a + CCV_GET_DATA_TYPE_SIZE(a_datatype) * i * astride[0] : a, a_datatype, a_nd > 3 ? a_nd - 1 : a_nd, a_nd > 3 ? adim + 1 : adim, a_nd > 3 ? astride + 1 : astride,
w_nd > 3 ? w + CCV_GET_DATA_TYPE_SIZE(w_datatype) * i * wstride[0] : w, w_datatype, w_nd > 3 ? w_nd - 1 : w_nd, w_nd > 3 ? wdim + 1 : wdim, w_nd > 3 ? wstride + 1 : wstride,
(a_nd > 3 && adim[0] > 1) ? a + CCV_GET_DATA_TYPE_SIZE(a_datatype) * i * astride[0] : a, a_datatype, a_nd > 3 ? a_nd - 1 : a_nd, a_nd > 3 ? adim + 1 : adim, a_nd > 3 ? astride + 1 : astride,
(w_nd > 3 && wdim[0] > 1) ? w + CCV_GET_DATA_TYPE_SIZE(w_datatype) * i * wstride[0] : w, w_datatype, w_nd > 3 ? w_nd - 1 : w_nd, w_nd > 3 ? wdim + 1 : wdim, w_nd > 3 ? wstride + 1 : wstride,
bias_nd > 3 ? bias + CCV_GET_DATA_TYPE_SIZE(bias_datatype) * i * biasstride[0] : bias, bias_datatype, bias_nd > 3 ? bias_nd - 1 : bias_nd, bias_nd > 3 ? biasdim + 1 : biasdim, bias_nd > 3 ? biasstride + 1 : biasstride,
b + CCV_GET_DATA_TYPE_SIZE(b_datatype) * i * bstride[0], b_datatype, b_nd - 1, bdim + 1, bstride + 1, b_batch_size, transa, transb, lda_inc, ldb_inc, a_batch_inc, w_batch_inc, bias_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, bias_rows_inc, b_rows_inc);
}
Expand Down Expand Up @@ -66,8 +66,8 @@ static inline void _ccv_nnc_gbmm(cublasHandle_t cublas, const unsigned char* con
for (i = 0; i < dim; i++)
{
_ccv_nnc_gbmm(cublas,
a_nd > 3 ? a + CCV_GET_DATA_TYPE_SIZE(a_datatype) * i * astride[0] : a, a_datatype, a_nd > 3 ? a_nd - 1 : a_nd, a_nd > 3 ? adim + 1 : adim, a_nd > 3 ? astride + 1 : astride,
w_nd > 3 ? w + CCV_GET_DATA_TYPE_SIZE(w_datatype) * i * wstride[0] : w, w_datatype, w_nd > 3 ? w_nd - 1 : w_nd, w_nd > 3 ? wdim + 1 : wdim, w_nd > 3 ? wstride + 1 : wstride,
(a_nd > 3 && adim[0] > 1) ? a + CCV_GET_DATA_TYPE_SIZE(a_datatype) * i * astride[0] : a, a_datatype, a_nd > 3 ? a_nd - 1 : a_nd, a_nd > 3 ? adim + 1 : adim, a_nd > 3 ? astride + 1 : astride,
(w_nd > 3 && wdim[0] > 1) ? w + CCV_GET_DATA_TYPE_SIZE(w_datatype) * i * wstride[0] : w, w_datatype, w_nd > 3 ? w_nd - 1 : w_nd, w_nd > 3 ? wdim + 1 : wdim, w_nd > 3 ? wstride + 1 : wstride,
b + CCV_GET_DATA_TYPE_SIZE(b_datatype) * i * bstride[0], b_datatype, b_nd - 1, bdim + 1, bstride + 1, b_batch_size, transa, transb, lda_inc, ldb_inc, a_batch_inc, w_batch_inc, b_batch_inc, b_rows, b_cols, a_cols, b_rows_inc);
}
}
Expand Down

0 comments on commit 81c7ea2

Please sign in to comment.