From 650239813dcdf355fb42fb04d62ffa1748be9c20 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Sun, 10 Nov 2024 12:00:06 -0800 Subject: [PATCH] Add local FFT (#4224) --- Docs/sphinx_documentation/source/FFT.rst | 37 ++- Src/Base/AMReX.cpp | 8 + Src/Base/AMReX_GpuDevice.cpp | 15 - Src/FFT/AMReX_FFT.H | 8 + Src/FFT/AMReX_FFT.cpp | 82 ++++++ Src/FFT/AMReX_FFT_Helper.H | 239 ++++++++++++++-- Src/FFT/AMReX_FFT_LocalR2C.H | 333 +++++++++++++++++++++++ Src/FFT/AMReX_FFT_R2C.H | 2 +- Src/FFT/CMakeLists.txt | 1 + Src/FFT/Make.package | 1 + Tests/FFT/R2C/main.cpp | 45 ++- 11 files changed, 724 insertions(+), 47 deletions(-) create mode 100644 Src/FFT/AMReX_FFT_LocalR2C.H diff --git a/Docs/sphinx_documentation/source/FFT.rst b/Docs/sphinx_documentation/source/FFT.rst index 8d6205f43a..2a5957e40b 100644 --- a/Docs/sphinx_documentation/source/FFT.rst +++ b/Docs/sphinx_documentation/source/FFT.rst @@ -7,10 +7,10 @@ FFT::R2C Class ============== Class template `FFT::R2C` supports discrete Fourier transforms between real -and complex data. The name R2C indicates that the forward transform converts -real data to complex data, while the backward transform converts complex -data to real data. It should be noted that both directions of transformation -are supported, not just from real to complex. +and complex data across MPI processes. The name R2C indicates that the +forward transform converts real data to complex data, while the backward +transform converts complex data to real data. It should be noted that both +directions of transformation are supported, not just from real to complex. The implementation utilizes cuFFT, rocFFT, oneMKL and FFTW, for CUDA, HIP, SYCL and CPU builds, respectively. Because the parallel communication is @@ -18,8 +18,7 @@ handled by AMReX, it does not need the parallel version of FFTW. Furthermore, there is no constraint on the domain decomposition such as one Box per process. This class performs parallel FFT on AMReX's parallel data containers (e.g., :cpp:`MultiFab` and -:cpp:`FabArray>>`. For local FFT, the users can -use FFTW, cuFFT, rocFFT, or oneMKL directly. +:cpp:`FabArray>>`. Other than using column-majored order, AMReX follows the convention of FFTW. Applying the forward transform followed by the backward transform @@ -68,6 +67,32 @@ object. Therefore, one should cache it for reuse if possible. Although :cpp:`std::unique_ptr>` to store an object in one's class. +.. _sec:FFT:localr2c: + +FFT::LocalR2C Class +=================== + +Class template `FFT::LocalR2C` supports local discrete Fourier transforms +between real and complex data. The name R2C indicates that the forward +transform converts real data to complex data, while the backward transform +converts complex data to real data. It should be noted that both directions +of transformation are supported, not just from real to complex. + +Below is an example of using :cpp:`FFT::LocalR2C`. + +.. highlight:: c++ + +:: + + MultiFab mf(...); + BaseFab> cfab; + for (MFIter mfi(mf); mfi.isValid(); ++mfi) { + FFT::LocalR2C fft(mfi.fabbox().length()); + cfab.resize(IntVect(0), fft.spectralSize()-1); + fft.forward(mf[mfi].dataPtr(), cfab.dataPtr()); + } + + Poisson Solver ============== diff --git a/Src/Base/AMReX.cpp b/Src/Base/AMReX.cpp index d3629c5fd3..9d9edeaeba 100644 --- a/Src/Base/AMReX.cpp +++ b/Src/Base/AMReX.cpp @@ -14,6 +14,10 @@ #include #include +#ifdef AMREX_USE_FFT +#include +#endif + #ifdef AMREX_USE_HYPRE #include <_hypre_utilities.h> #ifdef AMREX_USE_CUDA @@ -655,6 +659,10 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse, AsyncOut::Initialize(); VectorGrowthStrategy::Initialize(); +#ifdef AMREX_USE_FFT + FFT::Initialize(); +#endif + #ifdef AMREX_USE_EB EB2::Initialize(); #endif diff --git a/Src/Base/AMReX_GpuDevice.cpp b/Src/Base/AMReX_GpuDevice.cpp index 155cdcd4dd..d911349a61 100644 --- a/Src/Base/AMReX_GpuDevice.cpp +++ b/Src/Base/AMReX_GpuDevice.cpp @@ -35,13 +35,6 @@ #include #endif #endif -#if defined(AMREX_USE_FFT) -# if __has_include() // ROCm 5.3+ -# include -# else -# include -# endif -#endif #endif #ifdef AMREX_USE_ACC @@ -317,10 +310,6 @@ Device::Initialize () } #endif /* AMREX_USE_MPI */ -#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT) - AMREX_ROCFFT_SAFE_CALL(rocfft_setup()); -#endif - if (amrex::Verbose()) { #if defined(AMREX_USE_CUDA) amrex::Print() << "CUDA" @@ -360,10 +349,6 @@ Device::Finalize () #ifdef AMREX_USE_GPU Device::profilerStop(); -#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT) - AMREX_ROCFFT_SAFE_CALL(rocfft_cleanup()); -#endif - #ifdef AMREX_USE_SYCL for (auto& s : gpu_stream_pool) { delete s.queue; diff --git a/Src/FFT/AMReX_FFT.H b/Src/FFT/AMReX_FFT.H index cd7c0984c5..11bf4f4cc8 100644 --- a/Src/FFT/AMReX_FFT.H +++ b/Src/FFT/AMReX_FFT.H @@ -2,8 +2,16 @@ #define AMREX_FFT_H_ #include +#include #include #include #include +namespace amrex::FFT +{ + void Initialize (); + void Finalize (); + void Clear (); +} + #endif diff --git a/Src/FFT/AMReX_FFT.cpp b/Src/FFT/AMReX_FFT.cpp index 8bf7b5fd9f..91ac1a7a92 100644 --- a/Src/FFT/AMReX_FFT.cpp +++ b/Src/FFT/AMReX_FFT.cpp @@ -1,5 +1,87 @@ +#include #include +#include + +namespace amrex::FFT +{ + +namespace +{ + bool s_initialized = false; + std::map s_plans_d; + std::map s_plans_f; +} + +void Initialize () +{ + if (!s_initialized) + { + s_initialized = true; + +#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT) + AMREX_ROCFFT_SAFE_CALL(rocfft_setup()); +#endif + } + + amrex::ExecOnFinalize(amrex::FFT::Finalize); +} + +void Finalize () +{ + if (s_initialized) + { + s_initialized = false; + + Clear(); + +#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT) + AMREX_ROCFFT_SAFE_CALL(rocfft_cleanup()); +#endif + } +} + +void Clear () +{ + for (auto& [k, p] : s_plans_d) { + Plan::destroy_vendor_plan(p); + } + + for (auto& [k, p] : s_plans_f) { + Plan::destroy_vendor_plan(p); + } +} + +PlanD* get_vendor_plan_d (Key const& key) +{ + if (auto found = s_plans_d.find(key); found != s_plans_d.end()) { + return &(found->second); + } else { + return nullptr; + } +} + +PlanF* get_vendor_plan_f (Key const& key) +{ + if (auto found = s_plans_f.find(key); found != s_plans_f.end()) { + return &(found->second); + } else { + return nullptr; + } +} + +void add_vendor_plan_d (Key const& key, PlanD plan) +{ + s_plans_d[key] = plan; +} + +void add_vendor_plan_f (Key const& key, PlanF plan) +{ + s_plans_f[key] = plan; +} + +} + namespace amrex::FFT::detail { diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index 315e0641ac..efe7ab0b1e 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -144,35 +145,23 @@ struct Plan VendorPlan plan2{}; void* pf = nullptr; void* pb = nullptr; -#ifdef AMREX_USE_CUDA - std::size_t work_size = 0; + +#ifdef AMREX_USE_GPU + void set_ptrs (void* p0, void* p1) { + pf = p0; + pb = p1; + } #endif void destroy () { if (defined) { -#if defined(AMREX_USE_CUDA) - AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan)); -#elif defined(AMREX_USE_HIP) - AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan)); -#elif defined(AMREX_USE_SYCL) - std::visit([](auto&& p) { delete p; }, plan); -#else - if constexpr (std::is_same_v) { - fftwf_destroy_plan(plan); - } else { - fftw_destroy_plan(plan); - } -#endif + destroy_vendor_plan(plan); defined = false; } #if !defined(AMREX_USE_GPU) if (defined2) { - if constexpr (std::is_same_v) { - fftwf_destroy_plan(plan2); - } else { - fftw_destroy_plan(plan2); - } + destroy_vendor_plan(plan2); defined2 = false; } #endif @@ -211,19 +200,20 @@ struct Plan amrex::ignore_unused(nc); #if defined(AMREX_USE_CUDA) + AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan)); AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0)); + std::size_t work_size; if constexpr (D == Direction::forward) { cufftType fwd_type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; AMREX_CUFFT_SAFE_CALL (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); } else { cufftType bwd_type = std::is_same_v ? CUFFT_C2R : CUFFT_Z2D; AMREX_CUFFT_SAFE_CALL (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); } + #elif defined(AMREX_USE_HIP) auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; @@ -300,6 +290,9 @@ struct Plan #endif } + template + void init_r2c (IntVectND const& fft_size, void*, void*, bool cache); + template void init_c2c (Box const& box, VendorComplex* p) { @@ -318,9 +311,9 @@ struct Plan AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0)); cufftType t = std::is_same_v ? CUFFT_C2C : CUFFT_Z2Z; + std::size_t work_size; AMREX_CUFFT_SAFE_CALL (cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); #elif defined(AMREX_USE_HIP) @@ -479,9 +472,9 @@ struct Plan AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan)); AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0)); cufftType fwd_type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; + std::size_t work_size; AMREX_CUFFT_SAFE_CALL (cufftMakePlanMany(plan, 1, &nex, nullptr, 1, nc*2, nullptr, 1, nc, fwd_type, howmany, &work_size)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); #elif defined(AMREX_USE_HIP) @@ -589,8 +582,14 @@ struct Plan auto* po = (TO*)((D == Direction::forward) ? pb : pf); #if defined(AMREX_USE_CUDA) + AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); + + std::size_t work_size = 0; + AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size)); + auto* work_area = The_Arena()->alloc(work_size); AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area)); + if constexpr (D == Direction::forward) { if constexpr (std::is_same_v) { AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pi, po)); @@ -629,8 +628,14 @@ struct Plan auto* p = (VendorComplex*)pf; #if defined(AMREX_USE_CUDA) + AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); + + std::size_t work_size = 0; + AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size)); + auto* work_area = The_Arena()->alloc(work_size); AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area)); + auto dir = (D == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE; if constexpr (std::is_same_v) { AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, dir)); @@ -1065,8 +1070,14 @@ struct Plan #if defined(AMREX_USE_CUDA) + AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); + + std::size_t work_size = 0; + AMREX_CUFFT_SAFE_CALL(cufftGetSize(plan, &work_size)); + auto* work_area = The_Arena()->alloc(work_size); AMREX_CUFFT_SAFE_CALL(cufftSetWorkArea(plan, work_area)); + if constexpr (std::is_same_v) { AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, (T*)pscratch, (VendorComplex*)pscratch)); } else { @@ -1097,10 +1108,190 @@ struct Plan if (defined2) { fftw_execute(plan2); } } +#endif + } + + static void destroy_vendor_plan (VendorPlan plan) + { +#if defined(AMREX_USE_CUDA) + AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan)); +#elif defined(AMREX_USE_HIP) + AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan)); +#elif defined(AMREX_USE_SYCL) + std::visit([](auto&& p) { delete p; }, plan); +#else + if constexpr (std::is_same_v) { + fftwf_destroy_plan(plan); + } else { + fftw_destroy_plan(plan); + } #endif } }; +using Key = std::tuple,Direction,Kind>; +using PlanD = typename Plan::VendorPlan; +using PlanF = typename Plan::VendorPlan; + +PlanD* get_vendor_plan_d (Key const& key); +PlanF* get_vendor_plan_f (Key const& key); + +void add_vendor_plan_d (Key const& key, PlanD plan); +void add_vendor_plan_f (Key const& key, PlanF plan); + +template +template +void Plan::init_r2c (IntVectND const& fft_size, void* pbf, void* pbb, bool cache) +{ + static_assert(D == Direction::forward || D == Direction::backward); + + kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b; + defined = true; + pf = pbf; + pb = pbb; + + n = 1; + for (auto s : fft_size) { n *= s; } + howmany = 1; + +#if defined(AMREX_USE_GPU) + Key key = {fft_size.template expand<3>(), D, kind}; + if (cache) { + VendorPlan* cached_plan = nullptr; + if constexpr (std::is_same_v) { + cached_plan = get_vendor_plan_f(key); + } else { + cached_plan = get_vendor_plan_d(key); + } + if (cached_plan) { + plan = *cached_plan; + return; + } + } +#else + amrex::ignore_unused(cache); +#endif + +#if defined(AMREX_USE_CUDA) + + AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan)); + AMREX_CUFFT_SAFE_CALL(cufftSetAutoAllocation(plan, 0)); + cufftType type; + if constexpr (D == Direction::forward) { + type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; + } else { + type = std::is_same_v ? CUFFT_C2R : CUFFT_Z2D; + } + std::size_t work_size; + if constexpr (M == 1) { + AMREX_CUFFT_SAFE_CALL + (cufftMakePlan1d(plan, fft_size[0], type, howmany, &work_size)); + } else if constexpr (M == 2) { + AMREX_CUFFT_SAFE_CALL + (cufftMakePlan2d(plan, fft_size[1], fft_size[0], type, &work_size)); + } else if constexpr (M == 3) { + AMREX_CUFFT_SAFE_CALL + (cufftMakePlan3d(plan, fft_size[2], fft_size[1], fft_size[0], type, &work_size)); + } + +#elif defined(AMREX_USE_HIP) + + auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; + std::size_t length[M]; + for (int idim = 0; idim < M; ++idim) { length[idim] = fft_size[idim]; } + if constexpr (D == Direction::forward) { + AMREX_ROCFFT_SAFE_CALL + (rocfft_plan_create(&plan, rocfft_placement_notinplace, + rocfft_transform_type_real_forward, prec, M, + length, howmany, nullptr)); + } else { + AMREX_ROCFFT_SAFE_CALL + (rocfft_plan_create(&plan, rocfft_placement_notinplace, + rocfft_transform_type_real_inverse, prec, M, + length, howmany, nullptr)); + } + +#elif defined(AMREX_USE_SYCL) + + mkl_desc_r* pp; + if (M == 1) { + pp = new mkl_desc_r(fft_size[0]); + } else { + std::vector len(M); + for (int idim = 0; idim < M; ++idim) { + len[idim] = fft_size[M-1-idim]; + } + pp = new mkl_desc_r(len); + } +#ifndef AMREX_USE_MKL_DFTI_2024 + pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, + oneapi::mkl::dft::config_value::NOT_INPLACE); +#else + pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); +#endif + + std::vector strides(M+1); + strides[0] = 0; + strides[M] = 1; + for (int i = M-1; i >= 1; --i) { + strides[i] = strides[i+1] * fft_size[M-1-i]; + } + +#ifndef AMREX_USE_MKL_DFTI_2024 + pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides); + pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides); +#else + pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data()); + // Do not set BWD_STRIDES +#endif + pp->set_value(oneapi::mkl::dft::config_param::WORKSPACE, + oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); + pp->commit(amrex::Gpu::Device::streamQueue()); + plan = pp; + +#else /* FFTW */ + + if (pf == nullptr || pb == nullptr) { return; } + + int size_for_row_major[M]; + for (int idim = 0; idim < M; ++idim) { + size_for_row_major[idim] = fft_size[M-1-idim]; + } + + if constexpr (std::is_same_v) { + if constexpr (D == Direction::forward) { + plan = fftwf_plan_dft_r2c + (M, size_for_row_major, (float*)pf, (fftwf_complex*)pb, + FFTW_ESTIMATE); + } else { + plan = fftwf_plan_dft_c2r + (M, size_for_row_major, (fftwf_complex*)pb, (float*)pf, + FFTW_ESTIMATE); + } + } else { + if constexpr (D == Direction::forward) { + plan = fftw_plan_dft_r2c + (M, size_for_row_major, (double*)pf, (fftw_complex*)pb, + FFTW_ESTIMATE); + } else { + plan = fftw_plan_dft_c2r + (M, size_for_row_major, (fftw_complex*)pb, (double*)pf, + FFTW_ESTIMATE); + } + } +#endif + +#if defined(AMREX_USE_GPU) + if (cache) { + if constexpr (std::is_same_v) { + add_vendor_plan_f(key, plan); + } else { + add_vendor_plan_d(key, plan); + } + } +#endif +} + namespace detail { DistributionMapping make_iota_distromap (Long n); diff --git a/Src/FFT/AMReX_FFT_LocalR2C.H b/Src/FFT/AMReX_FFT_LocalR2C.H new file mode 100644 index 0000000000..11b4be6149 --- /dev/null +++ b/Src/FFT/AMReX_FFT_LocalR2C.H @@ -0,0 +1,333 @@ +#ifndef AMREX_FFT_LOCAL_R2C_H_ +#define AMREX_FFT_LOCAL_R2C_H_ +#include + +#include +#include + +namespace amrex::FFT +{ + +/** + * \brief Local Discrete Fourier Transform + * + * This class supports Fourier transforms between real and complex data. The + * name R2C indicates that the forward transform converts real data to + * complex data, while the backward transform converts complex data to real + * data. It should be noted that both directions of transformation are + * supported, not just from real to complex. The scaling follows the FFTW + * convention, where applying the forward transform followed by the backward + * transform scales the original data by the size of the input array. + * + * For more details, we refer the users to + * https://amrex-codes.github.io/amrex/docs_html/FFT_Chapter.html. + */ +template +class LocalR2C +{ +public: + /** + * \brief Constructor + * + * Given the diverse interfaces of FFT libraries we use, this constructo + * has a number of optional arguments. + * + * The user can provide the data pointers to the constructor. They are + * only needed by FFTW because its plan creation requires the input and + * output arrays. If they are null, we will delay the plan creation for + * FFTW until the forward or backward function is called. + * + * The cache_plan option is only used when we use cufft, rocfft and + * onemkl, but not FFTW. + * + * \param fft_size The forward domain size (i.e., the domain of the real data) + * \param p_fwd Forward domain data pointer (optional) + * \param p_bwd Backward domain data pointer (optional) + * \param cache_plan Try to cache the plan or not (optionl) + */ + explicit LocalR2C (IntVectND const& fft_size, + T* p_fwd = nullptr, + GpuComplex* p_bwd = nullptr, +#ifdef AMREX_USE_GPU + bool cache_plan = true); +#else + bool cache_plan = false); +#endif + + ~LocalR2C (); + + LocalR2C () = default; + LocalR2C (LocalR2C &&) noexcept; + LocalR2C& operator= (LocalR2C &&) noexcept; + + LocalR2C (LocalR2C const&) = delete; + LocalR2C& operator= (LocalR2C const&) = delete; + + /** + * \brief Forward transform + * + * This function is not available when this class template is + * instantiated for backward-only transform. For GPUs, this function is + * synchronous on the host. + * + * \param indata input data + * \param outdata output data + */ + template = 0> + void forward (T const* indata, GpuComplex* outdata); + + void clear (); + + /** + * \brief Backward transform + * + * This function is not available when this class template is + * instantiated for forward-only transform. For GPUs, this function is + * synchronous on the host. + * + * \param indata input data + * \param outdata output data + */ + template = 0> + void backward (GpuComplex const* indata, T* outdata); + + //! Scaling factor. If the data goes through forward and then backward, + //! the result multiplied by the scaling factor is equal to the original + //! data. + [[nodiscard]] T scalingFactor () const; + + //! Spectral domain size + [[nodiscard]] IntVectND const& spectralSize () const { + return m_spectral_size; + } + +private: + + Plan m_fft_fwd; + Plan m_fft_bwd; + + T* m_p_fwd = nullptr; + GpuComplex* m_p_bwd = nullptr; + +#if defined(AMREX_USE_SYCL) + gpuStream_t m_gpu_stream{}; +#endif + + IntVectND m_real_size; + IntVectND m_spectral_size; + + bool m_cache_plan = false; +}; + +template +LocalR2C::LocalR2C (IntVectND const& fft_size, T* p_fwd, + GpuComplex* p_bwd, bool cache_plan) + : m_p_fwd(p_fwd), + m_p_bwd(p_bwd), + m_real_size(fft_size), + m_spectral_size(fft_size) +#if defined(AMREX_USE_GPU) + , m_cache_plan(cache_plan) +#endif +{ +#if !defined(AMREX_USE_GPU) + amrex::ignore_unused(cache_plan); +#endif + + BL_PROFILE("FFT::LocalR2C"); + m_spectral_size[0] = m_real_size[0]/2 + 1; + +#if defined(AMREX_USE_SYCL) + + auto current_stream = Gpu::gpuStream(); + Gpu::Device::resetStreamIndex(); + m_gpu_stream = Gpu::gpuStream(); + +#endif + + auto* pf = (void*)m_p_fwd; + auto* pb = (void*)m_p_bwd; + +#ifdef AMREX_USE_SYCL + m_fft_fwd.template init_r2c(m_real_size, pf, pb, m_cache_plan); + m_fft_bwd = m_fft_fwd; +#else + if constexpr (D == Direction::both || D == Direction::forward) { + m_fft_fwd.template init_r2c(m_real_size, pf, pb, m_cache_plan); + } + if constexpr (D == Direction::both || D == Direction::backward) { + m_fft_bwd.template init_r2c(m_real_size, pf, pb, m_cache_plan); + } +#endif + +#if defined(AMREX_USE_SYCL) + Gpu::Device::setStream(current_stream); +#endif +} + +template +void LocalR2C::clear () +{ + if (!m_cache_plan) { + if (m_fft_bwd.plan != m_fft_fwd.plan) { + m_fft_bwd.destroy(); + } + m_fft_fwd.destroy(); + } + + m_fft_fwd = Plan{}; + m_fft_bwd = Plan{}; +} + +template +LocalR2C::~LocalR2C () +{ + static_assert(M >= 1 && M <= 3); + clear(); +} + +template +LocalR2C::LocalR2C (LocalR2C && rhs) noexcept + : m_p_fwd(rhs.m_p_fwd), + m_p_bwd(rhs.m_p_bwd), + m_fft_fwd(rhs.m_fft_fwd), + m_fft_bwd(rhs.m_fft_bwd), +#if defined(AMREX_USE_SYCL) + m_gpu_stream(rhs.m_gpu_stream), +#endif + m_real_size(rhs.m_real_size), + m_spectral_size(rhs.m_spectral_size), + m_cache_plan(rhs.m_cache_plan) +{ + rhs.m_cache_plan = true; // So that plans in rhs are not destroyed. +} + +template +LocalR2C& LocalR2C::operator= (LocalR2C && rhs) noexcept +{ + if (this == &rhs) { return *this; } + + this->clear(); + + m_p_fwd = rhs.m_p_fwd; + m_p_bwd = rhs.m_p_bwd; + m_fft_fwd = rhs.m_fft_fwd; + m_fft_bwd = rhs.m_fft_bwd; +#if defined(AMREX_USE_SYCL) + m_gpu_stream = rhs.m_gpu_stream; +#endif + m_real_size = rhs.m_real_size; + m_spectral_size = rhs.m_spectral_size; + m_cache_plan = rhs.m_cache_plan; + + rhs.m_cache_plan = true; // So that plans in rhs are not destroyed. + + return *this; +} + +template +template > +void LocalR2C::forward (T const* indata, GpuComplex* outdata) +{ + BL_PROFILE("FFT::LocalR2C::forward"); + +#if defined(AMREX_USE_GPU) + + m_fft_fwd.set_ptrs((void*)indata, (void*)outdata); + +#if defined(AMREX_USE_SYCL) + auto current_stream = Gpu::gpuStream(); + if (current_stream != m_gpu_stream) { + Gpu::streamSynchronize(); + Gpu::Device::setStream(m_gpu_stream); + } +#endif + +#else /* FFTW */ + + if (((T*)indata != m_p_fwd) || (outdata != m_p_bwd)) { + m_p_fwd = (T*)indata; + m_p_bwd = outdata; + auto* pf = (void*)m_p_fwd; + auto* pb = (void*)m_p_bwd; + m_fft_fwd.destroy(); + m_fft_fwd.template init_r2c(m_real_size, pf, pb, false); + if constexpr (D == Direction::both) { + m_fft_bwd.destroy(); + m_fft_bwd.template init_r2c(m_real_size, pf, pb, false); + } + } + +#endif + + m_fft_fwd.template compute_r2c(); + +#if defined(AMREX_USE_SYCL) + if (current_stream != m_gpu_stream) { + Gpu::Device::setStream(current_stream); + } +#endif +} + +template +template > +void LocalR2C::backward (GpuComplex const* indata, T* outdata) +{ + BL_PROFILE("FFT::LocalR2C::backward"); + +#if defined(AMREX_USE_GPU) + + m_fft_bwd.set_ptrs((void*)outdata, (void*)indata); + +#if defined(AMREX_USE_SYCL) + auto current_stream = Gpu::gpuStream(); + if (current_stream != m_gpu_stream) { + Gpu::streamSynchronize(); + Gpu::Device::setStream(m_gpu_stream); + } +#endif + +#else /* FFTW */ + + if (((GpuComplex*)indata != m_p_bwd) || (outdata != m_p_fwd)) { + m_p_fwd = outdata; + m_p_bwd = (GpuComplex*)indata; + auto* pf = (void*)m_p_fwd; + auto* pb = (void*)m_p_bwd; + m_fft_bwd.destroy(); + m_fft_bwd.template init_r2c(m_real_size, pf, pb, false); + if constexpr (D == Direction::both) { + m_fft_fwd.destroy(); + m_fft_fwd.template init_r2c(m_real_size, pf, pb, false); + } + } + +#endif + + m_fft_bwd.template compute_r2c(); + +#if defined(AMREX_USE_SYCL) + if (current_stream != m_gpu_stream) { + Gpu::Device::setStream(current_stream); + } +#endif +} + +template +T LocalR2C::scalingFactor () const +{ + T r = 1; + for (auto s : m_real_size) { + r *= T(s); + } + return T(1)/r; +} + +} + +#endif diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index aaa5fac4c3..456a3ddf7d 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -14,7 +14,7 @@ namespace amrex::FFT template class OpenBCSolver; /** - * \brief Discrete Fourier Transform + * \brief Parallel Discrete Fourier Transform * * This class supports Fourier transforms between real and complex data. The * name R2C indicates that the forward transform converts real data to diff --git a/Src/FFT/CMakeLists.txt b/Src/FFT/CMakeLists.txt index cbb205dd2e..6dd8150711 100644 --- a/Src/FFT/CMakeLists.txt +++ b/Src/FFT/CMakeLists.txt @@ -7,6 +7,7 @@ foreach(D IN LISTS AMReX_SPACEDIM) PRIVATE AMReX_FFT.H AMReX_FFT.cpp + AMReX_FFT_LocalR2C.H AMReX_FFT_OpenBCSolver.H AMReX_FFT_R2C.H AMReX_FFT_R2X.H diff --git a/Src/FFT/Make.package b/Src/FFT/Make.package index 82cc4803ea..fb369b7caf 100644 --- a/Src/FFT/Make.package +++ b/Src/FFT/Make.package @@ -3,6 +3,7 @@ ifndef AMREX_FFT_MAKE CEXE_headers += AMReX_FFT.H AMReX_FFT_Helper.H AMReX_FFT_Poisson.H CEXE_headers += AMReX_FFT_OpenBCSolver.H AMReX_FFT_R2C.H AMReX_FFT_R2X.H +CEXE_headers += AMReX_FFT_LocalR2C.H CEXE_sources += AMReX_FFT.cpp VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT diff --git a/Tests/FFT/R2C/main.cpp b/Tests/FFT/R2C/main.cpp index 594a9ec760..ee70b43b7b 100644 --- a/Tests/FFT/R2C/main.cpp +++ b/Tests/FFT/R2C/main.cpp @@ -17,7 +17,7 @@ int main (int argc, char* argv[]) int n_cell_y = 32;, int n_cell_z = 64); - AMREX_D_TERM(int max_grid_size_x = 32;, + AMREX_D_TERM(int max_grid_size_x = 64;, int max_grid_size_y = 32;, int max_grid_size_z = 32); @@ -120,6 +120,49 @@ int main (int argc, char* argv[]) auto eps = 1.e-6f; #else auto eps = 1.e-13; +#endif + AMREX_ALWAYS_ASSERT(error < eps); + } + + { + Real error = 0; + BaseFab> cfab; + for (MFIter mfi(mf); mfi.isValid(); ++mfi) + { + auto& fab = mf[mfi]; + auto& fab2 = mf2[mfi]; + Box const& box = fab.box(); + { + FFT::LocalR2C fft(box.length()); + Box cbox(IntVect(0), fft.spectralSize() - 1); + cfab.resize(cbox); + fft.forward(fab.dataPtr(), cfab.dataPtr()); + fft.backward(cfab.dataPtr(), fab2.dataPtr()); + auto fac = fft.scalingFactor(); + fab2.template xpay(-fac, fab, box, box, 0, 0, 1); + auto e = fab2.template norm(0); + error = std::max(e,error); + } + { + FFT::LocalR2C fft(box.length()); + fft.forward(fab.dataPtr(), cfab.dataPtr()); + } + { + FFT::LocalR2C fft(box.length()); + fft.backward(cfab.dataPtr(), fab2.dataPtr()); + auto fac = fft.scalingFactor(); + fab2.template xpay(-fac, fab, box, box, 0, 0, 1); + auto e = fab2.template norm(0); + error = std::max(e,error); + } + } + + ParallelDescriptor::ReduceRealMax(error); + amrex::Print() << " Expected to be close to zero: " << error << "\n"; +#ifdef AMREX_USE_FLOAT + auto eps = 1.e-6f; +#else + auto eps = 1.e-13; #endif AMREX_ALWAYS_ASSERT(error < eps); }