Skip to content

Commit ceb8c11

Browse files
Use DST-2 DST-3 in FFTW Poisson solver (#1331)
1 parent ac10226 commit ceb8c11

31 files changed

Lines changed: 389 additions & 332 deletions

docs/source/run/parameters.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,16 @@ The default is to use the explicit solver. **We strongly recommend to use the ex
348348
Which solver to use.
349349
Possible values: ``explicit`` and ``predictor-corrector``.
350350

351-
* ``fields.poisson_solver`` (`string`) optional (default CPU: `FFTDirichletDirect`, GPU: `FFTDirichletQuick` or `FFTDirichletFast`)
351+
* ``fields.poisson_solver`` (`string`) optional (default CPU: `FFTDirichletDirectEven` or `FFTDirichletDirectOdd`, GPU: `FFTDirichletQuick` or `FFTDirichletFast`)
352352
Which Poisson solver to use for ``Psi``, ``Ez`` and ``Bz``. The ``predictor-corrector`` BxBy
353353
solver also uses this poisson solver for ``Bx`` and ``By`` internally. Available solvers are:
354354

355-
* ``FFTDirichletDirect`` Use the discrete sine transformation that is directly implemented
355+
* ``FFTDirichletDirectEven`` Use the discrete sine transformation that is directly implemented
356+
by FFTW to solve the Poisson equation with Dirichlet boundary conditions.
357+
This option is only available when compiling for CPUs with FFTW.
358+
Preferred resolution: :math:`2^N`.
359+
360+
* ``FFTDirichletDirectOdd`` Use the discrete sine transformation that is directly implemented
356361
by FFTW to solve the Poisson equation with Dirichlet boundary conditions.
357362
This option is only available when compiling for CPUs with FFTW.
358363
Preferred resolution: :math:`2^N-1`.

src/fields/Fields.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,23 +204,28 @@ Fields::AllocData (
204204
}
205205

206206
// set default Poisson solver based on the platform and resolution
207-
#ifdef AMREX_USE_GPU
208207
const bool is_even = std::max(slice_ba[0].length(0), slice_ba[0].length(1)) % 2 == 0;
208+
#ifdef AMREX_USE_GPU
209209
std::string poisson_solver_str = is_even ? "FFTDirichletQuick" : "FFTDirichletFast";
210210
#else
211-
std::string poisson_solver_str = "FFTDirichletDirect";
211+
std::string poisson_solver_str = is_even ? "FFTDirichletDirectEven" : "FFTDirichletDirectOdd";
212212
#endif
213213
amrex::ParmParse ppf("fields");
214214
queryWithParser(ppf, "poisson_solver", poisson_solver_str);
215215

216216
// The Poisson solver operates on transverse slices only.
217217
// The constructor takes the BoxArray and the DistributionMap of a slice,
218218
// so the FFTPlans are built on a slice.
219-
if (poisson_solver_str == "FFTDirichletDirect"){
219+
if (poisson_solver_str == "FFTDirichletDirectEven"){
220+
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletDirect>(
221+
new FFTPoissonSolverDirichletDirect(getSlices(lev).boxArray(),
222+
getSlices(lev).DistributionMap(),
223+
geom, true)));
224+
} else if (poisson_solver_str == "FFTDirichletDirectOdd"){
220225
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletDirect>(
221226
new FFTPoissonSolverDirichletDirect(getSlices(lev).boxArray(),
222227
getSlices(lev).DistributionMap(),
223-
geom)) );
228+
geom, false)));
224229
} else if (poisson_solver_str == "FFTDirichletExpanded"){
225230
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletExpanded>(
226231
new FFTPoissonSolverDirichletExpanded(getSlices(lev).boxArray(),
@@ -248,8 +253,8 @@ Fields::AllocData (
248253
geom)) );
249254
} else {
250255
amrex::Abort("Unknown poisson solver '" + poisson_solver_str +
251-
"', must be 'FFTDirichletDirect', 'FFTDirichletExpanded', 'FFTDirichletFast', " +
252-
"'FFTDirichletQuick', 'FFTPeriodic' or 'MGDirichlet'");
256+
"', must be 'FFTDirichletDirectEven', 'FFTDirichletDirectOdd', 'FFTDirichletExpanded', "
257+
"'FFTDirichletFast', 'FFTDirichletQuick', 'FFTPeriodic' or 'MGDirichlet'");
253258
}
254259

