diff --git a/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp b/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp index 09d9a7ca68..bdf4ab43b2 100644 --- a/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp +++ b/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp @@ -45,229 +45,81 @@ inline void gesv_print_specialization() { namespace KokkosLapack { namespace Impl { -#define KOKKOSLAPACK_DGESV_LAPACK(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ +template +void lapackGesvWrapper(const AViewType& A, const BViewType& B, + const IPIVViewType& IPIV) { + using Scalar = typename AViewType::non_const_value_type; + + const bool with_pivot = + !((IPIV.extent(0) == 0) && (IPIV.data() == nullptr)); + + const int N = static_cast(A.extent(1)); + const int AST = static_cast(A.stride(1)); + const int LDA = (AST == 0) ? 1 : AST; + const int BST = static_cast(B.stride(1)); + const int LDB = (BST == 0) ? 1 : BST; + const int NRHS = static_cast(B.extent(1)); + + int info = 0; + + if (with_pivot) { + if constexpr (Kokkos::ArithTraits::is_complex) { + using MagType = typename Kokkos::ArithTraits::mag_type; + + HostLapack>::gesv( + N, NRHS, reinterpret_cast*>(A.data()), LDA, + IPIV.data(), reinterpret_cast*>(B.data()), + LDB, info); + } else { + HostLapack::gesv(N, NRHS, A.data(), LDA, IPIV.data(), + B.data(), LDB, info); + } + } +} + +#define KOKKOSLAPACK_GESV_LAPACK(SCALAR, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ template \ - struct GESV< \ - Kokkos::View, \ + struct GESV< \ + ExecSpace, \ + Kokkos::View, \ Kokkos::MemoryTraits>, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits>, \ Kokkos::View, \ Kokkos::MemoryTraits>, \ true, ETI_SPEC_AVAIL> { \ - typedef double SCALAR; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - AViewType; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - BViewType; \ - typedef Kokkos::View< \ - int*, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits> \ - PViewType; \ + using AViewType = Kokkos::View, \ + Kokkos::MemoryTraits>; \ + using BViewType = Kokkos::View, \ + Kokkos::MemoryTraits>; \ + using PViewType = Kokkos::View>; \ \ - static void gesv(const AViewType& A, const BViewType& B, \ - const PViewType& IPIV) { \ - Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK,double]"); \ + static void gesv(const ExecSpace& /* space */, const AViewType& A, \ + const BViewType& B, const PViewType& IPIV) { \ + Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK,"#SCALAR"]"); \ gesv_print_specialization(); \ - const bool with_pivot = \ - !((IPIV.extent(0) == 0) && (IPIV.data() == nullptr)); \ - \ - const int N = static_cast(A.extent(1)); \ - const int AST = static_cast(A.stride(1)); \ - const int LDA = (AST == 0) ? 1 : AST; \ - const int BST = static_cast(B.stride(1)); \ - const int LDB = (BST == 0) ? 1 : BST; \ - const int NRHS = static_cast(B.extent(1)); \ - \ - int info = 0; \ - \ - if (with_pivot) { \ - HostLapack::gesv(N, NRHS, A.data(), LDA, IPIV.data(), \ - B.data(), LDB, info); \ - } \ - Kokkos::Profiling::popRegion(); \ - } \ - }; - -#define KOKKOSLAPACK_SGESV_LAPACK(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ - template \ - struct GESV< \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - true, ETI_SPEC_AVAIL> { \ - typedef float SCALAR; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - AViewType; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - BViewType; \ - typedef Kokkos::View< \ - int*, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits> \ - PViewType; \ - \ - static void gesv(const AViewType& A, const BViewType& B, \ - const PViewType& IPIV) { \ - Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK,float]"); \ - gesv_print_specialization(); \ - const bool with_pivot = \ - !((IPIV.extent(0) == 0) && (IPIV.data() == nullptr)); \ - \ - const int N = static_cast(A.extent(1)); \ - const int AST = static_cast(A.stride(1)); \ - const int LDA = (AST == 0) ? 1 : AST; \ - const int BST = static_cast(B.stride(1)); \ - const int LDB = (BST == 0) ? 1 : BST; \ - const int NRHS = static_cast(B.extent(1)); \ - \ - int info = 0; \ - \ - if (with_pivot) { \ - HostLapack::gesv(N, NRHS, A.data(), LDA, IPIV.data(), B.data(), \ - LDB, info); \ - } \ - Kokkos::Profiling::popRegion(); \ - } \ - }; - -#define KOKKOSLAPACK_ZGESV_LAPACK(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ - template \ - struct GESV**, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits>, \ - Kokkos::View**, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - true, ETI_SPEC_AVAIL> { \ - typedef Kokkos::complex SCALAR; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - AViewType; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - BViewType; \ - typedef Kokkos::View< \ - int*, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits> \ - PViewType; \ - \ - static void gesv(const AViewType& A, const BViewType& B, \ - const PViewType& IPIV) { \ - Kokkos::Profiling::pushRegion( \ - "KokkosLapack::gesv[TPL_LAPACK,complex]"); \ - gesv_print_specialization(); \ - const bool with_pivot = \ - !((IPIV.extent(0) == 0) && (IPIV.data() == nullptr)); \ - \ - const int N = static_cast(A.extent(1)); \ - const int AST = static_cast(A.stride(1)); \ - const int LDA = (AST == 0) ? 1 : AST; \ - const int BST = static_cast(B.stride(1)); \ - const int LDB = (BST == 0) ? 1 : BST; \ - const int NRHS = static_cast(B.extent(1)); \ - \ - int info = 0; \ - \ - if (with_pivot) { \ - HostLapack>::gesv( \ - N, NRHS, reinterpret_cast*>(A.data()), LDA, \ - IPIV.data(), reinterpret_cast*>(B.data()), \ - LDB, info); \ - } \ - Kokkos::Profiling::popRegion(); \ - } \ - }; - -#define KOKKOSLAPACK_CGESV_LAPACK(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ - template \ - struct GESV**, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits>, \ - Kokkos::View**, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - true, ETI_SPEC_AVAIL> { \ - typedef Kokkos::complex SCALAR; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - AViewType; \ - typedef Kokkos::View, \ - Kokkos::MemoryTraits> \ - BViewType; \ - typedef Kokkos::View< \ - int*, LAYOUT, \ - Kokkos::Device, \ - Kokkos::MemoryTraits> \ - PViewType; \ - \ - static void gesv(const AViewType& A, const BViewType& B, \ - const PViewType& IPIV) { \ - Kokkos::Profiling::pushRegion( \ - "KokkosLapack::gesv[TPL_LAPACK,complex]"); \ - gesv_print_specialization(); \ - const bool with_pivot = \ - !((IPIV.extent(0) == 0) && (IPIV.data() == nullptr)); \ - \ - const int N = static_cast(A.extent(1)); \ - const int AST = static_cast(A.stride(1)); \ - const int LDA = (AST == 0) ? 1 : AST; \ - const int BST = static_cast(B.stride(1)); \ - const int LDB = (BST == 0) ? 1 : BST; \ - const int NRHS = static_cast(B.extent(1)); \ - \ - int info = 0; \ - \ - if (with_pivot) { \ - HostLapack>::gesv( \ - N, NRHS, reinterpret_cast*>(A.data()), LDA, \ - IPIV.data(), reinterpret_cast*>(B.data()), \ - LDB, info); \ - } \ + lapackGesvWrapper(A, B, IPIV); \ Kokkos::Profiling::popRegion(); \ } \ }; -KOKKOSLAPACK_DGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, true) -KOKKOSLAPACK_DGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, false) +KOKKOSLAPACK_GESV_LAPACK(float, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSLAPACK_GESV_LAPACK(float, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -KOKKOSLAPACK_SGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, true) -KOKKOSLAPACK_SGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, false) +KOKKOSLAPACK_GESV_LAPACK(double, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSLAPACK_GESV_LAPACK(double, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -KOKKOSLAPACK_ZGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, true) -KOKKOSLAPACK_ZGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, false) +KOKKOSLAPACK_GESV_LAPACK(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSLAPACK_GESV_LAPACK(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -KOKKOSLAPACK_CGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, true) -KOKKOSLAPACK_CGESV_LAPACK(Kokkos::LayoutLeft, Kokkos::HostSpace, false) +KOKKOSLAPACK_GESV_LAPACK(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSLAPACK_GESV_LAPACK(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::HostSpace, false) } // namespace Impl } // namespace KokkosLapack @@ -546,7 +398,7 @@ namespace Impl { template -void rocsolverGesvAdapter(const ExecutionSpace& space, const IPIVViewType& IPIV, +void rocsolverGesvWrapper(const ExecutionSpace& space, const IPIVViewType& IPIV, const AViewType& A, const BViewType& B) { using Scalar = typename BViewType::non_const_value_type; using ALayout_t = typename AViewType::array_layout; @@ -620,7 +472,7 @@ void rocsolverGesvAdapter(const ExecutionSpace& space, const IPIVViewType& IPIV, "KokkosLapack::gesv[TPL_ROCSOLVER," #SCALAR "]"); \ gesv_print_specialization(); \ \ - rocsolverGesvAdapter(space, IPIV, A, B); \ + rocsolverGesvWrapper(space, IPIV, A, B); \ Kokkos::Profiling::popRegion(); \ } \ }; diff --git a/lapack/unit_test/Test_Lapack_gesv.hpp b/lapack/unit_test/Test_Lapack_gesv.hpp index 0a8012baf4..1ac24d4d2d 100644 --- a/lapack/unit_test/Test_Lapack_gesv.hpp +++ b/lapack/unit_test/Test_Lapack_gesv.hpp @@ -36,7 +36,7 @@ namespace Test { - template +template void impl_test_gesv(const char* mode, const char* padding, int N) { using execution_space = typename Device::execution_space; using ScalarA = typename ViewTypeA::value_type; @@ -84,9 +84,9 @@ void impl_test_gesv(const char* mode, const char* padding, int N) { Kokkos::deep_copy(h_X0, X0); // Allocate IPIV view on host - using ViewTypeP = std::conditional, - Kokkos::View>::type; + using ViewTypeP = typename std::conditional, + Kokkos::View>::type; ViewTypeP ipiv; int Nt = 0; if (mode[0] == 'Y') { @@ -144,7 +144,7 @@ void impl_test_gesv(const char* mode, const char* padding, int N) { ASSERT_EQ(test_flag, true); } -template +template void impl_test_gesv_mrhs(const char* mode, const char* padding, int N, int nrhs) { using execution_space = typename Device::execution_space; @@ -193,9 +193,9 @@ void impl_test_gesv_mrhs(const char* mode, const char* padding, int N, Kokkos::deep_copy(h_X0, X0); // Allocate IPIV view on host - using ViewTypeP = std::conditional, - Kokkos::View>::type; + using ViewTypeP = typename std::conditional, + Kokkos::View>::type; ViewTypeP ipiv; int Nt = 0; if (mode[0] == 'Y') {