From 9f4b28216b67f5d38ba1a339e425fcf313c89657 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Tue, 2 Jan 2024 17:46:48 -0500 Subject: [PATCH] Using 16F for internal compute doesn't speed anything up. --- lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu | 5 +++++ lib/nnc/gpu/ccv_nnc_compat.cu | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu b/lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu index 51ac7cb99..6b8e0b96d 100644 --- a/lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu +++ b/lib/nnc/cmd/blas/gpu/ccv_nnc_gemm_gpu_cublas.cu @@ -23,6 +23,7 @@ static inline void _ccv_nnc_gbmm_and_bias(cublasHandle_t cublas, const void* con one = &one_f16; break; case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: one = &one_f32; break; case CUBLAS_COMPUTE_64F: @@ -75,6 +76,7 @@ static inline void _ccv_nnc_gbmm(cublasHandle_t cublas, const unsigned char* con one = &one_f16; break; case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: one = &one_f32; break; case CUBLAS_COMPUTE_64F: @@ -249,6 +251,7 @@ static inline void _ccv_nnc_gbmm_dbias(cublasHandle_t cublas, const int flags, c one = &one_f16; break; case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: one = &one_f32; break; case CUBLAS_COMPUTE_64F: @@ -304,6 +307,7 @@ static inline void _ccv_nnc_gbmm_dw(cublasHandle_t cublas, const int flags, cons one = &one_f16; break; case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: one = &one_f32; break; case CUBLAS_COMPUTE_64F: @@ -387,6 +391,7 @@ static inline void _ccv_nnc_gbmm_h(cublasHandle_t cublas, const int flags, const one = &one_f16; break; case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: one = &one_f32; break; case CUBLAS_COMPUTE_64F: diff --git a/lib/nnc/gpu/ccv_nnc_compat.cu b/lib/nnc/gpu/ccv_nnc_compat.cu index b9798c562..fcc05ff5f 100644 --- a/lib/nnc/gpu/ccv_nnc_compat.cu +++ b/lib/nnc/gpu/ccv_nnc_compat.cu @@ -767,7 +767,7 @@ cublasComputeType_t ccv_nnc_cuda_compute_datatype(int datatype) case CCV_32S: return CUBLAS_COMPUTE_32F; case CCV_16F: - return CUBLAS_COMPUTE_16F; + return CUBLAS_COMPUTE_32F; case CCV_32F: return CUBLAS_COMPUTE_32F; case CCV_64F: