From 3ea8bd9f6aaf6db18652c24ecfa4874bb8076c13 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Sat, 23 Nov 2024 15:36:08 -0800 Subject: [PATCH] FFT Poisson Solver: Fill ghost cells in solve() function --- Src/FFT/AMReX_FFT_Poisson.H | 114 +++++++++++++++++++++++++++++++++++- Src/FFT/AMReX_FFT_R2C.H | 31 +++++++--- Src/FFT/AMReX_FFT_R2X.H | 24 +++++++- Tests/FFT/Poisson/main.cpp | 53 +++-------------- 4 files changed, 167 insertions(+), 55 deletions(-) diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H index 776d252dd8..835ac0c2d4 100644 --- a/Src/FFT/AMReX_FFT_Poisson.H +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -7,6 +7,12 @@ namespace amrex::FFT { +namespace detail { +template +void fill_physbc (MF& mf, Geometry const& geom, + Array,AMREX_SPACEDIM> const& bc); +} + /** * \brief Poisson solver for periodic, Dirichlet & Neumann boundaries using * FFT. @@ -48,6 +54,13 @@ public: } } + /* + * \brief Solve del dot grad soln = rhs + * + * If soln has ghost cells, one layer of ghost cells will be filled + * except for the corners of the physical domain where they are not used + * in the cross stencil of the operator. The two MFs could be the MF. + */ void solve (MF& soln, MF const& rhs); private: @@ -104,6 +117,13 @@ public: #endif } + /* + * \brief Solve del dot grad soln = rhs + * + * If soln has ghost cells, one layer of ghost cells will be filled + * except for the corners of the physical domain where they are not used + * in the cross stencil of the operator. The two MFs could be the MF. + */ void solve (MF& soln, MF const& rhs); void solve (MF& soln, MF const& rhs, Vector const& dz); void solve (MF& soln, MF const& rhs, Gpu::DeviceVector const& dz); @@ -121,6 +141,8 @@ void Poisson::solve (MF& soln, MF const& rhs) { BL_PROFILE("FFT::Poisson::solve"); + AMREX_ASSERT(soln.is_cell_centered() && rhs.is_cell_centered()); + using T = typename MF::value_type; GpuArray fac @@ -170,10 +192,15 @@ void Poisson::solve (MF& soln, MF const& rhs) spectral_data *= scale; }; + IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1)); + if (m_r2x) { - m_r2x->forwardThenBackward(rhs, soln, f); + m_r2x->forwardThenBackward_doit(rhs, soln, f, ng, m_geom.periodicity()); + detail::fill_physbc(soln, m_geom, m_bc); } else { - m_r2c->forwardThenBackward(rhs, soln, f); + m_r2c->forward(rhs); + m_r2c->post_forward_doit(f); + m_r2c->backward_doit(soln, ng, m_geom.periodicity()); } } @@ -254,6 +281,8 @@ void PoissonHybrid::solve (MF& soln, MF const& rhs, Gpu::DeviceVector con template void PoissonHybrid::solve (MF& soln, MF const& rhs, Vector const& dz) { + AMREX_ASSERT(soln.is_cell_centered() && rhs.is_cell_centered()); + #ifdef AMREX_USE_GPU Gpu::DeviceVector d_dz(dz.size()); Gpu::htod_memcpy_async(d_dz.data(), dz.data(), dz.size()*sizeof(T)); @@ -414,9 +443,88 @@ void PoissonHybrid::solve_doit (MF& soln, MF const& rhs, DZ const& dz) #endif } - m_r2c.backward(spmf, soln); + IntVect const& ng = amrex::elemwiseMin(soln.nGrowVect(), IntVect(1)); + m_r2c.backward_doit(spmf, soln, ng, m_geom.periodicity()); + + Array,AMREX_SPACEDIM> bc + {AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic), + std::make_pair(Boundary::periodic,Boundary::periodic), + std::make_pair(Boundary::even,Boundary::even))}; + detail::fill_physbc(soln, m_geom, bc); +#endif +} + +namespace detail { + +template +struct FFTPhysBCTag { + Array4 dfab; + Box dbox; + Boundary bc; + Orientation face; + + [[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE + Box const& box () const noexcept { return dbox; } +}; + +template +void fill_physbc (MF& mf, Geometry const& geom, + Array,AMREX_SPACEDIM> const& bc) +{ + using T = typename MF::value_type; + using Tag = FFTPhysBCTag; + Vector tags; + + for (MFIter mfi(mf, MFItInfo{}.DisableDeviceSync()); mfi.isValid(); ++mfi) + { + auto const& box = mfi.fabbox(); + auto const& arr = mf.array(mfi); + for (OrientationIter oit; oit; ++oit) { + Orientation face = oit(); + int idim = face.coordDir(); + Box b = geom.Domain(); + Boundary fbc; + if (face.isLow()) { + b.setRange(idim,geom.Domain().smallEnd(idim)-1); + fbc = bc[idim].first; + } else { + b.setRange(idim,geom.Domain().bigEnd(idim)+1); + fbc = bc[idim].second; + } + b &= box; + if (b.ok() && fbc != Boundary::periodic) { + tags.push_back({arr, b, fbc, face}); + } + } + } + +#if defined(AMREX_USE_GPU) + amrex::ParallelFor(tags, [=] AMREX_GPU_DEVICE (int i, int j, int k, + Tag const& tag) noexcept +#else + auto ntags = int(tags.size()); +#ifdef AMREX_USE_OMP +#pragma omp parallel for +#endif + for (int itag = 0; itag < ntags; ++itag) { + Tag& tag = tags[itag]; + amrex::LoopOnCpu(tag.dbox, [&] (int i, int j, int k) +#endif + { + int sgn = tag.face.isLow() ? 1 : -1; + IntVect siv = IntVect(AMREX_D_DECL(i,j,k)) + + sgn * IntVect::TheDimensionVector(tag.face.coordDir()); + if (tag.bc == Boundary::odd) { + tag.dfab(i,j,k) = -tag.dfab(siv); + } else { // even + tag.dfab(i,j,k) = tag.dfab(siv); + } + }); +#if !defined(AMREX_USE_GPU) + } #endif } +} } diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index 3d46bc47a5..4a2ceab3fd 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -12,6 +12,8 @@ namespace amrex::FFT { template class OpenBCSolver; +template class Poisson; +template class PoissonHybrid; /** * \brief Parallel Discrete Fourier Transform @@ -38,6 +40,8 @@ public: using cMF = FabArray > >; template friend class OpenBCSolver; + template friend class Poisson; + template friend class PoissonHybrid; /** * \brief Constructor @@ -160,15 +164,20 @@ public: */ [[nodiscard]] std::pair getSpectralDataLayout () const; - // public for cuda + // This is a private function, but it's public for cuda. template void post_forward_doit (F const& post_forward); +private: + void prepare_openbc (); - void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0)); + void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0), + Periodicity const& period = Periodicity::NonPeriodic()); -private: + void backward_doit (cMF const& inmf, MF& outmf, + IntVect const& ngout = IntVect(0), + Periodicity const& period = Periodicity::NonPeriodic()); static std::pair,Plan> make_c2c_plans (cMF& inout); @@ -539,13 +548,14 @@ void R2C::backward (MF& outmf) } template -void R2C::backward_doit (MF& outmf, IntVect const& ngout) +void R2C::backward_doit (MF& outmf, IntVect const& ngout, + Periodicity const& period) { BL_PROFILE("FFT::R2C::backward(out)"); if (m_do_alld_fft) { m_fft_bwd_x.template compute_r2c(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); return; } @@ -567,7 +577,7 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) auto& fft_x = m_openbc_half ? m_fft_bwd_x_half : m_fft_bwd_x; fft_x.template compute_r2c(); - outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); } template @@ -694,6 +704,13 @@ template template > void R2C::backward (cMF const& inmf, MF& outmf) +{ + backward_doit(inmf, outmf); +} + +template +void R2C::backward_doit (cMF const& inmf, MF& outmf, IntVect const& ngout, + Periodicity const& period) { BL_PROFILE("FFT::R2C::backward(inout)"); @@ -709,7 +726,7 @@ void R2C::backward (cMF const& inmf, MF& outmf) } else { m_cx.ParallelCopy(inmf, 0, 0, 1); } - backward_doit(outmf); + backward_doit(outmf, ngout, period); } template diff --git a/Src/FFT/AMReX_FFT_R2X.H b/Src/FFT/AMReX_FFT_R2X.H index 5d916ada3c..f7fa256bba 100644 --- a/Src/FFT/AMReX_FFT_R2X.H +++ b/Src/FFT/AMReX_FFT_R2X.H @@ -11,6 +11,9 @@ namespace amrex::FFT { +template class Poisson; +template class PoissonHybrid; + /** * \brief Discrete Fourier Transform * @@ -25,6 +28,9 @@ public: MultiFab, FabArray > >; using cMF = FabArray > >; + template friend class Poisson; + template friend class PoissonHybrid; + R2X (Box const& domain, Array,AMREX_SPACEDIM> const& bc, Info const& info = Info{}); @@ -46,6 +52,12 @@ public: void post_forward_doit (FAB* fab, F const& f); private: + + template + void forwardThenBackward_doit (MF const& inmf, MF& outmf, F const& post_forward, + IntVect const& ngout = IntVect(0), + Periodicity const& period = Periodicity::NonPeriodic()); + Box m_dom_0; Array,AMREX_SPACEDIM> m_bc; @@ -501,6 +513,16 @@ T R2X::scalingFactor () const template template void R2X::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward) +{ + forwardThenBackward_doit(inmf, outmf, post_forward); +} + +template +template +void R2X::forwardThenBackward_doit (MF const& inmf, MF& outmf, + F const& post_forward, + IntVect const& ngout, + Periodicity const& period) { BL_PROFILE("FFT::R2X::forwardbackward"); @@ -638,7 +660,7 @@ void R2X::forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forwa } else { m_fft_bwd_x.template compute_r2r(); } - outmf.ParallelCopy(m_rx, 0, 0, 1); + outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout, period); } template diff --git a/Tests/FFT/Poisson/main.cpp b/Tests/FFT/Poisson/main.cpp index 634a03154a..aaaebe4185 100644 --- a/Tests/FFT/Poisson/main.cpp +++ b/Tests/FFT/Poisson/main.cpp @@ -69,13 +69,9 @@ void make_rhs (MultiFab& rhs, Geometry const& geom, } std::pair check_convergence - (MultiFab const& soln, MultiFab const& rhs, Geometry const& geom, - Array,AMREX_SPACEDIM> const& fft_bc) + (MultiFab const& phi, MultiFab const& rhs, Geometry const& geom) { - MultiFab phi(soln.boxArray(), soln.DistributionMap(), 1, 1); - MultiFab res(soln.boxArray(), soln.DistributionMap(), 1, 0); - MultiFab::Copy(phi, soln, 0, 0, 1, 0); - phi.FillBoundary(geom.periodicity()); + MultiFab res(phi.boxArray(), phi.DistributionMap(), 1, 0); auto const& res_ma = res.arrays(); auto const& phi_ma = phi.const_arrays(); auto const& rhs_ma = rhs.const_arrays(); @@ -90,43 +86,12 @@ std::pair check_convergence ParallelFor(res, [=] AMREX_GPU_DEVICE (int b, int i, int j, int k) { auto const& phia = phi_ma[b]; - Real lap = 0; - if (i == 0 && fft_bc[0].first == FFT::Boundary::odd) { - lap += (-3._rt*phia(i,j,k)+phia(i+1,j,k)) * lapfac[0]; - } else if (i == 0 && fft_bc[0].first == FFT::Boundary::even) { - lap += (-phia(i,j,k)+phia(i+1,j,k)) * lapfac[0]; - } else if (i == n_cell_x-1 && fft_bc[0].second == FFT::Boundary::odd) { - lap += (phia(i-1,j,k)-3._rt*phia(i,j,k)) * lapfac[0]; - } else if (i == n_cell_x-1 && fft_bc[0].second == FFT::Boundary::even) { - lap += (phia(i-1,j,k)-phia(i,j,k)) * lapfac[0]; - } else { - lap += (phia(i-1,j,k)-2._rt*phia(i,j,k)+phia(i+1,j,k)) * lapfac[0]; - } + Real lap = (phia(i-1,j,k)-2._rt*phia(i,j,k)+phia(i+1,j,k)) * lapfac[0]; #if (AMREX_SPACEDIM >= 2) - if (j == 0 && fft_bc[1].first == FFT::Boundary::odd) { - lap += (-3._rt*phia(i,j,k)+phia(i,j+1,k)) * lapfac[1]; - } else if (j == 0 && fft_bc[1].first == FFT::Boundary::even) { - lap += (-phia(i,j,k)+phia(i,j+1,k)) * lapfac[1]; - } else if (j == n_cell_y-1 && fft_bc[1].second == FFT::Boundary::odd) { - lap += (phia(i,j-1,k)-3._rt*phia(i,j,k)) * lapfac[1]; - } else if (j == n_cell_y-1 && fft_bc[1].second == FFT::Boundary::even) { - lap += (phia(i,j-1,k)-phia(i,j,k)) * lapfac[1]; - } else { - lap += (phia(i,j-1,k)-2._rt*phia(i,j,k)+phia(i,j+1,k)) * lapfac[1]; - } + lap += (phia(i,j-1,k)-2._rt*phia(i,j,k)+phia(i,j+1,k)) * lapfac[1]; #endif #if (AMREX_SPACEDIM == 3) - if (k == 0 && fft_bc[2].first == FFT::Boundary::odd) { - lap += (-3._rt*phia(i,j,k)+phia(i,j,k+1)) * lapfac[2]; - } else if (k == 0 && fft_bc[2].first == FFT::Boundary::even) { - lap += (-phia(i,j,k)+phia(i,j,k+1)) * lapfac[2]; - } else if (k == n_cell_z-1 && fft_bc[2].second == FFT::Boundary::odd) { - lap += (phia(i,j,k-1)-3._rt*phia(i,j,k)) * lapfac[2]; - } else if (k == n_cell_z-1 && fft_bc[2].second == FFT::Boundary::even) { - lap += (phia(i,j,k-1)-phia(i,j,k)) * lapfac[2]; - } else { - lap += (phia(i,j,k-1)-2._rt*phia(i,j,k)+phia(i,j,k+1)) * lapfac[2]; - } + lap += (phia(i,j,k-1)-2._rt*phia(i,j,k)+phia(i,j,k+1)) * lapfac[2]; #endif res_ma[b](i,j,k) = rhs_ma[b](i,j,k) - lap; }); @@ -214,14 +179,14 @@ int main (int argc, char* argv[]) amrex::Print() << ")\n"; MultiFab rhs(ba,dm,1,0); - MultiFab soln(ba,dm,1,0); + MultiFab soln(ba,dm,1,1); soln.setVal(std::numeric_limits::max()); make_rhs(rhs, geom, fft_bc); FFT::Poisson fft_poisson(geom, fft_bc); fft_poisson.solve(soln, rhs); - auto [bnorm, rnorm] = check_convergence(soln, rhs, geom, fft_bc); + auto [bnorm, rnorm] = check_convergence(soln, rhs, geom); amrex::Print() << " rhs inf norm " << bnorm << "\n" << " res inf norm " << rnorm << "\n"; #ifdef AMREX_USE_FLOAT @@ -242,7 +207,7 @@ int main (int argc, char* argv[]) std::make_pair(FFT::Boundary::even,FFT::Boundary::even)}; MultiFab rhs(ba,dm,1,0); - MultiFab soln(ba,dm,1,0); + MultiFab soln(ba,dm,1,1); soln.setVal(std::numeric_limits::max()); make_rhs(rhs, geom, fft_bc); @@ -252,7 +217,7 @@ int main (int argc, char* argv[]) FFT::PoissonHybrid fft_poisson(geom); fft_poisson.solve(soln, rhs, dz); - auto [bnorm, rnorm] = check_convergence(soln, rhs, geom, fft_bc); + auto [bnorm, rnorm] = check_convergence(soln, rhs, geom); amrex::Print() << " rhs inf norm " << bnorm << "\n" << " res inf norm " << rnorm << "\n"; #ifdef AMREX_USE_FLOAT