255260
if (lev == 0 && m_insitu_period.isNonZero()) {

src/fields/fft_poisson_solver/FFTPoissonSolverDirichletDirect.H

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public:
2929
/** Constructor */
3030
FFTPoissonSolverDirichletDirect ( amrex::BoxArray const& a_realspace_ba,
3131
amrex::DistributionMapping const& dm,
32-
amrex::Geometry const& gm);
32+
amrex::Geometry const& gm,
33+
bool is_even);
3334

3435
/** virtual destructor */
3536
virtual ~FFTPoissonSolverDirichletDirect () override final {}
@@ -42,10 +43,12 @@ public:
4243
* \param[in] realspace_ba BoxArray on which the FFT is executed.
4344
* \param[in] dm DistributionMapping for the BoxArray.
4445
* \param[in] gm Geometry, contains the box dimensions.
46+
* \param[in] is_even True: use DST-2 / DST-3; False: use DST-1
4547
*/
4648
void define ( amrex::BoxArray const& realspace_ba,
4749
amrex::DistributionMapping const& dm,
48-
amrex::Geometry const& gm);
50+
amrex::Geometry const& gm,
51+
bool is_even);
4952

5053
/**
5154
* Solve Poisson equation. The source term must be stored in the staging area m_stagingArea prior to this call.
@@ -55,8 +58,8 @@ public:
5558
virtual void SolvePoissonEquation (amrex::MultiFab& lhs_mf) override final;
5659

5760
/** Position and relative factor used to apply inhomogeneous Dirichlet boundary conditions */
58-
virtual amrex::Real BoundaryOffset() override final { return 1.; }
59-
virtual amrex::Real BoundaryFactor() override final { return 1.; }
61+
virtual amrex::Real BoundaryOffset() override final { return m_is_even ? 0.5 : 1.; }
62+
virtual amrex::Real BoundaryFactor() override final { return m_is_even ? 2. : 1.; }
6063

6164
private:
6265
/** Spectral fields, contains (real) field in Fourier space */
@@ -69,6 +72,8 @@ private:
6972
AnyFFT m_backward_fft;
7073
/** work area for both DST plans */
7174
amrex::Gpu::DeviceVector<char> m_fft_work_area;
75+
/** True: use DST-2 / DST-3; False: use DST-1 */
76+
bool m_is_even;
7277
};
7378

7479
#endif

src/fields/fft_poisson_solver/FFTPoissonSolverDirichletDirect.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,23 @@
1616
FFTPoissonSolverDirichletDirect::FFTPoissonSolverDirichletDirect (
1717
amrex::BoxArray const& realspace_ba,
1818
amrex::DistributionMapping const& dm,
19-
amrex::Geometry const& gm )
19+
amrex::Geometry const& gm,
20+
bool is_even)
2021
{
21-
define(realspace_ba, dm, gm);
22+
define(realspace_ba, dm, gm, is_even);
2223
}
2324

