Skip to content

Commit

Permalink
TPL: revise BLAS1 dot implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jczhang07 committed Aug 23, 2023
1 parent 87fee56 commit 708e86c
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 134 deletions.
24 changes: 16 additions & 8 deletions blas/src/KokkosBlas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,22 @@ dot(const execution_space& space, const XVector& x, const YVector& y) {
XVector_Internal X = x;
YVector_Internal Y = y;

// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type is
// 32-bit precision). Impl::Dot needs to support both cases, and it's easier
// to do this with overloading than by extending the ETI to deal with two
// different scalar types.
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space, R,
X, Y);
bool useFallback = false;
if (useFallback) {
// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type
// is 32-bit precision). Impl::Dot needs to support both cases, and it's
// easier to do this with overloading than by extending the ETI to deal with
// two different scalar types.
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space,
R, X,
Y);
} else {
using impl_type = Impl::Dot<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>;
impl_type::dot(space, R, X, Y);
}
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
Expand Down
34 changes: 19 additions & 15 deletions blas/tpls/KokkosBlas1_dot_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,

#endif

// cuBLAS
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
// double
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(SCALAR, LAYOUT, EXECSPACE, \
MEMSPACE) \
#define KOKKOSBLAS1_DOT_TPL_SPEC(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct dot_tpl_spec_avail< \
EXECSPACE, \
Expand All @@ -77,19 +73,27 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
enum : bool { value = true }; \
};

KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(double, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(float, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<double>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<float>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(float, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(double, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<float>, LAYOUT, EXECSPACE, \
MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<double>, LAYOUT, EXECSPACE, MEMSPACE)

#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Experimental::HIP,
Kokkos::Experimental::HIPSpace)
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Experimental::SYCL,
Kokkos::Experimental::SYCLDeviceUSMSpace)
#endif
} // namespace Impl
} // namespace KokkosBlas
#endif
Loading

0 comments on commit 708e86c

Please sign in to comment.