Skip to content

Commit

Permalink
Fix arguments of take_matrix_statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
enp1s0 committed Aug 25, 2022
1 parent be28095 commit daff5fc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa,
CULiP_exp_stats b_stats;
snprintf(a_stats.name, a_stats.name_length - 1, "A");
snprintf(b_stats.name, b_stats.name_length - 1, "B");
a_stats.stats = exp_stats(A, (transa == CUBLAS_OP_N ? m : k), (transb == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype);
b_stats.stats = exp_stats(B, (transa == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype);
a_stats.stats = exp_stats(A, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream, Atype);
b_stats.stats = exp_stats(B, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream, Btype);
mtk::cu_exp_statistics::to_json(a_stats.stats);
mtk::cu_exp_statistics::to_json(b_stats.stats);
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats);
Expand Down
4 changes: 2 additions & 2 deletions src/cublas.gemm.template.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle, cublasOperation_t transa,
CULiP_exp_stats b_stats;
snprintf(a_stats.name, a_stats.name_length - 1, "A");
snprintf(b_stats.name, b_stats.name_length - 1, "B");
a_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(A, (transa == CUBLAS_OP_N ? m : k), (transb == CUBLAS_OP_N ? k : m), lda, cuda_stream);
b_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(B, (transa == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream);
a_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(A, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream);
b_stats.stats = mtk::cu_exp_statistics::take_matrix_statistics(B, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream);
mtk::cu_exp_statistics::to_json(a_stats.stats);
mtk::cu_exp_statistics::to_json(b_stats.stats);
CULiP_launch_function(cuda_stream, &CULiP_print_exp_stats_result, (void*)&a_stats);
Expand Down
4 changes: 2 additions & 2 deletions src/cublas.gemm_strided_batched.template.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ cublasStatus_t CULIP_FUNC_NAME (cublasHandle_t handle,
snprintf(a_stats.name, a_stats.name_length - 1, "A");
snprintf(b_stats.name, b_stats.name_length - 1, "B");
for (std::uint32_t i = 0; i < batchCount; i++) {
a_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(A + i * strideA, (transa == CUBLAS_OP_N ? m : k), (transb == CUBLAS_OP_N ? k : m), lda, cuda_stream);
b_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(B + i * strideB, (transa == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream);
a_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(A + i * strideA, (transa == CUBLAS_OP_N ? m : k), (transa == CUBLAS_OP_N ? k : m), lda, cuda_stream);
b_stats.stats += mtk::cu_exp_statistics::take_matrix_statistics(B + i * strideB, (transb == CUBLAS_OP_N ? k : n), (transb == CUBLAS_OP_N ? n : k), ldb, cuda_stream);
}
mtk::cu_exp_statistics::to_json(a_stats.stats);
mtk::cu_exp_statistics::to_json(b_stats.stats);
Expand Down

0 comments on commit daff5fc

Please sign in to comment.