Skip to content

Commit

Permalink
Add local FFT (#4224)
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang authored Nov 10, 2024
1 parent 57380f3 commit 6502398
Show file tree
Hide file tree
Showing 11 changed files with 724 additions and 47 deletions.
37 changes: 31 additions & 6 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@ FFT::R2C Class
==============

Class template `FFT::R2C` supports discrete Fourier transforms between real
and complex data. The name R2C indicates that the forward transform converts
real data to complex data, while the backward transform converts complex
data to real data. It should be noted that both directions of transformation
are supported, not just from real to complex.
and complex data across MPI processes. The name R2C indicates that the
forward transform converts real data to complex data, while the backward
transform converts complex data to real data. It should be noted that both
directions of transformation are supported, not just from real to complex.

The implementation utilizes cuFFT, rocFFT, oneMKL and FFTW, for CUDA, HIP,
SYCL and CPU builds, respectively. Because the parallel communication is
handled by AMReX, it does not need the parallel version of
FFTW. Furthermore, there is no constraint on the domain decomposition such
as one Box per process. This class performs parallel FFT on AMReX's parallel
data containers (e.g., :cpp:`MultiFab` and
:cpp:`FabArray<BaseFab<ComplexData<Real>>>`. For local FFT, the users can
use FFTW, cuFFT, rocFFT, or oneMKL directly.
:cpp:`FabArray<BaseFab<ComplexData<Real>>>`.

Other than using column-majored order, AMReX follows the convention of
FFTW. Applying the forward transform followed by the backward transform
Expand Down Expand Up @@ -68,6 +67,32 @@ object. Therefore, one should cache it for reuse if possible. Although
:cpp:`std::unique_ptr<FFT::R2C<Real>>` to store an object in one's class.


.. _sec:FFT:localr2c:

FFT::LocalR2C Class
===================

Class template `FFT::LocalR2C` supports local discrete Fourier transforms
between real and complex data. The name R2C indicates that the forward
transform converts real data to complex data, while the backward transform
converts complex data to real data. It should be noted that both directions
of transformation are supported, not just from real to complex.

Below is an example of using :cpp:`FFT::LocalR2C`.

.. highlight:: c++

::

MultiFab mf(...);
BaseFab<GpuComplex<T>> cfab;
for (MFIter mfi(mf); mfi.isValid(); ++mfi) {
FFT::LocalR2C fft(mfi.fabbox().length());
cfab.resize(IntVect(0), fft.spectralSize()-1);
fft.forward(mf[mfi].dataPtr(), cfab.dataPtr());
}


Poisson Solver
==============

Expand Down
8 changes: 8 additions & 0 deletions Src/Base/AMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
#include <AMReX_Geometry.H>
#include <AMReX_Gpu.H>

#ifdef AMREX_USE_FFT
#include <AMReX_FFT.H>
#endif

#ifdef AMREX_USE_HYPRE
#include <_hypre_utilities.h>
#ifdef AMREX_USE_CUDA
Expand Down Expand Up @@ -655,6 +659,10 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
AsyncOut::Initialize();
VectorGrowthStrategy::Initialize();

#ifdef AMREX_USE_FFT
FFT::Initialize();
#endif

#ifdef AMREX_USE_EB
EB2::Initialize();
#endif
Expand Down
15 changes: 0 additions & 15 deletions Src/Base/AMReX_GpuDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@
#include <roctracer/roctx.h>
#endif
#endif
#if defined(AMREX_USE_FFT)
# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
# include <rocfft/rocfft.h>
# else
# include <rocfft.h>
# endif
#endif
#endif

#ifdef AMREX_USE_ACC
Expand Down Expand Up @@ -317,10 +310,6 @@ Device::Initialize ()
}
#endif /* AMREX_USE_MPI */

#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT)
AMREX_ROCFFT_SAFE_CALL(rocfft_setup());
#endif

if (amrex::Verbose()) {
#if defined(AMREX_USE_CUDA)
amrex::Print() << "CUDA"
Expand Down Expand Up @@ -360,10 +349,6 @@ Device::Finalize ()
#ifdef AMREX_USE_GPU
Device::profilerStop();

#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT)
AMREX_ROCFFT_SAFE_CALL(rocfft_cleanup());
#endif

#ifdef AMREX_USE_SYCL
for (auto& s : gpu_stream_pool) {
delete s.queue;
Expand Down
8 changes: 8 additions & 0 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
#define AMREX_FFT_H_
#include <AMReX_Config.H>

#include <AMReX_FFT_LocalR2C.H>
#include <AMReX_FFT_OpenBCSolver.H>
#include <AMReX_FFT_R2C.H>
#include <AMReX_FFT_R2X.H>

namespace amrex::FFT
{
void Initialize ();
void Finalize ();
void Clear ();
}

#endif
82 changes: 82 additions & 0 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,87 @@
#include <AMReX_FFT.H>
#include <AMReX_FFT_Helper.H>

#include <map>

namespace amrex::FFT
{

namespace
{
bool s_initialized = false;
std::map<Key, PlanD> s_plans_d;
std::map<Key, PlanF> s_plans_f;
}

void Initialize ()
{
if (!s_initialized)
{
s_initialized = true;

#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT)
AMREX_ROCFFT_SAFE_CALL(rocfft_setup());
#endif
}

amrex::ExecOnFinalize(amrex::FFT::Finalize);
}

void Finalize ()
{
if (s_initialized)
{
s_initialized = false;

Clear();

#if defined(AMREX_USE_HIP) && defined(AMREX_USE_FFT)
AMREX_ROCFFT_SAFE_CALL(rocfft_cleanup());
#endif
}
}

void Clear ()
{
for (auto& [k, p] : s_plans_d) {
Plan<double>::destroy_vendor_plan(p);
}

for (auto& [k, p] : s_plans_f) {
Plan<float>::destroy_vendor_plan(p);
}
}

PlanD* get_vendor_plan_d (Key const& key)
{
if (auto found = s_plans_d.find(key); found != s_plans_d.end()) {
return &(found->second);
} else {
return nullptr;
}
}

PlanF* get_vendor_plan_f (Key const& key)
{
if (auto found = s_plans_f.find(key); found != s_plans_f.end()) {
return &(found->second);
} else {
return nullptr;
}
}

void add_vendor_plan_d (Key const& key, PlanD plan)
{
s_plans_d[key] = plan;
}

void add_vendor_plan_f (Key const& key, PlanF plan)
{
s_plans_f[key] = plan;
}

}

namespace amrex::FFT::detail
{

Expand Down
Loading

0 comments on commit 6502398

Please sign in to comment.