Skip to content

Commit

Permalink
TPL: revise BLAS1 dot implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jczhang07 committed Aug 24, 2023
1 parent 87fee56 commit b746154
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 142 deletions.
44 changes: 28 additions & 16 deletions blas/src/KokkosBlas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,37 @@ dot(const execution_space& space, const XVector& x, const YVector& y) {
Kokkos::View<result_type, default_layout, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

result_type result{};
RVector_Result R = RVector_Result(&result);
XVector_Internal X = x;
YVector_Internal Y = y;

// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type is
// 32-bit precision). Impl::Dot needs to support both cases, and it's easier
// to do this with overloading than by extending the ETI to deal with two
// different scalar types.
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space, R,
X, Y);
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
bool useFallback = false;
if (useFallback) {
// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type
// is 32-bit precision). Impl::Dot needs to support both cases, and it's
// easier to do this with overloading than by extending the ETI to deal with
// two different scalar types.
result_type result{};
RVector_Result R = RVector_Result(&result);
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space,
R, X,
Y);
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
} else {
dot_type result{};
RVector_Internal R = RVector_Internal(&result);
Impl::Dot<execution_space, RVector_Internal, XVector_Internal,
YVector_Internal>::dot(space, R, X, Y);
space.fence();
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
}
}

/// \brief Return the dot product of the two vectors x and y.
Expand Down
34 changes: 19 additions & 15 deletions blas/tpls/KokkosBlas1_dot_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,

#endif

// cuBLAS
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
// double
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(SCALAR, LAYOUT, EXECSPACE, \
MEMSPACE) \
#define KOKKOSBLAS1_DOT_TPL_SPEC(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct dot_tpl_spec_avail< \
EXECSPACE, \
Expand All @@ -77,19 +73,27 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
enum : bool { value = true }; \
};

KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(double, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(float, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<double>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<float>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(float, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(double, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<float>, LAYOUT, EXECSPACE, \
MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<double>, LAYOUT, EXECSPACE, MEMSPACE)

#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Experimental::HIP,
Kokkos::Experimental::HIPSpace)
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Experimental::SYCL,
Kokkos::Experimental::SYCLDeviceUSMSpace)
#endif
} // namespace Impl
} // namespace KokkosBlas
#endif
Loading

0 comments on commit b746154

Please sign in to comment.