Skip to content

Commit

Permalink
[HIPIFY][6.0.0][BLAS] Support for ROCm HIP 6.0.0 - Step 20 - half a…
Browse files Browse the repository at this point in the history
…nd `bfloat16` Functions

+ `__half` -> `__half` -> `rocblas_half`
+ `__nv_bfloat16` -> `hip_bfloat16` -> `rocblas_bfloat16`
+ [rocBLAS] New functions: `rocblas_hs(h|s)gemv_batched` and `rocblas_ts(s|t)gemv_batched`
+ [fix] Removed a non-existing type `nv_bfloat16`
+ Updated synthetic tests, the regenerated hipify-perl, and docs
  • Loading branch information
emankov committed Oct 22, 2023
1 parent 90d8251 commit 5f4e5b1
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 23 deletions.
13 changes: 7 additions & 6 deletions bin/hipify-perl
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,8 @@ sub rocSubstitutions {
subst("cublasGetStream_v2", "rocblas_get_stream", "library");
subst("cublasGetVector", "rocblas_get_vector", "library");
subst("cublasGetVectorAsync", "rocblas_get_vector_async", "library");
subst("cublasHSHgemvBatched", "rocblas_hshgemv_batched", "library");
subst("cublasHSSgemvBatched", "rocblas_hssgemv_batched", "library");
subst("cublasHgemm", "rocblas_hgemm", "library");
subst("cublasHgemmBatched", "rocblas_hgemm_batched", "library");
subst("cublasHgemmStridedBatched", "rocblas_hgemm_strided_batched", "library");
Expand Down Expand Up @@ -1575,6 +1577,8 @@ sub rocSubstitutions {
subst("cublasStrsm_v2", "rocblas_strsm", "library");
subst("cublasStrsv", "rocblas_strsv", "library");
subst("cublasStrsv_v2", "rocblas_strsv", "library");
subst("cublasTSSgemvBatched", "rocblas_tssgemv_batched", "library");
subst("cublasTSTgemvBatched", "rocblas_tstgemv_batched", "library");
subst("cublasZaxpy", "rocblas_zaxpy", "library");
subst("cublasZaxpy_v2", "rocblas_zaxpy", "library");
subst("cublasZcopy", "rocblas_zcopy", "library");
Expand Down Expand Up @@ -2031,6 +2035,8 @@ sub rocSubstitutions {
subst("cusparseZgtsvInterleavedBatch_bufferSizeExt", "rocsparse_zgtsv_interleaved_batch_buffer_size", "library");
subst("cusparseZnnz", "rocsparse_znnz", "library");
subst("cusparseZnnz_compress", "rocsparse_znnz_compress", "library");
subst("__half", "rocblas_half", "device_type");
subst("__nv_bfloat16", "rocblas_bfloat16", "device_type");
subst("cublas.h", "rocblas.h", "include_cuda_main_header");
subst("cublas_v2.h", "rocblas.h", "include_cuda_main_header_v2");
subst("bsric02Info", "_rocsparse_mat_info", "type");
Expand Down Expand Up @@ -4052,6 +4058,7 @@ sub simpleSubstitutions {
subst("__half2", "__half2", "device_type");
subst("__half2_raw", "__half2_raw", "device_type");
subst("__half_raw", "__half_raw", "device_type");
subst("__nv_bfloat16", "hip_bfloat16", "device_type");
subst("caffe2\/core\/common_cudnn.h", "caffe2\/core\/hip\/common_miopen.h", "include");
subst("caffe2\/operators\/spatial_batch_norm_op.h", "caffe2\/operators\/hip\/spatial_batch_norm_op_miopen.hip", "include");
subst("channel_descriptor.h", "hip\/channel_descriptor.h", "include");
Expand Down Expand Up @@ -6757,7 +6764,6 @@ sub warnUnsupportedFunctions {
"nvrtcGetLTOIRSize",
"nvrtcGetLTOIR",
"nv_bfloat162",
"nv_bfloat16",
"memoryBarrier",
"libraryPropertyType_t",
"libraryPropertyType",
Expand Down Expand Up @@ -7969,7 +7975,6 @@ sub warnUnsupportedFunctions {
"__nv_bfloat16_raw",
"__nv_bfloat162_raw",
"__nv_bfloat162",
"__nv_bfloat16",
"__curand_umul",
"__NV_SATFINITE",
"__NV_NOSAT",
Expand Down Expand Up @@ -9958,11 +9963,9 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasTSTgemvStridedBatched_64",
"cublasTSTgemvStridedBatched",
"cublasTSTgemvBatched_64",
"cublasTSTgemvBatched",
"cublasTSSgemvStridedBatched_64",
"cublasTSSgemvStridedBatched",
"cublasTSSgemvBatched_64",
"cublasTSSgemvBatched",
"cublasSwapEx_64",
"cublasSwapEx",
"cublasStrttp",
Expand Down Expand Up @@ -10095,11 +10098,9 @@ sub warnRocOnlyUnsupportedFunctions {
"cublasHSSgemvStridedBatched_64",
"cublasHSSgemvStridedBatched",
"cublasHSSgemvBatched_64",
"cublasHSSgemvBatched",
"cublasHSHgemvStridedBatched_64",
"cublasHSHgemvStridedBatched",
"cublasHSHgemvBatched_64",
"cublasHSHgemvBatched",
"cublasGetVersion_v2",
"cublasGetVersion",
"cublasGetVector_64",
Expand Down
8 changes: 4 additions & 4 deletions docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,11 @@
|`cublasDtrsm_64`|12.0| | | | | | | | | | | | | | |
|`cublasDtrsm_v2`| | | |`hipblasDtrsm`|1.8.2| | | | |`rocblas_dtrsm`|1.5.0| | | | |
|`cublasDtrsm_v2_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSHgemvBatched`|11.6| | | | | | | | | | | | | | |
|`cublasHSHgemvBatched`|11.6| | | | | | | | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSHgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSSgemvBatched`|11.6| | | | | | | | | | | | | | |
|`cublasHSSgemvBatched`|11.6| | | | | | | | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSSgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
Expand Down Expand Up @@ -845,11 +845,11 @@
|`cublasStrsm_64`|12.0| | | | | | | | | | | | | | |
|`cublasStrsm_v2`| | | |`hipblasStrsm`|1.8.2| | | | |`rocblas_strsm`|1.5.0| | | | |
|`cublasStrsm_v2_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSSgemvBatched`|11.6| | | | | | | | | | | | | | |
|`cublasTSSgemvBatched`|11.6| | | | | | | | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSSgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSTgemvBatched`|11.6| | | | | | | | | | | | | | |
|`cublasTSTgemvBatched`|11.6| | | | | | | | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSTgemvBatched_64`|12.0| | | | | | | | | | | | | | |
|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | | | | | | | |
|`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | |
Expand Down
8 changes: 4 additions & 4 deletions docs/tables/CUBLAS_API_supported_by_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,11 @@
|`cublasDtrsm_64`|12.0| | | | | | | | |
|`cublasDtrsm_v2`| | | |`rocblas_dtrsm`|1.5.0| | | | |
|`cublasDtrsm_v2_64`|12.0| | | | | | | | |
|`cublasHSHgemvBatched`|11.6| | | | | | | | |
|`cublasHSHgemvBatched`|11.6| | |`rocblas_hshgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSHgemvBatched_64`|12.0| | | | | | | | |
|`cublasHSHgemvStridedBatched`|11.6| | | | | | | | |
|`cublasHSHgemvStridedBatched_64`|12.0| | | | | | | | |
|`cublasHSSgemvBatched`|11.6| | | | | | | | |
|`cublasHSSgemvBatched`|11.6| | |`rocblas_hssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasHSSgemvBatched_64`|12.0| | | | | | | | |
|`cublasHSSgemvStridedBatched`|11.6| | | | | | | | |
|`cublasHSSgemvStridedBatched_64`|12.0| | | | | | | | |
Expand Down Expand Up @@ -845,11 +845,11 @@
|`cublasStrsm_64`|12.0| | | | | | | | |
|`cublasStrsm_v2`| | | |`rocblas_strsm`|1.5.0| | | | |
|`cublasStrsm_v2_64`|12.0| | | | | | | | |
|`cublasTSSgemvBatched`|11.6| | | | | | | | |
|`cublasTSSgemvBatched`|11.6| | |`rocblas_tssgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSSgemvBatched_64`|12.0| | | | | | | | |
|`cublasTSSgemvStridedBatched`|11.6| | | | | | | | |
|`cublasTSSgemvStridedBatched_64`|12.0| | | | | | | | |
|`cublasTSTgemvBatched`|11.6| | | | | | | | |
|`cublasTSTgemvBatched`|11.6| | |`rocblas_tstgemv_batched`|6.0.0| | | |6.0.0|
|`cublasTSTgemvBatched_64`|12.0| | | | | | | | |
|`cublasTSTgemvStridedBatched`|11.6| | | | | | | | |
|`cublasTSTgemvStridedBatched_64`|12.0| | | | | | | | |
Expand Down
3 changes: 1 addition & 2 deletions docs/tables/CUDA_Device_API_supported_by_HIP.md
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@
|`__half2`| | | |`__half2`|1.6.0| | | | |
|`__half2_raw`| | | |`__half2_raw`|1.9.0| | | | |
|`__half_raw`| | | |`__half_raw`|1.9.0| | | | |
|`__nv_bfloat16`|11.0| | | | | | | | |
|`__nv_bfloat16`|11.0| | |`hip_bfloat16`|3.5.0| | | | |
|`__nv_bfloat162`|11.0| | | | | | | | |
|`__nv_bfloat162_raw`|11.0| | | | | | | | |
|`__nv_bfloat16_raw`|11.0| | | | | | | | |
Expand All @@ -826,7 +826,6 @@
|`__nv_fp8x4_e5m2`|11.8| | | | | | | | |
|`__nv_fp8x4_storage_t`|11.8| | | | | | | | |
|`__nv_saturation_t`|11.8| | | | | | | | |
|`nv_bfloat16`|11.0| | | | | | | | |
|`nv_bfloat162`|11.0| | | | | | | | |


Expand Down
12 changes: 8 additions & 4 deletions src/CUDA2HIP_BLAS_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,13 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP {
{"cublasCgemvBatched_64", {"hipblasCgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasZgemvBatched", {"hipblasZgemvBatched_v2", "rocblas_zgemv_batched", CONV_LIB_FUNC, API_BLAS, 7}},
{"cublasZgemvBatched_64", {"hipblasZgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSHgemvBatched", {"hipblasHSHgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSHgemvBatched", {"hipblasHSHgemvBatched", "rocblas_hshgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasHSHgemvBatched_64", {"hipblasHSHgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSSgemvBatched", {"hipblasHSSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasHSSgemvBatched", {"hipblasHSSgemvBatched", "rocblas_hssgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasHSSgemvBatched_64", {"hipblasHSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSTgemvBatched", {"hipblasTSTgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSTgemvBatched", {"hipblasTSTgemvBatched", "rocblas_tstgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasTSTgemvBatched_64", {"hipblasTSTgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "rocblas_tssgemv_batched", CONV_LIB_FUNC, API_BLAS, 7, HIP_UNSUPPORTED}},
{"cublasTSSgemvBatched_64", {"hipblasTSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
{"cublasSgemvStridedBatched_64", {"hipblasSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, 7, UNSUPPORTED}},
Expand Down Expand Up @@ -2096,6 +2096,10 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
{"rocblas_dtrmm", {HIP_3050, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_ctrmm", {HIP_3050, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_ztrmm", {HIP_3050, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_hshgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_hssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_tstgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
{"rocblas_tssgemv_batched", {HIP_6000, HIP_0, HIP_0, HIP_LATEST}},
};

const std::map<llvm::StringRef, hipAPIChangedVersions> HIP_BLAS_FUNCTION_CHANGED_VER_MAP {
Expand Down
9 changes: 6 additions & 3 deletions src/CUDA2HIP_Device_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ THE SOFTWARE.
// Maps the names of CUDA Device/Host types to the corresponding HIP types
const std::map<llvm::StringRef, hipCounter> CUDA_DEVICE_TYPE_NAME_MAP {
// float16 Precision Device types
{"__half", {"__half", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}},
{"__half", {"__half", "rocblas_half", CONV_DEVICE_TYPE, API_RUNTIME, 2}},
{"__half_raw", {"__half_raw", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}},
{"__half2", {"__half2", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}},
{"__half2_raw", {"__half2_raw", "", CONV_DEVICE_TYPE, API_RUNTIME, 2}},
// Bfloat16 Precision Device types
{"__nv_bfloat16", {"__hip_bfloat16", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}},
{"nv_bfloat16", {"hip_bfloat16", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}},
{"__nv_bfloat16", {"hip_bfloat16", "rocblas_bfloat16", CONV_DEVICE_TYPE, API_RUNTIME, 2}},
{"__nv_bfloat16_raw", {"__hip_bfloat16_raw", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}},
{"__nv_bfloat162", {"__hip_bfloat162", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}},
{"nv_bfloat162", {"hip_bfloat162", "", CONV_DEVICE_TYPE, API_RUNTIME, 2, UNSUPPORTED}},
Expand Down Expand Up @@ -83,4 +82,8 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_DEVICE_TYPE_NAME_VER_MAP {
{"__half2", {HIP_1060, HIP_0, HIP_0 }},
{"__half_raw", {HIP_1090, HIP_0, HIP_0 }},
{"__half2_raw", {HIP_1090, HIP_0, HIP_0 }},
{"hip_bfloat16", {HIP_3050, HIP_0, HIP_0 }},

{"rocblas_half", {HIP_1050, HIP_0, HIP_0 }},
{"rocblas_bfloat16", {HIP_3050, HIP_0, HIP_0 }},
};
43 changes: 43 additions & 0 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,29 @@ int main() {
const float** const fBarray_const = const_cast<const float**>(fBarray);
float** fCarray = 0;
float** fTauarray = 0;
float** fyarray = 0;

// CHECK: rocblas_half** hAarray = 0;
__half** hAarray = 0;
// CHECK: const rocblas_half** const hAarray_const = const_cast<const rocblas_half**>(hAarray);
const __half** const hAarray_const = const_cast<const __half**>(hAarray);
// CHECK: rocblas_half** hxarray = 0;
__half** hxarray = 0;
// CHECK: const rocblas_half** const hxarray_const = const_cast<const rocblas_half**>(hxarray_const);
const __half** const hxarray_const = const_cast<const __half**>(hxarray_const);
// CHECK: rocblas_half** hyarray = 0;
__half** hyarray = 0;

// CHECK: rocblas_bfloat16** bf16Aarray = 0;
__nv_bfloat16** bf16Aarray = 0;
// CHECK: const rocblas_bfloat16** const bf16Aarray_const = const_cast<const rocblas_bfloat16**>(bf16Aarray);
const __nv_bfloat16** const bf16Aarray_const = const_cast<const __nv_bfloat16**>(bf16Aarray);
// CHECK: rocblas_bfloat16** bf16xarray = 0;
__nv_bfloat16** bf16xarray = 0;
// CHECK: const rocblas_bfloat16** const bf16xarray_const = const_cast<const rocblas_bfloat16**>(bf16xarray_const);
const __nv_bfloat16** const bf16xarray_const = const_cast<const __nv_bfloat16**>(bf16xarray_const);
// CHECK: rocblas_bfloat16** bf16yarray = 0;
__nv_bfloat16** bf16yarray = 0;

double da = 0;
double dA = 0;
Expand Down Expand Up @@ -1770,6 +1793,26 @@ int main() {
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_zgemv_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_int m, rocblas_int n, const rocblas_double_complex* alpha, const rocblas_double_complex* A, rocblas_int lda, rocblas_stride strideA, const rocblas_double_complex* x, rocblas_int incx, rocblas_stride stridex, const rocblas_double_complex* beta, rocblas_double_complex* y, rocblas_int incy, rocblas_stride stridey, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_zgemv_strided_batched(blasHandle, blasOperation, m, n, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexx, incx, stridex, &dcomplexb, &dcomplexy, incy, stridey, batchCount);
blasStatus = cublasZgemvStridedBatched(blasHandle, blasOperation, m, n, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexx, incx, stridex, &dcomplexb, &dcomplexy, incy, stridey, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSHgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* const Aarray[], int lda, const __half* const xarray[], int incx, const float* beta, __half* const yarray[], int incy, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hshgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const x[], rocblas_int incx, const float* beta, rocblas_half* const y[], rocblas_int incy, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hshgemv_batched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, hyarray, incy, batchCount);
blasStatus = cublasHSHgemvBatched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, hyarray, incy, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHSSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __half* const Aarray[], int lda, const __half* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hssgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hssgemv_batched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, fyarray, incy, batchCount);
blasStatus = cublasHSSgemvBatched(blasHandle, blasOperation, m, n, &fa, hAarray_const, lda, hxarray_const, incx, &fb, fyarray, incy, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSTgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* const Aarray[], int lda, const __nv_bfloat16* const xarray[], int incx, const float* beta, __nv_bfloat16* const yarray[], int incy, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_tstgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* const A[], rocblas_int lda, const rocblas_bfloat16* const x[], rocblas_int incx, const float* beta, rocblas_bfloat16* const y[], rocblas_int incy, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_tstgemv_batched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, bf16yarray, incy, batchCount);
blasStatus = cublasTSTgemvBatched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, bf16yarray, incy, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasTSSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const __nv_bfloat16* const Aarray[], int lda, const __nv_bfloat16* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_tssgemv_batched(rocblas_handle handle, rocblas_operation trans, rocblas_int m, rocblas_int n, const float* alpha, const rocblas_bfloat16* const A[], rocblas_int lda, const rocblas_bfloat16* const x[], rocblas_int incx, const float* beta, float* const y[], rocblas_int incy, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_tssgemv_batched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount);
blasStatus = cublasTSSgemvBatched(blasHandle, blasOperation, m, n, &fa, bf16Aarray_const, lda, bf16xarray_const, incx, &fb, fyarray, incy, batchCount);
#endif

return 0;
Expand Down

0 comments on commit 5f4e5b1

Please sign in to comment.