Skip to content

Commit

Permalink
Add oneMKL GEMM support for SYCL
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Dec 7, 2023
1 parent 4ce619e commit 763c7d7
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 12 deletions.
13 changes: 1 addition & 12 deletions blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,17 +777,6 @@ KOKKOSBLAS2_CGEMV_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false)
namespace KokkosBlas {
namespace Impl {

inline oneapi::mkl::transpose mode_kk_to_onemkl(char mode_kk) {
switch (toupper(mode_kk)) {
case 'N': return oneapi::mkl::transpose::nontrans;
case 'T': return oneapi::mkl::transpose::trans;
case 'C': return oneapi::mkl::transpose::conjtrans;
default:;
}
throw std::invalid_argument(
"Invalid mode for oneMKL (should be one of N, T, C)");
}

template <typename T, bool is_complex = false>
struct kokkos_to_std_type_map {
using type = T;
Expand Down Expand Up @@ -829,7 +818,7 @@ struct kokkos_to_std_type_map<T, true> {
bool row_major = std::is_same<Kokkos::LayoutRight, LAYOUT>::value; \
const std::int64_t M = A.extent(0); \
const std::int64_t N = A.extent(1); \
oneapi::mkl::transpose trans = mode_kk_to_onemkl(kk_trans[0]); \
oneapi::mkl::transpose trans = trans_mode_kk_to_onemkl(kk_trans[0]); \
const std::int64_t LDA = row_major ? A.stride(0) : A.stride(1); \
std::string label = "KokkosBlas::gemv[TPL_ONEMKL," + \
Kokkos::ArithTraits<SCALAR>::name() + "]"; \
Expand Down
40 changes: 40 additions & 0 deletions blas/tpls/KokkosBlas3_gemm_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,46 @@ KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_ROCBLAS(Kokkos::complex<float>,
Kokkos::LayoutRight, Kokkos::HIPSpace)

#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)

#define KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(SCALAR, LAYOUT, MEMSPACE) \
template <> \
struct gemm_tpl_spec_avail< \
Kokkos::Experimental::SYCL, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > > { \
enum : bool { value = true }; \
};

KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(double, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(float, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace)

KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(double, Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(float, Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<double>,
Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)
KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex<float>, Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace)

#endif

} // namespace Impl
} // namespace KokkosBlas

Expand Down
140 changes: 140 additions & 0 deletions blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,144 @@ KOKKOSBLAS3_CGEMM_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false)
} // namespace KokkosBlas
#endif // KOKKOSKERNELS_ENABLE_TPL_ROCBLAS

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
#include <KokkosBlas_tpl_spec.hpp>
#include <oneapi/mkl/blas.hpp>

namespace KokkosBlas::Impl {

/*!
SCALAR_TYPE is the Kokkos Kernels type
TPL_SCALAR_TYPE is the type MKL accents for SCALAR_TYPE
*/
#define KOKKOSBLAS3_XGEMM_MKL(SCALAR_TYPE, TPL_SCALAR_TYPE, LAYOUT, MEM_SPACE, \
ETI_SPEC_AVAIL) \
template <> \
struct GEMM< \
Kokkos::Experimental::SYCL, \
Kokkos::View<const SCALAR_TYPE**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<const SCALAR_TYPE**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<SCALAR_TYPE**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
true, ETI_SPEC_AVAIL> { \
typedef SCALAR_TYPE SCALAR; \
typedef Kokkos::View< \
const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > \
AViewType; \
typedef Kokkos::View< \
const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > \
BViewType; \
typedef Kokkos::View< \
SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > \
CViewType; \
\
static void gemm(const typename CViewType::execution_space& space, \
const char transA[], const char transB[], \
typename AViewType::const_value_type& alpha, \
const AViewType& A, const BViewType& B, \
typename CViewType::const_value_type& beta, \
const CViewType& C) { \
Kokkos::Profiling::pushRegion("KokkosBlas::gemm[TPL_MKL," #SCALAR_TYPE \
"]"); \
\
const bool A_t = (transA[0] != 'N') && (transA[0] != 'n'); \
const int64_t M = static_cast<int64_t>(C.extent(0)); \
const int64_t N = static_cast<int64_t>(C.extent(1)); \
const int64_t K = static_cast<int64_t>(A.extent(A_t ? 0 : 1)); \
\
constexpr bool is_lr = std::is_same<Kokkos::LayoutRight, LAYOUT>::value; \
\
const int64_t ast = is_lr ? A.stride(0) : A.stride(1); \
const int64_t lda = ast == 0 ? 1 : ast; \
const int64_t bst = is_lr ? B.stride(0) : B.stride(1); \
const int64_t ldb = bst == 0 ? 1 : bst; \
const int64_t cst = is_lr ? C.stride(0) : C.stride(1); \
const int64_t ldc = cst == 0 ? 1 : cst; \
\
oneapi::mkl::transpose transa = trans_mode_kk_to_onemkl(transA[0]); \
oneapi::mkl::transpose transb = trans_mode_kk_to_onemkl(transB[0]); \
oneapi::mkl::blas::compute_mode mode = \
oneapi::mkl::blas::compute_mode::standard; \
\
if constexpr (!is_lr) { \
oneapi::mkl::blas::column_major::gemm( \
space.sycl_queue(), transa, transb, M, N, K, alpha, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(A.data()), lda, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(B.data()), ldb, beta, \
reinterpret_cast<TPL_SCALAR_TYPE*>(C.data()), ldc, mode); \
} else { \
oneapi::mkl::blas::row_major::gemm( \
space.sycl_queue(), transa, transb, M, N, K, alpha, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(A.data()), lda, \
reinterpret_cast<const TPL_SCALAR_TYPE*>(B.data()), ldb, beta, \
reinterpret_cast<TPL_SCALAR_TYPE*>(C.data()), ldc, mode); \
} \
\
Kokkos::Profiling::popRegion(); \
} \
};

#define KOKKOSBLAS3_DGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(double, double, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL)

#define KOKKOSBLAS3_SGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(float, float, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL)

#define KOKKOSBLAS3_ZGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(Kokkos::complex<double>, std::complex<double>, LAYOUT, \
MEM_SPACE, ETI_SPEC_AVAIL)

#define KOKKOSBLAS3_CGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
KOKKOSBLAS3_XGEMM_MKL(Kokkos::complex<float>, std::complex<float>, LAYOUT, \
MEM_SPACE, ETI_SPEC_AVAIL)

KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)

KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)

KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)

KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutLeft,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, true)
KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutRight,
Kokkos::Experimental::SYCLDeviceUSMSpace, false)
} // namespace KokkosBlas::Impl
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL && KOKKOS_ENABLE_SYCL

#endif
23 changes: 23 additions & 0 deletions blas/tpls/KokkosBlas_tpl_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,27 @@ struct MagmaSingleton {
} // namespace KokkosBlas
#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
#include <oneapi/mkl/types.hpp>

namespace KokkosBlas {
namespace Impl {

/// \brief This function converts KK transpose mode to MKL transpose mode
inline oneapi::mkl::transpose trans_mode_kk_to_onemkl(char mode_kk) {
switch (toupper(mode_kk)) {
case 'N': return oneapi::mkl::transpose::nontrans;
case 'T': return oneapi::mkl::transpose::trans;
case 'C': return oneapi::mkl::transpose::conjtrans;
default:;
}
throw std::invalid_argument(
"Invalid mode for oneMKL (should be one of N, T, C)");
}

} // namespace Impl
} // namespace KokkosBlas

#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

#endif // KOKKOSBLAS_TPL_SPEC_HPP_

0 comments on commit 763c7d7

Please sign in to comment.