Skip to content

Commit

Permalink
mode_kk_to_onemkl and trans_mode_kk_to_mkl -> trans_mode_kk_to_onemkl
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Dec 7, 2023
1 parent c241523 commit 257ad0c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 22 deletions.
13 changes: 1 addition & 12 deletions blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,17 +777,6 @@ KOKKOSBLAS2_CGEMV_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false)
namespace KokkosBlas {
namespace Impl {

inline oneapi::mkl::transpose mode_kk_to_onemkl(char mode_kk) {
switch (toupper(mode_kk)) {
case 'N': return oneapi::mkl::transpose::nontrans;
case 'T': return oneapi::mkl::transpose::trans;
case 'C': return oneapi::mkl::transpose::conjtrans;
default:;
}
throw std::invalid_argument(
"Invalid mode for oneMKL (should be one of N, T, C)");
}

template <typename T, bool is_complex = false>
struct kokkos_to_std_type_map {
using type = T;
Expand Down Expand Up @@ -829,7 +818,7 @@ struct kokkos_to_std_type_map<T, true> {
bool row_major = std::is_same<Kokkos::LayoutRight, LAYOUT>::value; \
const std::int64_t M = A.extent(0); \
const std::int64_t N = A.extent(1); \
oneapi::mkl::transpose trans = mode_kk_to_onemkl(kk_trans[0]); \
oneapi::mkl::transpose trans = trans_mode_kk_to_onemkl(kk_trans[0]); \
const std::int64_t LDA = row_major ? A.stride(0) : A.stride(1); \
std::string label = "KokkosBlas::gemv[TPL_ONEMKL," + \
Kokkos::ArithTraits<SCALAR>::name() + "]"; \
Expand Down
4 changes: 2 additions & 2 deletions blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,8 @@ TPL_SCALAR_TYPE is the type MKL accents for SCALAR_TYPE
const int64_t cst = is_lr ? C.stride(0) : C.stride(1); \
const int64_t ldc = cst == 0 ? 1 : cst; \
\
oneapi::mkl::transpose transa = trans_mode_kk_to_mkl(transA); \
oneapi::mkl::transpose transb = trans_mode_kk_to_mkl(transB); \
oneapi::mkl::transpose transa = trans_mode_kk_to_onemkl(transA[0]); \
oneapi::mkl::transpose transb = trans_mode_kk_to_onemkl(transB[0]); \
oneapi::mkl::blas::compute_mode mode = oneapi::mkl::blas::compute_mode::standard; \
\
if constexpr (!is_lr) { \
Expand Down
17 changes: 9 additions & 8 deletions blas/tpls/KokkosBlas_tpl_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,15 @@ namespace Impl {


/// \brief This function converts KK transpose mode to MKL transpose mode
inline oneapi::mkl::transpose trans_mode_kk_to_mkl(const char kkMode[]) {
oneapi::mkl::transpose trans;
if ((kkMode[0] == 'N') || (kkMode[0] == 'n'))
return oneapi::mkl::transpose::N;
else if ((kkMode[0] == 'T') || (kkMode[0] == 't'))
return oneapi::mkl::transpose::T;
else
return oneapi::mkl::transpose::C;
inline oneapi::mkl::transpose trans_mode_kk_to_onemkl(char mode_kk) {
switch (toupper(mode_kk)) {
case 'N': return oneapi::mkl::transpose::nontrans;
case 'T': return oneapi::mkl::transpose::trans;
case 'C': return oneapi::mkl::transpose::conjtrans;
default:;
}
throw std::invalid_argument(
"Invalid mode for oneMKL (should be one of N, T, C)");
}

} // namespace Impl
Expand Down

0 comments on commit 257ad0c

Please sign in to comment.