Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] GEMM: add support for mixed input scalar types #1557

Draft
wants to merge 26 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
76f494c
GEMM: Move Serial, Team and TeamVector implementations to KokkosBlas
Sep 7, 2022
7ac1184
GEMM: move unit tests to Blas
Sep 7, 2022
51185f3
GEMM: move MKL implementation of SerialGemm to dedicated TPL header
Sep 9, 2022
016095b
GEMM: connect TeamVectorGemm to the selective interface
Sep 9, 2022
d945717
GEMM: implicit MemberType
Sep 9, 2022
88de7f2
GEMM: bring back batched interfaces for backward compatibility
Sep 22, 2022
38c9989
MKL: move utils to common header + fix macro duplication
Sep 22, 2022
de3d26c
GEMM: fix "batched" in names
Sep 27, 2022
2105dd9
GEMM: refactor crossing of A/B matrix transposes
Sep 13, 2022
5b6712b
GEMM: implement ConjTranspose
Sep 15, 2022
fc5bbb1
TeamGemv: move {Team,TeamVector}Internal to KokkosBlas2_team_gemv_int…
Sep 10, 2022
9f99fc4
TeamGemv: rename impl header
Sep 10, 2022
13a16d1
TeamGemv: remove unused headers
Sep 10, 2022
8d06e07
Gemv: move functor-level interfaces to the top-level header
Sep 10, 2022
2bcaff5
Gemv: implicit MemberType
Sep 26, 2022
643e8be
Merge branch 'gemm-transpose-refactoring' into gemv-transpose-refacto…
Sep 28, 2022
4e145e3
Merge branch 'fix-gemv-blas-headers' into gemv-transpose-refactoring
Sep 28, 2022
d989d26
GEMV: refactor A matrix transpose
Sep 27, 2022
7728904
Merge branch 'gemm-add-conjtranspose' into threadvector-kernels
Sep 28, 2022
900226a
Merge branch 'gemv-transpose-refactoring' into threadvector-kernels
Sep 28, 2022
cc8e867
SET/SCAL: add ThreadVector implementations and unit tests
Sep 27, 2022
61a1993
GEMV: add ThreadVector implementation and unit test
Sep 27, 2022
d079ea6
GEMM: add ThreadVector implementation and unit test
Sep 27, 2022
1823f05
GEMM: Support for mixed scalar types in functor-level kernels
Sep 27, 2022
cb81dff
GEMM: fix compilation errors in InnerGemmFixB
Sep 29, 2022
c0e0e25
GEMM: add mixed scalar support to InnerGemmFix{A,B}
Sep 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "KokkosBatched_Schur_Serial_Internal.hpp"
#include "KokkosBatched_RightEigenvectorFromSchur_Serial_Internal.hpp"
#include "KokkosBatched_LeftEigenvectorFromSchur_Serial_Internal.hpp"
#include "KokkosBatched_Gemm_Serial_Internal.hpp"

namespace KokkosBatched {

Expand Down
344 changes: 5 additions & 339 deletions batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,343 +43,9 @@
#define __KOKKOSBATCHED_GEMM_SERIAL_IMPL_HPP__

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Gemm_Serial_Internal.hpp"
#include "KokkosBlas3_gemm.hpp"

namespace KokkosBatched {
/********************* BEGIN functor-level routines *********************/
///
/// Serial Impl
/// ===========

///
/// Implemented:
/// NT/NT, T/NT, NT/T, T/T
///
/// Not yet immplemented (ConjTranspose):
/// CT/NT, NT/CT, CT/CT
///

///
/// NT/NT
///

#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::NoTranspose, Trans::NoTranspose,
Algo::Gemm::CompactMKL>::invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B,
const ScalarType beta,
const CViewType &C) {
typedef typename CViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

const int m = C.extent(0), n = C.extent(1), k = A.extent(1);

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(
vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
const MKL_COMPACT_PACK format =
vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
int r_val = 0;
if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) {
mkl_dgemm_compact(MKL_COL_MAJOR, MKL_NOTRANS, MKL_NOTRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_1(),
(const double *)B.data(), B.stride_1(), beta,
(double *)C.data(), C.stride_1(), format,
(MKL_INT)vector_type::vector_length);
} else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) {
mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_NOTRANS, MKL_NOTRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_0(),
(const double *)B.data(), B.stride_0(), beta,
(double *)C.data(), C.stride_0(), format,
(MKL_INT)vector_type::vector_length);
} else {
r_val = -1;
}
return r_val;
}
#endif

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::NoTranspose, Trans::NoTranspose,
Algo::Gemm::Unblocked>::invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B,
const ScalarType beta,
const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
C.stride_0(), C.stride_1());
}

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::NoTranspose, Trans::NoTranspose, Algo::Gemm::Blocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
C.stride_0(), C.stride_1());
}

///
/// T/NT
///

