From 50666a3b6f439d027aa6a08c72e116bc8e8907ba Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Fri, 27 Sep 2024 03:29:51 -0700 Subject: [PATCH 1/2] Add gemv_batch cublas impl Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 67 ++++++++++++++--------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 009bb9541..0744e2975 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -493,35 +493,52 @@ sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t throw unimplemented("blas", "gemv_batch", "for column_major layout"); } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, - const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, - float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); +template +inline sycl::event gemv_batch(const char *func_name, Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, T *alpha, const T **a, int64_t *lda, const T **x, + int64_t *incx, T *beta, T **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], lda[i], incx[i], incy[i], group_size[i]); + } + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + cublasStatus_t err; + auto **a_ = reinterpret_cast(a); + auto **x_ = reinterpret_cast(x); + auto **y_ = reinterpret_cast(y); + for (int64_t i = 0; i < group_count; i++) { + CUBLAS_ERROR_FUNC_T_SYNC( + func_name, func, err, handle, get_cublas_operation(trans[i]), + (int)m[i], (int)n[i], + (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], x_ + offset, (int)incx[i], + (cuDataType *)&beta[i], y_ + offset, (int)incy[i], (int)group_size[i]); + offset += group_size[i]; + } + }); + }); + return done; } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, - const double **a, int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#define GEMV_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event gemv_batch( \ + sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ + int64_t *lda, const TYPE **x, int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \ + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { \ + return gemv_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, \ + incy, group_count, group_size, dependencies); \ + } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +GEMV_BATCH_LAUNCHER_USM(float, cublasSgemvBatched) +GEMV_BATCH_LAUNCHER_USM(double, cublasDgemvBatched) +GEMV_BATCH_LAUNCHER_USM(std::complex, cublasCgemvBatched) +GEMV_BATCH_LAUNCHER_USM(std::complex, cublasZgemvBatched) -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#undef GEMV_BATCH_LAUNCHER_USM sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, From 44853da5005a20515f25c0783803ea7dc31aff8a Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 9 Oct 2024 06:43:31 -0700 Subject: [PATCH 2/2] Use new native_enqueue. Signed-off-by: JackAKirk --- src/blas/backends/cublas/cublas_batch.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index 810570778..2975e6c58 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -521,7 +521,7 @@ inline sycl::event gemv_batch(const char *func_name, Func func, sycl::queue &que auto **x_ = reinterpret_cast(x); auto **y_ = reinterpret_cast(y); for (int64_t i = 0; i < group_count; i++) { - CUBLAS_ERROR_FUNC_T_SYNC( + cublas_native_named_func( func_name, func, err, handle, get_cublas_operation(trans[i]), (int)m[i], (int)n[i], (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], x_ + offset, (int)incx[i], @@ -533,13 +533,14 @@ inline sycl::event gemv_batch(const char *func_name, Func func, sycl::queue &que return done; } -#define GEMV_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemv_batch( \ - sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ - int64_t *lda, const TYPE **x, int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \ - int64_t group_count, int64_t *group_size, const std::vector &dependencies) { \ - return gemv_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, \ - incy, group_count, group_size, dependencies); \ +#define GEMV_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event gemv_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, \ + TYPE *alpha, const TYPE **a, int64_t *lda, const TYPE **x, \ + int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemv_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, \ + x, incx, beta, y, incy, group_count, group_size, dependencies); \ } GEMV_BATCH_LAUNCHER_USM(float, cublasSgemvBatched)