-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BLAS][portBLAS] Add bindings for half and some gemm_batch group APIs #576
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
#include <CL/sycl.hpp> | ||
#endif | ||
|
||
#include "portblas_common.hpp" | ||
#include "oneapi/mkl/exceptions.hpp" | ||
#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" | ||
|
||
|
@@ -32,19 +33,33 @@ namespace blas { | |
namespace portblas { | ||
namespace column_major { | ||
|
||
constexpr bool is_column_major() { | ||
return true; | ||
} | ||
|
||
// BUFFER | ||
void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, | ||
std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, | ||
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b, | ||
std::int64_t ldb, sycl::half beta, sycl::buffer<sycl::half, 1> &c, std::int64_t ldc) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, | ||
ldc); | ||
#else | ||
throw unimplemented("blas", "gemm", " half"); | ||
#endif | ||
} | ||
|
||
void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, | ||
std::int64_t m, std::int64_t n, std::int64_t k, float alpha, | ||
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b, | ||
std::int64_t ldb, float beta, sycl::buffer<float, 1> &c, std::int64_t ldc) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, | ||
ldc); | ||
#else | ||
throw unimplemented("blas", "gemm", " for different argument data types"); | ||
#endif | ||
} | ||
|
||
// USM | ||
|
@@ -53,31 +68,56 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl: | |
const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, | ||
sycl::half beta, sycl::half *c, std::int64_t ldc, | ||
const std::vector<sycl::event> &dependencies) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, | ||
c, ldc, dependencies); | ||
#else | ||
throw unimplemented("blas", "gemm", " for USM"); | ||
#endif | ||
} | ||
|
||
sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, | ||
std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, | ||
std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, | ||
std::int64_t ldc, const std::vector<sycl::event> &dependencies) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, | ||
c, ldc, dependencies); | ||
#else | ||
throw unimplemented("blas", "gemm", " for USM"); | ||
#endif | ||
} | ||
} // namespace column_major | ||
|
||
namespace row_major { | ||
|
||
constexpr bool is_column_major() { | ||
return false; | ||
} | ||
|
||
// BUFFER | ||
void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, | ||
std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, | ||
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b, | ||
std::int64_t ldb, sycl::half beta, sycl::buffer<sycl::half, 1> &c, std::int64_t ldc) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, | ||
ldc); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, we don't support any row major operator in portBLAS, so they are generally not enabled. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are already checks that throw unimplemented when |
||
#else | ||
throw unimplemented("blas", "gemm", " half"); | ||
#endif | ||
} | ||
|
||
void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, | ||
std::int64_t m, std::int64_t n, std::int64_t k, float alpha, | ||
sycl::buffer<sycl::half, 1> &a, std::int64_t lda, sycl::buffer<sycl::half, 1> &b, | ||
std::int64_t ldb, float beta, sycl::buffer<float, 1> &c, std::int64_t ldc) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, | ||
ldc); | ||
#else | ||
throw unimplemented("blas", "gemm", " for different argument data types"); | ||
#endif | ||
} | ||
|
||
// USM | ||
|
@@ -86,14 +126,24 @@ sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl: | |
const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, | ||
sycl::half beta, sycl::half *c, std::int64_t ldc, | ||
const std::vector<sycl::event> &dependencies) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, | ||
c, ldc, dependencies); | ||
#else | ||
throw unimplemented("blas", "gemm", " for USM"); | ||
#endif | ||
} | ||
|
||
sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, | ||
std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, | ||
std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, | ||
std::int64_t ldc, const std::vector<sycl::event> &dependencies) { | ||
#ifdef ENABLE_PORTBLAS_HALF | ||
CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, | ||
c, ldc, dependencies); | ||
#else | ||
throw unimplemented("blas", "gemm", " for USM"); | ||
#endif | ||
} | ||
|
||
} // namespace row_major | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to #554 and PR #571 this var should be named with the prefix
ONEAPI_ONEMKL_
. Could you update it?I know it would be the only one named correctly now, but I hope that PR will be merged soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I was planning to update this PR once #571 is merged. It will be easier to do once I can add
ONEAPI_ONEMKL_ENABLE_PORTBLAS_HALF
to the list you introduce in https://github.com/oneapi-src/oneMKL/pull/571/files#diff-148715d6ea0c0ea0a346af3f6bd610d010d490eca35ac6a9b408748f7ca9e3f4R54