Skip to content

Commit

Permalink
Merge branch '9-complex-gemm' into 'master'
Browse files Browse the repository at this point in the history
Add complex gemm

See merge request mutsuki/CULiP!11
  • Loading branch information
enp1s0 committed May 5, 2021
2 parents 577f3e7 + 1484ec9 commit cc0a090
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 128 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ CULiP is a library for profiling the execution time of CUDA official library fun
- `cublasDgemm`
- `cublasSgemm`
- `cublasHgemm`
- `cublasCgemm`
- `cublasZgemm`
- `cublasGemmEx`

## Dependencies
Expand Down
4 changes: 3 additions & 1 deletion include/CULiP/cublas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ enum CULiP_cublas_control_t {
CULiP_cublasDgemm = 0,
CULiP_cublasSgemm = 1,
CULiP_cublasHgemm = 2,
CULiP_cublasGemmEx = 3,
CULiP_cublasCgemm = 3,
CULiP_cublasZgemm = 4,
CULiP_cublasGemmEx = 5,
CULiP_cublas_enum_length
};

Expand Down
153 changes: 34 additions & 119 deletions src/cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,125 +80,40 @@ extern "C" const char* CULiP_get_cublasComputeType_t_string(const cublasComputeT
// cuBLAS functions
// -------------------------------------------------

cublasStatus_t cublasSgemm(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float *alpha, const float *A, int lda,
const float *B, int ldb, const float *beta, float *C,
int ldc) {
const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasSgemm] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float*, const float*, int, const float*, int, const float*, float*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-m%d-n%d-k%d", __func__, m, n ,k);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}

cublasStatus_t cublasDgemm(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const double *alpha, const double *A, int lda,
const double *B, int ldb, const double *beta, double *C,
int ldc) {
const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasDgemm] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double*, const double*, int, const double*, int, const double*, double*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-m%d-n%d-k%d", __func__, m, n ,k);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}

cublasStatus_t cublasHgemm(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const half *alpha, const half *A, int lda,
const half *B, int ldb, const half *beta, half *C,
int ldc) {
const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasHgemm] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const half*, const half*, int, const half*, int, const half*, half*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-m%d-n%d-k%d", __func__, m, n ,k);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}
// SGEMM
#define CULIP_FUNC_NAME cublasSgemm
#define CULIP_TYPE float
#include "cublas.gemm.template"
#undef CULIP_FUNC_NAME
#undef CULIP_TYPE

// DGEMM
#define CULIP_FUNC_NAME cublasDgemm
#define CULIP_TYPE double
#include "cublas.gemm.template"
#undef CULIP_FUNC_NAME
#undef CULIP_TYPE

// HGEMM
#define CULIP_FUNC_NAME cublasHgemm
#define CULIP_TYPE half
#include "cublas.gemm.template"
#undef CULIP_FUNC_NAME
#undef CULIP_TYPE

// CGEMM
#define CULIP_FUNC_NAME cublasCgemm
#define CULIP_TYPE cuComplex
#include "cublas.gemm.template"
#undef CULIP_FUNC_NAME
#undef CULIP_TYPE

// ZGEMM
#define CULIP_FUNC_NAME cublasZgemm
#define CULIP_TYPE cuDoubleComplex
#include "cublas.gemm.template"
#undef CULIP_FUNC_NAME
#undef CULIP_TYPE

cublasStatus_t cublasGemmEx(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
Expand Down
39 changes: 39 additions & 0 deletions src/cublas.gemm.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
cublasStatus_t CULIP_FUNC_NAME(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const CULIP_TYPE *alpha, const CULIP_TYPE *A, int lda,
const CULIP_TYPE *B, int ldb, const CULIP_TYPE *beta, CULIP_TYPE *C,
int ldc) {
const int profiling_flag = (CULiP_profiling_control_array[CULiP_cublasSgemm] == 0) && CULiP_is_profiling_enabled(CULIP_CUBLAS_DISABLE_ENV_NAME);

// Get the function pointer
cublasStatus_t (*cublas_lib_func)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const CULIP_TYPE*, const CULIP_TYPE*, int, const CULIP_TYPE*, int, const CULIP_TYPE*, CULIP_TYPE*, int);
*(void**)(&cublas_lib_func) = CULiP_get_function_pointer(CULIP_CUBLAS_LIBRARY_NAME, CULIP_CUBLAS_ENV_NAME, __func__, &CULiP_cublas_lib_handle_cache);

cudaStream_t cuda_stream;
struct CULiP_profile_result profile_result;

if (profiling_flag) {
// Get current cuda stream
cublasGetStream(handle, &cuda_stream);

// Profile result structure
snprintf(profile_result.function_name, profile_result.function_name_length - 1, "%s-m%d-n%d-k%d", __func__, m, n ,k);

// Record start rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.start_timestamp);
}

// Call the function
const cublasStatus_t result = (*cublas_lib_func)(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
CULIBPROFILER_DEBUG_PRINT(printf("[CULiP Debug][%s] executed\n", __func__));

if (profiling_flag) {
// Record end rimestamp
CULiP_launch_function(cuda_stream, &CULiP_record_timestamp, (void*)&profile_result.end_timestamp);

// Print result
CULiP_launch_function(cuda_stream, &CULiP_print_profile_result, (void*)&profile_result);
}

return result;
}
57 changes: 49 additions & 8 deletions tests/cublas_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ cublasStatus_t gemm<half , op_gemm>(cublasHandle_t handle, cublasOperation_t tr
int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
template <>
cublasStatus_t gemm<cuComplex, op_gemm>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const cuComplex *alpha, const cuComplex *A, int lda,
const cuComplex *B, int ldb, const cuComplex *beta, cuComplex *C,
int ldc) {
return cublasCgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
template <>
cublasStatus_t gemm<cuDoubleComplex, op_gemm>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda,
const cuDoubleComplex *B, int ldb, const cuDoubleComplex *beta, cuDoubleComplex *C,
int ldc) {
return cublasZgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
// -----------------------------------------------------
// op_gemmEx
// -----------------------------------------------------
Expand Down Expand Up @@ -66,12 +82,33 @@ cublasStatus_t gemm<half , op_gemmEx>(cublasHandle_t handle, cublasOperation_t
int ldc) {
return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_16F, ldc, CUDA_R_16F, CUBLAS_GEMM_DEFAULT);
}
template <>
cublasStatus_t gemm<cuComplex, op_gemmEx>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const cuComplex *alpha, const cuComplex *A, int lda,
const cuComplex *B, int ldb, const cuComplex *beta, cuComplex *C,
int ldc) {
return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_C_32F, lda, B, CUDA_C_32F, ldb, beta, C, CUDA_C_32F, ldc, CUDA_C_32F, CUBLAS_GEMM_DEFAULT);
}
template <>
cublasStatus_t gemm<cuDoubleComplex, op_gemmEx>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda,
const cuDoubleComplex *B, int ldb, const cuDoubleComplex *beta, cuDoubleComplex *C,
int ldc) {
return cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_C_64F, lda, B, CUDA_C_64F, ldb, beta, C, CUDA_C_64F, ldc, CUDA_C_64F, CUBLAS_GEMM_DEFAULT);
}

