Skip to content

Commit

Permalink
FFT Poisson Solver: Fill ghost cells in solve() function
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Nov 23, 2024
1 parent 12d6af2 commit 3ea8bd9
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 55 deletions.
114 changes: 111 additions & 3 deletions Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
namespace amrex::FFT
{

namespace detail {
template <typename MF>
void fill_physbc (MF& mf, Geometry const& geom,
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> const& bc);
}

/**
* \brief Poisson solver for periodic, Dirichlet & Neumann boundaries using
* FFT.
Expand Down Expand Up @@ -48,6 +54,13 @@ public:
}
}

/*
* \brief Solve del dot grad soln = rhs
*
* If soln has ghost cells, one layer of ghost cells will be filled
* except for the corners of the physical domain where they are not used
* in the cross stencil of the operator. The two MFs could be the MF.
*/
void solve (MF& soln, MF const& rhs);

private:
Expand Down Expand Up @@ -104,6 +117,13 @@ public:
#endif
}

/*
* \brief Solve del dot grad soln = rhs
*
* If soln has ghost cells, one layer of ghost cells will be filled
* except for the corners of the physical domain where they are not used
* in the cross stencil of the operator. The two MFs could be the MF.
*/
void solve (MF& soln, MF const& rhs);
void solve (MF& soln, MF const& rhs, Vector<T> const& dz);
void solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> const& dz);
Expand All @@ -121,6 +141,8 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
{
BL_PROFILE("FFT::Poisson::solve");

AMREX_ASSERT(soln.is_cell_centered() && rhs.is_cell_centered());

using T = typename MF::value_type;

GpuArray<T,AMREX_SPACEDIM> fac
Expand Down Expand Up @@ -170,10 +192,15 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
spectral_data *= scale;
};

IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1));

if (m_r2x) {
m_r2x->forwardThenBackward(rhs, soln, f);
m_r2x->forwardThenBackward_doit(rhs, soln, f, ng, m_geom.periodicity());
detail::fill_physbc(soln, m_geom, m_bc);
} else {
m_r2c->forwardThenBackward(rhs, soln, f);
m_r2c->forward(rhs);
m_r2c->post_forward_doit(f);
m_r2c->backward_doit(soln, ng, m_geom.periodicity());
}
}

Expand Down Expand Up @@ -254,6 +281,8 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> con
template <typename MF>
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
{
AMREX_ASSERT(soln.is_cell_centered() && rhs.is_cell_centered());

#ifdef AMREX_USE_GPU
Gpu::DeviceVector<T> d_dz(dz.size());
Gpu::htod_memcpy_async(d_dz.data(), dz.data(), dz.size()*sizeof(T));
Expand Down Expand Up @@ -414,9 +443,88 @@ void PoissonHybrid<MF>::solve_doit (MF& soln, MF const& rhs, DZ const& dz)
#endif
}

m_r2c.backward(spmf, soln);
IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1));
m_r2c.backward_doit(spmf, soln, ng, m_geom.periodicity());

Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> bc
{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic),
std::make_pair(Boundary::periodic,Boundary::periodic),
std::make_pair(Boundary::even,Boundary::even))};
detail::fill_physbc(soln, m_geom, bc);
#endif
}

namespace detail {

template <class T>
struct FFTPhysBCTag {
Array4<T> dfab;
Box dbox;
Boundary bc;
Orientation face;

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Box const& box () const noexcept { return dbox; }
};

template <typename MF>
void fill_physbc (MF& mf, Geometry const& geom,
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> const& bc)
{
using T = typename MF::value_type;
using Tag = FFTPhysBCTag<T>;
Vector<Tag> tags;

for (MFIter mfi(mf, MFItInfo{}.DisableDeviceSync()); mfi.isValid(); ++mfi)
{
auto const& box = mfi.fabbox();
auto const& arr = mf.array(mfi);
for (OrientationIter oit; oit; ++oit) {
Orientation face = oit();
int idim = face.coordDir();
Box b = geom.Domain();
Boundary fbc;
if (face.isLow()) {
b.setRange(idim,geom.Domain().smallEnd(idim)-1);
fbc = bc[idim].first;
} else {
b.setRange(idim,geom.Domain().bigEnd(idim)+1);
fbc = bc[idim].second;
}
b &= box;
if (b.ok() && fbc != Boundary::periodic) {
tags.push_back({arr, b, fbc, face});
}
}
}

#if defined(AMREX_USE_GPU)
amrex::ParallelFor(tags, [=] AMREX_GPU_DEVICE (int i, int j, int k,
Tag const& tag) noexcept
#else
auto ntags = int(tags.size());
#ifdef AMREX_USE_OMP
#pragma omp parallel for
#endif
for (int itag = 0; itag < ntags; ++itag) {
Tag& tag = tags[itag];
amrex::LoopOnCpu(tag.dbox, [&] (int i, int j, int k)
#endif
{
int sgn = tag.face.isLow() ? 1 : -1;
IntVect siv = IntVect(AMREX_D_DECL(i,j,k))
+ sgn * IntVect::TheDimensionVector(tag.face.coordDir());
if (tag.bc == Boundary::odd) {
tag.dfab(i,j,k) = -tag.dfab(siv);
} else { // even
tag.dfab(i,j,k) = tag.dfab(siv);
}
});
#if !defined(AMREX_USE_GPU)
}
#endif
}
}

}