#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::Transpose, Trans::NoTranspose,
Algo::Gemm::CompactMKL>::invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B,
const ScalarType beta,
const CViewType &C) {
typedef typename CViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

const int m = C.extent(0), n = C.extent(1), k = A.extent(0);

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(
vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
const MKL_COMPACT_PACK format =
vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
int r_val = 0;
if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) {
mkl_dgemm_compact(MKL_COL_MAJOR, MKL_TRANS, MKL_NOTRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_1(),
(const double *)B.data(), B.stride_1(), beta,
(double *)C.data(), C.stride_1(), format,
(MKL_INT)vector_type::vector_length);
} else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) {
mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_TRANS, MKL_NOTRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_0(),
(const double *)B.data(), B.stride_0(), beta,
(double *)C.data(), C.stride_0(), format,
(MKL_INT)vector_type::vector_length);
} else {
r_val = -1;
}
return r_val;
}
#endif

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Unblocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
C.stride_0(), C.stride_1());
}

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::Transpose, Trans::NoTranspose, Algo::Gemm::Blocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), B.data(), B.stride_0(), B.stride_1(), beta, C.data(),
C.stride_0(), C.stride_1());
}

///
/// NT/T
///

#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::NoTranspose, Trans::Transpose,
Algo::Gemm::CompactMKL>::invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B,
const ScalarType beta,
const CViewType &C) {
typedef typename CViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

const int m = C.extent(0), n = C.extent(1), k = A.extent(1);

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(
vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
const MKL_COMPACT_PACK format =
vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
int r_val = 0;
if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) {
mkl_dgemm_compact(MKL_COL_MAJOR, MKL_NOTRANS, MKL_TRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_1(),
(const double *)B.data(), B.stride_1(), beta,
(double *)C.data(), C.stride_1(), format,
(MKL_INT)vector_type::vector_length);
} else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) {
mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_NOTRANS, MKL_TRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_0(),
(const double *)B.data(), B.stride_0(), beta,
(double *)C.data(), C.stride_0(), format,
(MKL_INT)vector_type::vector_length);
} else {
r_val = -1;
}
return r_val;
}
#endif

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Unblocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
C.stride_0(), C.stride_1());
}

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::NoTranspose, Trans::Transpose, Algo::Gemm::Blocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
C.extent(0), C.extent(1), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
C.stride_0(), C.stride_1());
}

///
/// T/T
///

#if defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_BATCHED__) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::CompactMKL>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
typedef typename CViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

const int m = C.extent(0), n = C.extent(1), k = A.extent(0);

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(
vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
const MKL_COMPACT_PACK format =
vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
int r_val = 0;
if (A.stride_0() == 1 && B.stride_0() == 1 && C.stride_0() == 1) {
mkl_dgemm_compact(MKL_COL_MAJOR, MKL_TRANS, MKL_TRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_1(),
(const double *)B.data(), B.stride_1(), beta,
(double *)C.data(), C.stride_1(), format,
(MKL_INT)vector_type::vector_length);
} else if (A.stride_1() == 1 && B.stride_1() == 1 && C.stride_1() == 1) {
mkl_dgemm_compact(MKL_ROW_MAJOR, MKL_TRANS, MKL_TRANS, m, n, k, alpha,
(const double *)A.data(), A.stride_0(),
(const double *)B.data(), B.stride_0(), beta,
(double *)C.data(), C.stride_0(), format,
(MKL_INT)vector_type::vector_length);
} else {
r_val = -1;
}
return r_val;
}
#endif

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Unblocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
C.stride_0(), C.stride_1());
}

template <>
template <typename ScalarType, typename AViewType, typename BViewType,
typename CViewType>
KOKKOS_INLINE_FUNCTION int
SerialGemm<Trans::Transpose, Trans::Transpose, Algo::Gemm::Blocked>::invoke(
const ScalarType alpha, const AViewType &A, const BViewType &B,
const ScalarType beta, const CViewType &C) {
// C = beta C + alpha A B
// C (m x n), A(m x k), B(k x n)
return SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
C.extent(0), C.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), B.data(), B.stride_1(), B.stride_0(), beta, C.data(),
C.stride_0(), C.stride_1());
}
/********************* END functor-level routines *********************/

namespace Impl {
/********************* BEGIN non-functor-level routines *********************/
template <class ArgTransA, class ArgTransB, class ArgMode, class ArgBatchSzDim,
Expand Down Expand Up @@ -467,9 +133,9 @@ class BatchedSerialGemm {
// matrix transpositions, here we must perform the GEMM on:
// row_vec x col_vec, which is svA_row' x svB_col to compute the element
// of C.
KokkosBatched::SerialGemm<Trans::Transpose, Trans::NoTranspose,
ArgMode>::invoke(alpha, svA_row, svB_col, beta,
svC_ele);
KokkosBlas::SerialGemm<Trans::Transpose, Trans::NoTranspose,
ArgMode>::invoke(alpha, svA_row, svB_col, beta,
svC_ele);
}

KOKKOS_INLINE_FUNCTION
Expand All @@ -481,7 +147,7 @@ class BatchedSerialGemm {
auto svC =
subview_wrapper(C, i, Kokkos::ALL(), Kokkos::ALL(), batch_layout_tag);

KokkosBatched::SerialGemm<ArgTransA, ArgTransB, ArgMode>::invoke(
KokkosBlas::SerialGemm<ArgTransA, ArgTransB, ArgMode>::invoke(
alpha, svA, svB, beta, svC);
}
};
Expand Down
Loading