diff --git a/src/cublas.cu b/src/cublas.cu index e39b9b0..f155026 100644 --- a/src/cublas.cu +++ b/src/cublas.cu @@ -264,8 +264,8 @@ cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa, CULiP_exp_stats b_stats; snprintf(a_stats.name, a_stats.name_length - 1, "A"); snprintf(b_stats.name, b_stats.name_length - 1, "B"); - a_stats.stats = exp_stats(A, (transa == CUBLAS_OP_N ? m : k), (transb == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype); - b_stats.stats = exp_stats(B, (transa == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype); + a_stats.stats = exp_stats(A, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype); + b_stats.stats = exp_stats(B, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype); mtk::cu_exp_statistics::to_json(a_stats.stats); mtk::cu_exp_statistics::to_json(b_stats.stats); CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats); diff --git a/src/cublas.gemm.template.h b/src/cublas.gemm.template.h index ee517ed..dd1be75 100644 --- a/src/cublas.gemm.template.h +++ b/src/cublas.gemm.template.h @@ -43,8 +43,8 @@ cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle, cublasOperation_t transa, CULiP_exp_stats b_stats; snprintf(a_stats.name, a_stats.name_length - 1, "A"); snprintf(b_stats.name, b_stats.name_length - 1, "B"); - a_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(A, (transa == CUBLAS_OP_N ? m : k), (transb == CUBLAS_OP_N ? k : m), lda, cuda_stream); - b_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(B, (transa == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream); + a_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(A, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream); + b_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(B, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream); mtk::cu_exp_statistics::to_json(a_stats.stats); mtk::cu_exp_statistics::to_json(b_stats.stats); CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats); diff --git a/src/cublas.gemm_strided_batched.template.h b/src/cublas.gemm_strided_batched.template.h index e12cd64..207ec36 100644 --- a/src/cublas.gemm_strided_batched.template.h +++ b/src/cublas.gemm_strided_batched.template.h @@ -57,8 +57,8 @@ cublasStatus_t CULIP_FUNC_NAME (cublasHandle_t handle, snprintf(a_stats.name, a_stats.name_length - 1, "A"); snprintf(b_stats.name, b_stats.name_length - 1, "B"); for (std::uint32_t i = 0; i < batchCount; i++) { - a_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(A + i * strideA, (transa == CUBLAS_OP_N ? m : k), (transb == CUBLAS_OP_N ? k : m), lda, cuda_stream); - b_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(B + i * strideB, (transa == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream); + a_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(A + i * strideA, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream); + b_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(B + i * strideB, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream); } mtk::cu_exp_statistics::to_json(a_stats.stats); mtk::cu_exp_statistics::to_json(b_stats.stats);