Expand Down
31 changes: 24 additions & 7 deletions Src/FFT/AMReX_FFT_R2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace amrex::FFT
{

template <typename T> class OpenBCSolver;
template <typename T> class Poisson;
template <typename T> class PoissonHybrid;

/**
* \brief Parallel Discrete Fourier Transform
Expand All @@ -38,6 +40,8 @@ public:
using cMF = FabArray<BaseFab<GpuComplex<T> > >;

template <typename U> friend class OpenBCSolver;
template <typename U> friend class Poisson;
template <typename U> friend class PoissonHybrid;

/**
* \brief Constructor
Expand Down Expand Up @@ -160,15 +164,20 @@ public:
*/
[[nodiscard]] std::pair<BoxArray,DistributionMapping> getSpectralDataLayout () const;

// public for cuda
// This is a private function, but it's public for cuda.
template <typename F>
void post_forward_doit (F const& post_forward);

private:

void prepare_openbc ();

void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0));
void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0),
Periodicity const& period = Periodicity::NonPeriodic());

private:
void backward_doit (cMF const& inmf, MF& outmf,
IntVect const& ngout = IntVect(0),
Periodicity const& period = Periodicity::NonPeriodic());

static std::pair<Plan<T>,Plan<T>> make_c2c_plans (cMF& inout);

Expand Down Expand Up @@ -539,13 +548,14 @@ void R2C<T,D,S>::backward (MF& outmf)
}

template <typename T, Direction D, DomainStrategy S>
void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout,
Periodicity const& period)
{
BL_PROFILE("FFT::R2C::backward(out)");

if (m_do_alld_fft) {
m_fft_bwd_x.template compute_r2c<Direction::backward>();
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout);
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period);
return;
}

Expand All @@ -567,7 +577,7 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)

auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x;
fft_x.template compute_r2c<Direction::backward>();
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout);
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period);
}

template <typename T, Direction D, DomainStrategy S>
Expand Down Expand Up @@ -694,6 +704,13 @@ template <typename T, Direction D, DomainStrategy S>
template <Direction DIR, std::enable_if_t<DIR == Direction::backward ||
DIR == Direction::both, int> >
void R2C<T,D,S>::backward (cMF const& inmf, MF& outmf)
{
backward_doit(inmf, outmf);
}

template <typename T, Direction D, DomainStrategy S>
void R2C<T,D,S>::backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout,
Periodicity const& period)
{
BL_PROFILE("FFT::R2C::backward(inout)");

Expand All @@ -709,7 +726,7 @@ void R2C<T,D,S>::backward (cMF const& inmf, MF& outmf)
} else {
m_cx.ParallelCopy(inmf, 0, 0, 1);
}
backward_doit(outmf);
backward_doit(outmf, ngout, period);
}

template <typename T, Direction D, DomainStrategy S>
Expand Down
24 changes: 23 additions & 1 deletion Src/FFT/AMReX_FFT_R2X.H
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
namespace amrex::FFT
{

template <typename T> class Poisson;
template <typename T> class PoissonHybrid;

/**
* \brief Discrete Fourier Transform
*
Expand All @@ -25,6 +28,9 @@ public:
MultiFab, FabArray<BaseFab<T> > >;
using cMF = FabArray<BaseFab<GpuComplex<T> > >;

template <typename U> friend class Poisson;
template <typename U> friend class PoissonHybrid;

R2X (Box const& domain,
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> const& bc,
Info const& info = Info{});
Expand All @@ -46,6 +52,12 @@ public:
void post_forward_doit (FAB* fab, F const& f);

private:

template <typename F>
void forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward,
IntVect const& ngout = IntVect(0),
Periodicity const& period = Periodicity::NonPeriodic());

Box m_dom_0;
Array<std::pair<Boundary,Boundary>,AMREX_SPACEDIM> m_bc;

Expand Down Expand Up @@ -501,6 +513,16 @@ T R2X<T>::scalingFactor () const
template <typename T>
template <typename F>
void R2X<T>::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward)
{
forwardThenBackward_doit(inmf, outmf, post_forward);
}

template <typename T>
template <typename F>
void R2X<T>::forwardThenBackward_doit (MF const& inmf, MF& outmf,
F const& post_forward,
IntVect const& ngout,
Periodicity const& period)
{
BL_PROFILE("FFT::R2X::forwardbackward");

Expand Down Expand Up @@ -638,7 +660,7 @@ void R2X<T>::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forwa
} else {
m_fft_bwd_x.template compute_r2r<Direction::backward>();
}
outmf.ParallelCopy(m_rx, 0, 0, 1);
outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period);
}

template <typename T>
Expand Down
Loading

0 comments on commit 3ea8bd9

Please sign in to comment.