From 5713084c59bf0fd89914ff3cee7235c3f810b62e Mon Sep 17 00:00:00 2001 From: yasahi-hpc <57478230+yasahi-hpc@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:14:01 +0200 Subject: [PATCH] Check View rank and FFT rank consistency (#121) * Add a constant for maximum FFT dimension * Check view rank and fft rank consistency in all APIs --------- Co-authored-by: Yuuichi Asahi --- common/src/KokkosFFT_Helpers.hpp | 28 +++++++++++------ common/src/KokkosFFT_common_types.hpp | 3 ++ fft/src/KokkosFFT_Plans.hpp | 8 +++++ fft/src/KokkosFFT_Transform.hpp | 44 +++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/common/src/KokkosFFT_Helpers.hpp b/common/src/KokkosFFT_Helpers.hpp index e53cadff..f0a9ec25 100644 --- a/common/src/KokkosFFT_Helpers.hpp +++ b/common/src/KokkosFFT_Helpers.hpp @@ -14,10 +14,6 @@ namespace KokkosFFT { namespace Impl { template auto get_shift(const ViewType& inout, axis_type _axes, int direction = 1) { - static_assert(DIM > 0, - "get_shift: Rank of shift axes must be " - "larger than or equal to 1."); - // Convert the input axes to be in the range of [0, rank-1] std::vector axes; for (std::size_t i = 0; i < DIM; i++) { @@ -132,9 +128,6 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift, template void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout, axis_type axes) { - static_assert(ViewType::rank() >= DIM, - "fftshift_impl: Rank of View must be larger thane " - "or equal to the Rank of shift axes."); auto shift = get_shift(inout, axes); roll(exec_space, inout, shift, axes); } @@ -142,9 +135,6 @@ void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout, template void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout, axis_type axes) { - static_assert(ViewType::rank() >= DIM, - "ifftshift_impl: Rank of View must be larger " - "thane or equal to the Rank of shift axes."); auto shift = get_shift(inout, axes, -1); roll(exec_space, inout, shift, axes); } @@ -229,6 +219,9 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout, "Kokkos::Complex, or Kokkos::Complex. " "Layout must be either LayoutLeft or LayoutRight. " "ExecutionSpace must be able to access data in ViewType"); + static_assert(ViewType::rank() >= 1, + "fftshift: View rank must be larger than or equal to 1"); + if (axes) { axis_type<1> _axes{axes.value()}; KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes); @@ -253,6 +246,12 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout, "Kokkos::Complex, or Kokkos::Complex. " "Layout must be either LayoutLeft or LayoutRight. " "ExecutionSpace must be able to access data in ViewType"); + static_assert( + DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "fftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert(ViewType::rank() >= DIM, + "fftshift: View rank must be larger than or equal to the Rank " + "of FFT axes"); KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes); } @@ -269,6 +268,8 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, "Kokkos::Complex, or Kokkos::Complex. " "Layout must be either LayoutLeft or LayoutRight. " "ExecutionSpace must be able to access data in ViewType"); + static_assert(ViewType::rank() >= 1, + "ifftshift: View rank must be larger than or equal to 1"); if (axes) { axis_type<1> _axes{axes.value()}; KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes); @@ -293,6 +294,13 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, "Kokkos::Complex, or Kokkos::Complex. " "Layout must be either LayoutLeft or LayoutRight. " "ExecutionSpace must be able to access data in ViewType"); + static_assert( + DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "ifftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert(ViewType::rank() >= DIM, + "ifftshift: View rank must be larger than or equal to the Rank " + "of FFT axes"); + KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes); } } // namespace KokkosFFT diff --git a/common/src/KokkosFFT_common_types.hpp b/common/src/KokkosFFT_common_types.hpp index 2c7c2eee..bc726743 100644 --- a/common/src/KokkosFFT_common_types.hpp +++ b/common/src/KokkosFFT_common_types.hpp @@ -34,6 +34,9 @@ enum class Direction { backward, }; +//! Maximum FFT dimension allowed in KokkosFFT +constexpr std::size_t MAX_FFT_DIM = 3; + } // namespace KokkosFFT #endif \ No newline at end of file diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 7a4aef18..500a2e84 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -167,6 +167,8 @@ class Plan { "(LayoutLeft/LayoutRight), " "and the same rank. ExecutionSpace must be accessible to the data in " "InViewType and OutViewType."); + static_assert(InViewType::rank() >= 1, + "Plan::Plan: View rank must be larger than or equal to 1"); if (KokkosFFT::Impl::is_real_v && m_direction != KokkosFFT::Direction::forward) { @@ -220,6 +222,12 @@ class Plan { "(LayoutLeft/LayoutRight), " "and the same rank. ExecutionSpace must be accessible to the data in " "InViewType and OutViewType."); + static_assert( + DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "Plan::Plan: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert(InViewType::rank() >= DIM, + "Plan::Plan: View rank must be larger than or equal to the " + "Rank of FFT axes"); if (std::is_floating_point::value && m_direction != KokkosFFT::Direction::forward) { diff --git a/fft/src/KokkosFFT_Transform.hpp b/fft/src/KokkosFFT_Transform.hpp index 6e9ca1f9..23dcd003 100644 --- a/fft/src/KokkosFFT_Transform.hpp +++ b/fft/src/KokkosFFT_Transform.hpp @@ -139,6 +139,8 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 1, + "fft: View rank must be larger than or equal to 1"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axis, n); @@ -165,6 +167,8 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 1, + "ifft: View rank must be larger than or equal to 1"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward, axis, n); @@ -191,6 +195,8 @@ void rfft(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 1, + "rfft: View rank must be larger than or equal to 1"); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; @@ -224,6 +230,8 @@ void irfft(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 1, + "irfft: View rank must be larger than or equal to 1"); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; @@ -255,6 +263,8 @@ void hfft(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 1, + "hfft: View rank must be larger than or equal to 1"); // [TO DO] // allow real type as input, need to obtain complex view type from in view @@ -295,6 +305,8 @@ void ihfft(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 1, + "ihfft: View rank must be larger than or equal to 1"); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; @@ -332,6 +344,8 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 2, + "fft2: View rank must be larger than or equal to 2"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axes, s); @@ -359,6 +373,8 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 2, + "ifft2: View rank must be larger than or equal to 2"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward, axes, s); @@ -386,6 +402,9 @@ void rfft2(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 2, + "rfft2: View rank must be larger than or equal to 2"); + using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; @@ -418,6 +437,8 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(InViewType::rank() >= 2, + "irfft2: View rank must be larger than or equal to 2"); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; @@ -453,6 +474,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert(DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "fftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert( + InViewType::rank() >= DIM, + "fftn: View rank must be larger than or equal to the Rank of FFT axes"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axes, s); @@ -481,6 +507,12 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert( + DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "ifftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert( + InViewType::rank() >= DIM, + "ifftn: View rank must be larger than or equal to the Rank of FFT axes"); KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward, axes, s); @@ -509,6 +541,12 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert( + DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "rfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert( + InViewType::rank() >= DIM, + "rfftn: View rank must be larger than or equal to the Rank of FFT axes"); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; @@ -543,6 +581,12 @@ void irfftn(const ExecutionSpace& exec_space, const InViewType& in, "type (float/double), the same layout (LayoutLeft/LayoutRight), and the " "same rank. ExecutionSpace must be accessible to the data in InViewType " "and OutViewType."); + static_assert( + DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM, + "irfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM"); + static_assert( + InViewType::rank() >= DIM, + "irfftn: View rank must be larger than or equal to the Rank of FFT axes"); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type;