Skip to content

Commit

Permalink
[HIPIFY][ROCm#1086][rocBLAS][tests] Added tests on rocblas_half fun…
Browse files Browse the repository at this point in the history
…ctions
  • Loading branch information
emankov committed Oct 23, 2023
1 parent 5f4e5b1 commit 6920a21
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
31 changes: 28 additions & 3 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,31 @@ int main() {
float** fTauarray = 0;
float** fyarray = 0;

// CHECK: rocblas_half* ha = 0;
__half* ha = 0;
// CHECK: rocblas_half* hA = 0;
__half* hA = 0;
// CHECK: rocblas_half* hb = 0;
__half* hb = 0;
// CHECK: rocblas_half* hB = 0;
__half* hB = 0;
// CHECK: rocblas_half* hc = 0;
__half* hc = 0;
// CHECK: rocblas_half* hC = 0;
__half* hC = 0;

// CHECK: rocblas_half** hAarray = 0;
__half** hAarray = 0;
// CHECK: const rocblas_half** const hAarray_const = const_cast<const rocblas_half**>(hAarray);
const __half** const hAarray_const = const_cast<const __half**>(hAarray);
// CHECK: rocblas_half** hBarray = 0;
__half** hBarray = 0;
// CHECK: const rocblas_half** const hBarray_const = const_cast<const rocblas_half**>(hBarray);
const __half** const hBarray_const = const_cast<const __half**>(hBarray);
// CHECK: rocblas_half** hCarray = 0;
__half** hCarray = 0;
// CHECK: const rocblas_half** const hCarray_const = const_cast<const rocblas_half**>(hCarray);
const __half** const hCarray_const = const_cast<const __half**>(hCarray);
// CHECK: rocblas_half** hxarray = 0;
__half** hxarray = 0;
// CHECK: const rocblas_half** const hxarray_const = const_cast<const rocblas_half**>(hxarray_const);
Expand Down Expand Up @@ -1240,9 +1261,10 @@ int main() {
blasStatus = cublasZgemm_v2(blasHandle, transa, transb, m, n, k, &dcomplexa, &dcomplexA, lda, &dcomplexB, ldb, &dcomplexb, &dcomplexC, ldc);

// TODO: #1281
// TODO: __half -> rocblas_half
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* A, int lda, const __half* B, int ldb, const __half* beta, __half* C, int ldc);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* A, rocblas_int lda, const rocblas_half* B, rocblas_int ldb, const rocblas_half* beta, rocblas_half* C, rocblas_int ldc);
// CHECK: blasStatus = rocblas_hgemm(blasHandle, transa, transb, m, n, k, ha, hA, lda, hB, ldb, hb, hC, ldc);
blasStatus = cublasHgemm(blasHandle, transa, transb, m, n, k, ha, hA, lda, hB, ldb, hb, hC, ldc);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* const Aarray[], int lda, const float* const Barray[], int ldb, const float* beta, float* const Carray[], int ldc, int batchCount);
Expand All @@ -1257,9 +1279,10 @@ int main() {
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// TODO: #1281
// TODO: __half -> rocblas_half
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const B[], rocblas_int ldb, const rocblas_half* beta, rocblas_half* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_batched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
Expand Down Expand Up @@ -1638,9 +1661,11 @@ int main() {
// CHECK: blasStatus = rocblas_zgemm_strided_batched(blasHandle, transa, transb, m, n, k, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexB, ldb, strideB, &dcomplexb, &dcomplexC, ldc, strideC, batchCount);
blasStatus = cublasZgemmStridedBatched(blasHandle, transa, transb, m, n, k, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexB, ldb, strideB, &dcomplexb, &dcomplexC, ldc, strideC, batchCount);

// TODO: __half -> rocblas_half
// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmStridedBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* A, int lda, long long int strideA, const __half* B, int ldb, long long int strideB, const __half* beta, __half* C, int ldc, long long int strideC, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* A, rocblas_int lda, rocblas_stride stride_a, const rocblas_half* B, rocblas_int ldb, rocblas_stride stride_b, const rocblas_half* beta, rocblas_half* C, rocblas_int ldc, rocblas_stride stride_c, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_strided_batched(blasHandle, transa, transb, m, n, k, ha, hA, lda, strideA, hB, ldb, strideB, hb, hC, ldc, strideC, batchCount);
blasStatus = cublasHgemmStridedBatched(blasHandle, transa, transb, m, n, k, ha, hA, lda, strideA, hB, ldb, strideB, hb, hC, ldc, strideC, batchCount);

void* aptr = nullptr;
void* Aptr = nullptr;
Expand Down
42 changes: 39 additions & 3 deletions tests/unit_tests/synthetic/libraries/cublas2rocblas_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,38 @@ int main() {
float** fCarray = 0;
float** fTauarray = 0;

// CHECK: rocblas_half* ha = 0;
__half* ha = 0;
// CHECK: rocblas_half* hA = 0;
__half* hA = 0;
// CHECK: rocblas_half* hb = 0;
__half* hb = 0;
// CHECK: rocblas_half* hB = 0;
__half* hB = 0;
// CHECK: rocblas_half* hc = 0;
__half* hc = 0;
// CHECK: rocblas_half* hC = 0;
__half* hC = 0;

// CHECK: rocblas_half** hAarray = 0;
__half** hAarray = 0;
// CHECK: const rocblas_half** const hAarray_const = const_cast<const rocblas_half**>(hAarray);
const __half** const hAarray_const = const_cast<const __half**>(hAarray);
// CHECK: rocblas_half** hBarray = 0;
__half** hBarray = 0;
// CHECK: const rocblas_half** const hBarray_const = const_cast<const rocblas_half**>(hBarray);
const __half** const hBarray_const = const_cast<const __half**>(hBarray);
// CHECK: rocblas_half** hCarray = 0;
__half** hCarray = 0;
// CHECK: const rocblas_half** const hCarray_const = const_cast<const rocblas_half**>(hCarray);
const __half** const hCarray_const = const_cast<const __half**>(hCarray);
// CHECK: rocblas_half** hxarray = 0;
__half** hxarray = 0;
// CHECK: const rocblas_half** const hxarray_const = const_cast<const rocblas_half**>(hxarray_const);
const __half** const hxarray_const = const_cast<const __half**>(hxarray_const);
// CHECK: rocblas_half** hyarray = 0;
__half** hyarray = 0;

double da = 0;
double dA = 0;
double db = 0;
Expand Down Expand Up @@ -1345,9 +1377,10 @@ int main() {
blasStatus = cublasZgemm_v2(blasHandle, transa, transb, m, n, k, &dcomplexa, &dcomplexA, lda, &dcomplexB, ldb, &dcomplexb, &dcomplexC, ldc);

// TODO: #1281
// TODO: __half -> rocblas_half
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* A, int lda, const __half* B, int ldb, const __half* beta, __half* C, int ldc);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* A, rocblas_int lda, const rocblas_half* B, rocblas_int ldb, const rocblas_half* beta, rocblas_half* C, rocblas_int ldc);
// CHECK: blasStatus = rocblas_hgemm(blasHandle, transa, transb, m, n, k, ha, hA, lda, hB, ldb, hb, hC, ldc);
blasStatus = cublasHgemm(blasHandle, transa, transb, m, n, k, ha, hA, lda, hB, ldb, hb, hC, ldc);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* const Aarray[], int lda, const float* const Barray[], int ldb, const float* beta, float* const Carray[], int ldc, int batchCount);
Expand All @@ -1362,9 +1395,10 @@ int main() {
blasStatus = cublasDgemmBatched(blasHandle, transa, transb, m, n, k, &da, dAarray_const, lda, dBarray_const, ldb, &db, dCarray, ldc, batchCount);

// TODO: #1281
// TODO: __half -> rocblas_half
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* const Aarray[], int lda, const __half* const Barray[], int ldb, const __half* beta, __half* const Carray[], int ldc, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* const A[], rocblas_int lda, const rocblas_half* const B[], rocblas_int ldb, const rocblas_half* beta, rocblas_half* const C[], rocblas_int ldc, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_batched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);
blasStatus = cublasHgemmBatched(blasHandle, transa, transb, m, n, k, ha, hAarray_const, lda, hBarray_const, ldb, hb, hCarray, ldc, batchCount);

// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const Barray[], int ldb, const cuComplex* beta, cuComplex* const Carray[], int ldc, int batchCount);
Expand Down Expand Up @@ -1767,9 +1801,11 @@ int main() {
// CHECK: blasStatus = rocblas_zgemm_strided_batched(blasHandle, transa, transb, m, n, k, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexB, ldb, strideB, &dcomplexb, &dcomplexC, ldc, strideC, batchCount);
blasStatus = cublasZgemmStridedBatched(blasHandle, transa, transb, m, n, k, &dcomplexa, &dcomplexA, lda, strideA, &dcomplexB, ldb, strideB, &dcomplexb, &dcomplexC, ldc, strideC, batchCount);

// TODO: __half -> rocblas_half
// TODO: #1281
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmStridedBatched(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* A, int lda, long long int strideA, const __half* B, int ldb, long long int strideB, const __half* beta, __half* C, int ldc, long long int strideC, int batchCount);
// ROC: ROCBLAS_EXPORT rocblas_status rocblas_hgemm_strided_batched(rocblas_handle handle, rocblas_operation transA, rocblas_operation transB, rocblas_int m, rocblas_int n, rocblas_int k, const rocblas_half* alpha, const rocblas_half* A, rocblas_int lda, rocblas_stride stride_a, const rocblas_half* B, rocblas_int ldb, rocblas_stride stride_b, const rocblas_half* beta, rocblas_half* C, rocblas_int ldc, rocblas_stride stride_c, rocblas_int batch_count);
// CHECK: blasStatus = rocblas_hgemm_strided_batched(blasHandle, transa, transb, m, n, k, ha, hA, lda, strideA, hB, ldb, strideB, hb, hC, ldc, strideC, batchCount);
blasStatus = cublasHgemmStridedBatched(blasHandle, transa, transb, m, n, k, ha, hA, lda, strideA, hB, ldb, strideB, hb, hC, ldc, strideC, batchCount);

void* aptr = nullptr;
void* Aptr = nullptr;
Expand Down

0 comments on commit 6920a21

Please sign in to comment.