template <class T>
T convert(const double a) {return static_cast<T>(a);}
template <> cuComplex convert<cuComplex >(const double a) {return make_float2(a, 0);}
template <> cuDoubleComplex convert<cuDoubleComplex>(const double a) {return make_double2(a, 0);}

template <class T, class Op>
void gemm_test() {
const std::size_t n = 1lu << 10;
const auto alpha = static_cast<T>(1);
const auto beta = static_cast<T>(0);
const auto alpha = convert<T>(1);
const auto beta = convert<T>(0);

T* mat_a;
T* mat_b;
Expand Down Expand Up @@ -102,12 +139,16 @@ void gemm_test() {
}

void test_all() {
gemm_test<double, op_gemm >();
gemm_test<float , op_gemm >();
gemm_test<half , op_gemm >();
gemm_test<double, op_gemmEx>();
gemm_test<float , op_gemmEx>();
gemm_test<half , op_gemmEx>();
gemm_test<double , op_gemm >();
gemm_test<float , op_gemm >();
gemm_test<half , op_gemm >();
gemm_test<cuComplex , op_gemm >();
gemm_test<cuDoubleComplex, op_gemm >();
gemm_test<double , op_gemmEx>();
gemm_test<float , op_gemmEx>();
gemm_test<half , op_gemmEx>();
gemm_test<cuComplex , op_gemmEx>();
gemm_test<cuDoubleComplex, op_gemmEx>();
}

int main(){
Expand Down

0 comments on commit cc0a090

Please sign in to comment.