2425
void
2526
FFTPoissonSolverDirichletDirect::define (amrex::BoxArray const& a_realspace_ba,
2627
amrex::DistributionMapping const& dm,
27-
amrex::Geometry const& gm )
28+
amrex::Geometry const& gm,
29+
bool is_even)
2830
{
2931
HIPACE_PROFILE("FFTPoissonSolverDirichletDirect::define()");
3032
using namespace amrex::literals;
3133

34+
m_is_even = is_even;
35+
3236
// If we are going to support parallel FFT, the constructor needs to take a communicator.
3337
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(a_realspace_ba.size() == 1, "Parallel FFT not supported yet");
3438

@@ -52,15 +56,17 @@ FFTPoissonSolverDirichletDirect::define (amrex::BoxArray const& a_realspace_ba,
5256
const amrex::IntVect fft_size = fft_box.length();
5357
const int nx = fft_size[0];
5458
const int ny = fft_size[1];
59+
const int logical_nx = is_even ? nx : nx + 1;
60+
const int logical_ny = is_even ? ny : ny + 1;
5561
const auto dx = gm.CellSizeArray();
5662
const amrex::Real dxsquared = dx[0]*dx[0];
5763
const amrex::Real dysquared = dx[1]*dx[1];
58-
const amrex::Real sine_x_factor = MathConst::pi / ( 2. * ( nx + 1 ));
59-
const amrex::Real sine_y_factor = MathConst::pi / ( 2. * ( ny + 1 ));
64+
const amrex::Real sine_x_factor = MathConst::pi / ( 2. * logical_nx);
65+
const amrex::Real sine_y_factor = MathConst::pi / ( 2. * logical_ny);
6066

6167
// Normalization of FFTW's 'DST-I' discrete sine transform (FFTW_RODFT00)
6268
// This normalization is used regardless of the sine transform library
63-
const amrex::Real norm_fac = 0.5 / ( 2 * (( nx + 1 ) * ( ny + 1 )));
69+
const amrex::Real norm_fac = 0.5 / ( 2 * (logical_nx * logical_ny));
6470

6571
// Calculate the array of m_eigenvalue_matrix
6672
for (amrex::MFIter mfi(m_eigenvalue_matrix, DfltMfi); mfi.isValid(); ++mfi ){
@@ -84,8 +90,10 @@ FFTPoissonSolverDirichletDirect::define (amrex::BoxArray const& a_realspace_ba,
8490
}
8591

8692
// Allocate and initialize the FFT plans
87-
std::size_t fwd_area = m_forward_fft.Initialize(FFTType::R2R_2D, fft_size[0], fft_size[1]);
88-
std::size_t bkw_area = m_backward_fft.Initialize(FFTType::R2R_2D, fft_size[0], fft_size[1]);
93+
std::size_t fwd_area = m_forward_fft.Initialize(
94+
is_even ? FFTType::R2R_2D_DST2 : FFTType::R2R_2D_DST1, fft_size[0], fft_size[1]);
95+
std::size_t bkw_area = m_backward_fft.Initialize(
96+
is_even ? FFTType::R2R_2D_DST3 : FFTType::R2R_2D_DST1, fft_size[0], fft_size[1]);
8997

9098
// Allocate work area for both FFTs
9199
m_fft_work_area.resize(std::max(fwd_area, bkw_area));

src/fields/fft_poisson_solver/fft/AnyFFT.H

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ enum struct FFTType {
2424
C2C_2D_bkw,
2525
C2R_2D,
2626
R2C_2D,
27-
R2R_2D,
27+
R2R_2D_DST1,
28+
R2R_2D_DST2,
29+
R2R_2D_DST3,
2830
C2R_1D_batched,
2931
R2C_1D_batched
3032
};

src/fields/fft_poisson_solver/fft/WrapCuFFT.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ std::size_t AnyFFT::Initialize (FFTType type, int nx, int ny) {
102102
n[1] = nx;
103103
batch = 1;
104104
break;
105-
case FFTType::R2R_2D:
105+
case FFTType::R2R_2D_DST1:
106+
case FFTType::R2R_2D_DST2:
107+
case FFTType::R2R_2D_DST3:
106108
amrex::Abort("R2R FFT not supported by cufft");
107109
return 0;
108110
case FFTType::C2R_1D_batched:
@@ -189,7 +191,9 @@ void AnyFFT::Execute () {
189191
reinterpret_cast<cufftComplex*>(m_plan->m_out));
190192
assert_cufft_status("cufftExecR2C", result);
191193
break;
192-
case FFTType::R2R_2D:
194+
case FFTType::R2R_2D_DST1:
195+
case FFTType::R2R_2D_DST2:
196+
case FFTType::R2R_2D_DST3:
193197
amrex::Abort("R2R FFT not supported by cufft");
194198
break;
195199
case FFTType::C2R_1D_batched:
@@ -233,7 +237,9 @@ void AnyFFT::Execute () {
233237
reinterpret_cast<cufftDoubleComplex*>(m_plan->m_out));
234238
assert_cufft_status("cufftExecD2Z", result);
235239
break;
236-
case FFTType::R2R_2D:
240+
case FFTType::R2R_2D_DST1:
241+
case FFTType::R2R_2D_DST2:
242+
case FFTType::R2R_2D_DST3:
237243
amrex::Abort("R2R FFT not supported by cufft");
238244
break;
239245
case FFTType::C2R_1D_batched:

src/fields/fft_poisson_solver/fft/WrapFFTW.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,24 @@ void AnyFFT::SetBuffers (void* in, void* out, [[maybe_unused]] void* work_area)
7474
reinterpret_cast<float*>(in), reinterpret_cast<fftwf_complex*>(out),
7575
FFTW_MEASURE);
7676
break;
77-
case FFTType::R2R_2D:
77+
case FFTType::R2R_2D_DST1:
7878
m_plan->m_fftwf_plan = fftwf_plan_r2r_2d(
7979
m_plan->m_ny, m_plan->m_nx,
8080
reinterpret_cast<float*>(in), reinterpret_cast<float*>(out),
8181
FFTW_RODFT00, FFTW_RODFT00, FFTW_MEASURE);
8282
break;
83+
case FFTType::R2R_2D_DST2:
84+
m_plan->m_fftwf_plan = fftwf_plan_r2r_2d(
85+
m_plan->m_ny, m_plan->m_nx,
86+
reinterpret_cast<float*>(in), reinterpret_cast<float*>(out),
87+
FFTW_RODFT10, FFTW_RODFT10, FFTW_MEASURE);
88+
break;
89+
case FFTType::R2R_2D_DST3:
90+
m_plan->m_fftwf_plan = fftwf_plan_r2r_2d(
91+
m_plan->m_ny, m_plan->m_nx,
92+
reinterpret_cast<float*>(in), reinterpret_cast<float*>(out),
93+
FFTW_RODFT01, FFTW_RODFT01, FFTW_MEASURE);
94+
break;
8395
case FFTType::C2R_1D_batched:
8496
{
8597
int n[1] = {m_plan->m_nx};
@@ -127,12 +139,24 @@ void AnyFFT::SetBuffers (void* in, void* out, [[maybe_unused]] void* work_area)
127139
reinterpret_cast<double*>(in), reinterpret_cast<fftw_complex*>(out),
128140
FFTW_MEASURE);
129141
break;
130-
case FFTType::R2R_2D:
142+
case FFTType::R2R_2D_DST1:
131143
m_plan->m_fftw_plan = fftw_plan_r2r_2d(
132144
m_plan->m_ny, m_plan->m_nx,
133145
reinterpret_cast<double*>(in), reinterpret_cast<double*>(out),
134146
FFTW_RODFT00, FFTW_RODFT00, FFTW_MEASURE);
135147
break;
148+
case FFTType::R2R_2D_DST2:
149+
m_plan->m_fftw_plan = fftw_plan_r2r_2d(
150+
m_plan->m_ny, m_plan->m_nx,
151+
reinterpret_cast<double*>(in), reinterpret_cast<double*>(out),
152+
FFTW_RODFT10, FFTW_RODFT10, FFTW_MEASURE);
153+
break;
154+
case FFTType::R2R_2D_DST3:
155+
m_plan->m_fftw_plan = fftw_plan_r2r_2d(
156+
m_plan->m_ny, m_plan->m_nx,
157+
reinterpret_cast<double*>(in), reinterpret_cast<double*>(out),
158+
FFTW_RODFT01, FFTW_RODFT01, FFTW_MEASURE);
159+
break;
136160
case FFTType::C2R_1D_batched:
137161
{
138162
int n[1] = {m_plan->m_nx};

src/fields/fft_poisson_solver/fft/WrapRocFFT.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ std::size_t AnyFFT::Initialize (FFTType type, int nx, int ny) {
107107
lengths[1] = ny;
108108
number_of_transforms = 1;
109109
break;
110-
case FFTType::R2R_2D:
110+
case FFTType::R2R_2D_DST1:
111+
case FFTType::R2R_2D_DST2:
112+
case FFTType::R2R_2D_DST3:
111113
amrex::Abort("R2R FFT not supported by rocfft");
112114
return 0;
113115
case FFTType::C2R_1D_batched:

tests/Poisson_even.1Rank.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ HIPACE_TEST_DIR=${HIPACE_SOURCE_DIR}/tests
2626
RTOL=2e-3
2727

2828

29-
for solver_type in FFTDirichletDirect FFTDirichletExpanded FFTDirichletFast FFTDirichletQuick MGDirichlet
29+
for solver_type in FFTDirichletDirectOdd FFTDirichletDirectEven FFTDirichletExpanded FFTDirichletFast FFTDirichletQuick MGDirichlet
3030
do
3131

3232
echo "Testing $solver_type"

tests/Poisson_odd.1Rank.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ HIPACE_TEST_DIR=${HIPACE_SOURCE_DIR}/tests
2626
RTOL=2e-3
2727

2828

29-
for solver_type in FFTDirichletDirect FFTDirichletExpanded FFTDirichletFast FFTDirichletQuick MGDirichlet
29+
for solver_type in FFTDirichletDirectOdd FFTDirichletDirectEven FFTDirichletExpanded FFTDirichletFast FFTDirichletQuick MGDirichlet
3030
do
3131

3232
echo "Testing $solver_type"

0 commit comments

Comments
 (0)