Skip to content

Commit f2573dd

Browse files
committed
feat: add 2D analytical solution, num_samples config, n param, 4 new tests
- sim_wave.py: add analytical_solution_2d (2D FFT exact solution for wave/KG) - sim_wave.py: add n param to WaveSimulator for API compat with other simulators - wave.yaml: add num_samples (was hardcoded 1000 in gen_wave.py) and n: 1 - gen_wave.py: use config.num_samples instead of hardcoded 1000 - test_sim_wave.py: add seed reproducibility, 2D leapfrog vs analytical, 2D analytical at t=0, and periodicity test (single mode returns to IC after T=1/c) Tests: 14/14 pass
1 parent e703ff0 commit f2573dd

4 files changed

Lines changed: 128 additions & 2 deletions

File tree

pdebench/data_gen/configs/wave.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ output_path: 1D_Wave_c1.0
1212

1313
name: 1d_wave
1414

15+
num_samples: 1000
16+
1517
sim:
1618
c: 1.0
1719
chi: 0.0 # 0 = wave equation, >0 = Klein-Gordon
@@ -20,6 +22,7 @@ sim:
2022
xdim: 1024
2123
ndim: 1
2224
n_modes: 5
25+
n: 1
2326
seed: "???"
2427

2528
plot:

pdebench/data_gen/gen_wave.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def main(config: DictConfig):
164164
base_name = f"{config.sim.ndim}D_Wave_c{config.sim.c}{chi_str}"
165165
config.output_path = str((output_dir / base_name).with_suffix(".h5"))
166166

167-
num_samples = 1000
167+
num_samples = config.num_samples
168168

169169
log.info(f"Generating {num_samples} samples -> {config.output_path}")
170170
log.info(f"PDE: d2u/dt2 = {config.sim.c}^2 * Lap(u) - {config.sim.chi}^2 * u")

