Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check View rank and FFT rank consistency #121

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM = 1>
auto get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction = 1) {
static_assert(DIM > 0,
"get_shift: Rank of shift axes must be "
"larger than or equal to 1.");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
for (std::size_t i = 0; i < DIM; i++) {
Expand Down Expand Up @@ -132,19 +128,13 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(ViewType::rank() >= DIM,
"fftshift_impl: Rank of View must be larger thane "
"or equal to the Rank of shift axes.");
auto shift = get_shift(inout, axes);
roll(exec_space, inout, shift, axes);
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(ViewType::rank() >= DIM,
"ifftshift_impl: Rank of View must be larger "
"thane or equal to the Rank of shift axes.");
auto shift = get_shift(inout, axes, -1);
roll(exec_space, inout, shift, axes);
}
Expand Down Expand Up @@ -229,6 +219,9 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(ViewType::rank() >= 1,
"fftshift: View rank must be larger than or equal to 1");

if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
Expand All @@ -253,6 +246,12 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"fftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"fftshift: View rank must be larger than or equal to the Rank "
"of FFT axes");
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
}

Expand All @@ -269,6 +268,8 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(ViewType::rank() >= 1,
"ifftshift: View rank must be larger than or equal to 1");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
Expand All @@ -293,6 +294,13 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"ifftshift: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"ifftshift: View rank must be larger than or equal to the Rank "
"of FFT axes");

KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
}
} // namespace KokkosFFT
Expand Down
3 changes: 3 additions & 0 deletions common/src/KokkosFFT_common_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ enum class Direction {
backward,
};

//! Maximum FFT dimension allowed in KokkosFFT
constexpr std::size_t MAX_FFT_DIM = 3;

} // namespace KokkosFFT

#endif
8 changes: 8 additions & 0 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class Plan {
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");
static_assert(InViewType::rank() >= 1,
"Plan::Plan: View rank must be larger than or equal to 1");

if (KokkosFFT::Impl::is_real_v<in_value_type> &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down Expand Up @@ -220,6 +222,12 @@ class Plan {
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"Plan::Plan: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(InViewType::rank() >= DIM,
"Plan::Plan: View rank must be larger than or equal to the "
"Rank of FFT axes");

if (std::is_floating_point<in_value_type>::value &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down
44 changes: 44 additions & 0 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"fft: View rank must be larger than or equal to 1");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axis, n);
Expand All @@ -165,6 +167,8 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"ifft: View rank must be larger than or equal to 1");

KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axis, n);
Expand All @@ -191,6 +195,8 @@ void rfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"rfft: View rank must be larger than or equal to 1");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -224,6 +230,8 @@ void irfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"irfft: View rank must be larger than or equal to 1");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -255,6 +263,8 @@ void hfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"hfft: View rank must be larger than or equal to 1");

// [TO DO]
// allow real type as input, need to obtain complex view type from in view
Expand Down Expand Up @@ -295,6 +305,8 @@ void ihfft(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 1,
"ihfft: View rank must be larger than or equal to 1");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -332,6 +344,8 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"fft2: View rank must be larger than or equal to 2");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axes, s);
Expand Down Expand Up @@ -359,6 +373,8 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"ifft2: View rank must be larger than or equal to 2");

KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axes, s);
Expand Down Expand Up @@ -386,6 +402,9 @@ void rfft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"rfft2: View rank must be larger than or equal to 2");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

Expand Down Expand Up @@ -418,6 +437,8 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(InViewType::rank() >= 2,
"irfft2: View rank must be larger than or equal to 2");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -453,6 +474,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"fftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"fftn: View rank must be larger than or equal to the Rank of FFT axes");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axes, s);
Expand Down Expand Up @@ -481,6 +507,12 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"ifftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"ifftn: View rank must be larger than or equal to the Rank of FFT axes");

KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axes, s);
Expand Down Expand Up @@ -509,6 +541,12 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"rfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"rfftn: View rank must be larger than or equal to the Rank of FFT axes");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down Expand Up @@ -543,6 +581,12 @@ void irfftn(const ExecutionSpace& exec_space, const InViewType& in,
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"irfftn: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(
InViewType::rank() >= DIM,
"irfftn: View rank must be larger than or equal to the Rank of FFT axes");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
Expand Down
Loading