pdebench/data_gen/src/sim_wave.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class WaveSimulator:
3737
:param xdim: number of spatial grid points per dimension
3838
:param ndim: spatial dimensionality (1 or 2)
3939
:param n_modes: number of Fourier modes in IC generation
40+
:param n: number of batches (unused; present for API compatibility)
4041
:param seed: random seed for IC generation
4142
"""
4243

@@ -49,6 +50,7 @@ def __init__(
4950
xdim: int = 1024,
5051
ndim: int = 1,
5152
n_modes: int = 5,
53+
n: int = 1, # noqa: ARG002
5254
seed: int = 0,
5355
):
5456
self.c = c
@@ -219,3 +221,42 @@ def analytical_solution_1d(
219221
result[i] = np.fft.ifft(u_hat_t).real
220222

221223
return result
224+
225+
226+
def analytical_solution_2d(
227+
x: np.ndarray,
228+
t: np.ndarray,
229+
u0: np.ndarray,
230+
c: float,
231+
chi: float = 0.0,
232+
) -> np.ndarray:
233+
"""
234+
Compute analytical solution for 2D wave/KG equation via 2D FFT.
235+
236+
Each Fourier mode (kx, ky) oscillates at frequency
237+
omega = sqrt(c^2 * (2*pi)^2 * (kx^2 + ky^2) + chi^2).
238+
239+
:param x: 1D spatial grid (Nx,), same for both dimensions
240+
:param t: time points (Nt,)
241+
:param u0: initial condition (Nx, Nx)
242+
:param c: wave speed
243+
:param chi: mass parameter
244+
:return: solution array (Nt, Nx, Nx)
245+
"""
246+
Nx = u0.shape[0]
247+
u0_hat = np.fft.fft2(u0)
248+
kx = np.fft.fftfreq(Nx, d=1.0 / Nx)
249+
ky = np.fft.fftfreq(Nx, d=1.0 / Nx)
250+
KX, KY = np.meshgrid(kx, ky, indexing="ij")
251+
252+
omega = np.sqrt(
253+
c**2 * (2 * np.pi) ** 2 * (KX**2 + KY**2) + chi**2 + 0j
254+
).real
255+
256+
result = np.zeros((len(t), Nx, Nx), dtype=np.float64)
257+
for i, ti in enumerate(t):
258+
# du/dt(0) = 0 => only cosine term
259+
u_hat_t = u0_hat * np.cos(omega * ti)
260+
result[i] = np.fft.ifft2(u_hat_t).real
261+
262+
return result

tests/test_sim_wave.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from pdebench.data_gen.src.sim_wave import WaveSimulator, analytical_solution_1d
8+
from pdebench.data_gen.src.sim_wave import WaveSimulator, analytical_solution_1d, analytical_solution_2d
99

1010

1111
# ---------------------------------------------------------------------------
@@ -141,3 +141,85 @@ def test_analytical_solution_kg_at_t0():
141141
result = analytical_solution_1d(x, t, u0, c=1.0, chi=5.0)
142142

143143
np.testing.assert_allclose(result[0], u0, atol=1e-12)
144+
145+
146+
# ---------------------------------------------------------------------------
147+
# Seed reproducibility
148+
# ---------------------------------------------------------------------------
149+
150+
151+
def test_seed_reproducibility():
152+
"""Same seed must produce bit-identical output across two calls."""
153+
sim_a = WaveSimulator(c=1.0, chi=0.0, xdim=32, tdim=11, t=0.5, seed=42)
154+
sim_b = WaveSimulator(c=1.0, chi=0.0, xdim=32, tdim=11, t=0.5, seed=42)
155+
np.testing.assert_array_equal(sim_a.generate_sample(), sim_b.generate_sample())
156+
157+
158+
# ---------------------------------------------------------------------------
159+
# 2D analytical solution
160+
# ---------------------------------------------------------------------------
161+
162+
163+
def test_wave_2d_matches_analytical():
164+
"""
165+
2D leapfrog should match the 2D analytical solution to within 1% nRMSE.
166+
167+
IC: u0(x, y) = cos(2*pi*x) * cos(2*pi*y), du/dt=0.
168+
Exact: u(x, y, t) = cos(2*pi*x) * cos(2*pi*y) * cos(2*pi*sqrt(2)*c*t).
169+
"""
170+
Nx = 64
171+
sim = WaveSimulator(c=1.0, chi=0.0, xdim=Nx, tdim=11, t=0.3, ndim=2, seed=0)
172+
173+
x1d = np.linspace(0, 1, Nx, endpoint=False)
174+
X, Y = np.meshgrid(x1d, x1d, indexing="ij")
175+
u0_2d = np.cos(2 * np.pi * X) * np.cos(2 * np.pi * Y)
176+
177+
sim._random_fourier_ic_2d = lambda rng: u0_2d.copy() # noqa: ARG005
178+
179+
numerical = sim.generate_sample() # (11, Nx, Nx) float32
180+
exact = analytical_solution_2d(x1d, sim.t_save, u0_2d, c=1.0, chi=0.0)
181+
182+
u_range = exact.max() - exact.min()
183+
nrmse = np.sqrt(np.mean((numerical.astype(np.float64) - exact) ** 2)) / u_range
184+
assert nrmse < 0.01, f"2D nRMSE={nrmse:.4f} exceeds 1% tolerance"
185+
186+
187+
def test_analytical_solution_2d_at_t0():
188+
"""2D analytical solution at t=0 must equal u0 exactly."""
189+
Nx = 32
190+
x = np.linspace(0, 1, Nx, endpoint=False)
191+
X, Y = np.meshgrid(x, x, indexing="ij")
192+
u0 = np.sin(2 * np.pi * X) + 0.3 * np.cos(4 * np.pi * Y)
193+
t = np.array([0.0, 0.5, 1.0])
194+
195+
result = analytical_solution_2d(x, t, u0, c=1.0, chi=0.0)
196+
197+
np.testing.assert_allclose(result[0], u0, atol=1e-12)
198+
199+
200+
# ---------------------------------------------------------------------------
201+
# Periodicity: single mode returns to IC after one period
202+
# ---------------------------------------------------------------------------
203+
204+
205+
def test_wave_periodicity_single_mode():
206+
"""
207+
After one full period T = 1/c, a single k=1 cosine IC returns to itself.
208+
209+
u(x, t) = cos(2*pi*x) * cos(2*pi*c*t) has period T = 1/c.
210+
The leapfrog nRMSE at t=T should be < 0.1% (O(dt^2) per step).
211+
"""
212+
c = 1.0
213+
Nx = 256
214+
# tdim=3 saves t=0, t=T/2, t=T
215+
sim = WaveSimulator(c=c, chi=0.0, xdim=Nx, tdim=3, t=1.0 / c, seed=0)
216+
u0 = np.cos(2 * np.pi * sim.x)
217+
218+
sim._random_fourier_ic_1d = lambda rng: u0.copy() # noqa: ARG005
219+
220+
result = sim.generate_sample() # (3, Nx): t=0, t=T/2, t=T
221+
u_final = result[-1].astype(np.float64)
222+
223+
u_range = u0.max() - u0.min()
224+
nrmse = np.sqrt(np.mean((u_final - u0) ** 2)) / u_range
225+
assert nrmse < 0.001, f"Periodicity nRMSE={nrmse:.6f} exceeds 0.1%"

0 commit comments

Comments
 (0)