Skip to content

Commit eba5d6f

Browse files
add missing tests
1 parent dc3eb7e commit eba5d6f

7 files changed

Lines changed: 1315 additions & 0 deletions

File tree

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
"""Numerical validation for multi-frequency custom dispersive medium gradients."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
7+
import autograd.numpy as anp
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
import pytest
11+
from autograd import value_and_grad
12+
13+
import tidy3d as td
14+
import tidy3d.web as web
15+
from tidy3d.components.autograd import get_static
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def _enable_local_cache(monkeypatch):
20+
monkeypatch.setattr(td.config.local_cache, "enabled", True)
21+
22+
23+
SIM_SIZE_SCALE = (3.0, 2.5, 3.0)
24+
BOX_SIZE_SCALE = (0.8, 0.8, 0.8)
25+
GRID_STEPS_PER_WVL = 40
26+
RUN_TIME = 2e-13
27+
FD_STEP = 5e-3
28+
ANGLE_TOL = 5.0
29+
30+
FREQS = np.array([1.7e14, 2.4e14])
31+
FREQ_WEIGHTS = np.array([1.0, 0.6])
32+
33+
PARAM_SHAPE_2D = (2, 2)
34+
PARAM_SHAPE = (2, 2, 2)
35+
FD_SWEEP_STEPS = np.logspace(-3, -1, num=7)
36+
37+
SELLMEIER_C_VAL = 0.6 * (td.C_0 / np.max(FREQS)) ** 2
38+
39+
TEST_CASES = [
40+
{
41+
"name": "lo1", # keep names short, filenames get too long otherwise
42+
"kind": "lorentz",
43+
"eps_inf": 1.6,
44+
"param0": 0.5,
45+
"f0": 2.6e14,
46+
"delta": 0.2e14,
47+
},
48+
{
49+
"name": "lo2",
50+
"kind": "lorentz",
51+
"eps_inf": 2.3,
52+
"param0": 0.7,
53+
"f0": 2.3e14,
54+
"delta": 0.2e14,
55+
},
56+
{
57+
"name": "lo3",
58+
"kind": "lorentz",
59+
"eps_inf": 1.9,
60+
"param0": 0.35,
61+
"f0": 3.0e14,
62+
"delta": 0.2e14,
63+
},
64+
{
65+
"name": "sl",
66+
"kind": "sellmeier",
67+
"param0": 0.6,
68+
"c_val": SELLMEIER_C_VAL,
69+
},
70+
{
71+
"name": "dd",
72+
"kind": "drude",
73+
"eps_inf": 1.6,
74+
"param0": 0.5,
75+
"param_scale": 2.0e14,
76+
"delta": 0.3e14,
77+
},
78+
{
79+
"name": "db",
80+
"kind": "debye",
81+
"eps_inf": 2.5,
82+
"param0": 0.5,
83+
"tau": 0.4e-14,
84+
},
85+
{
86+
"name": "pr",
87+
"kind": "pole_residue",
88+
"eps_inf": 1.6,
89+
"param0": 0.5,
90+
"param_scale": 1.0e14,
91+
"a_val": -1.2e14,
92+
},
93+
]
94+
95+
96+
def _build_base_sim(freqs: np.ndarray) -> tuple[td.Simulation, str, float]:
97+
wavelength_min = td.C_0 / np.max(freqs)
98+
sim_size = tuple(scale * wavelength_min for scale in SIM_SIZE_SCALE)
99+
100+
freq0 = float(np.mean(freqs))
101+
fwidth = float(max(freqs.max() - freqs.min(), 0.4 * freq0))
102+
103+
src = td.PlaneWave(
104+
center=(0.0, 0.0, -0.75 * sim_size[2] / 2),
105+
size=(sim_size[0], sim_size[1], 0.0),
106+
source_time=td.GaussianPulse(freq0=freq0, fwidth=fwidth),
107+
direction="+",
108+
pol_angle=0.0,
109+
)
110+
111+
monitor_name = "field_monitor"
112+
monitor = td.FieldMonitor(
113+
center=(0.0, 0.0, sim_size[2] / 2 * 0.6),
114+
size=(sim_size[0], sim_size[1], 0.0),
115+
freqs=list(freqs),
116+
name=monitor_name,
117+
colocate=False,
118+
)
119+
120+
sim = td.Simulation(
121+
size=sim_size,
122+
center=(0.0, 0.0, 0.0),
123+
grid_spec=td.GridSpec.auto(
124+
min_steps_per_wvl=GRID_STEPS_PER_WVL,
125+
wavelength=wavelength_min,
126+
),
127+
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True),
128+
sources=[src],
129+
monitors=[monitor],
130+
structures=[],
131+
run_time=RUN_TIME,
132+
)
133+
return sim, monitor_name, wavelength_min
134+
135+
136+
def _box_geometry(wavelength_min: float) -> td.Box:
137+
size = tuple(scale * wavelength_min for scale in BOX_SIZE_SCALE)
138+
return td.Box(size=size, center=(0.0, 0.0, 0.0))
139+
140+
141+
def _coords_for_bounds(bounds, shape):
142+
return {
143+
"x": np.linspace(bounds[0][0], bounds[1][0], shape[0]),
144+
"y": np.linspace(bounds[0][1], bounds[1][1], shape[1]),
145+
"z": np.linspace(bounds[0][2], bounds[1][2], shape[2]),
146+
}
147+
148+
149+
def _custom_medium(case, param_vals: anp.ndarray, box_geom: td.Box):
150+
bounds = box_geom.bounds
151+
coords = _coords_for_bounds(bounds, param_vals.shape)
152+
kind = case["kind"]
153+
param_scale = case.get("param_scale", 1.0)
154+
scaled = param_scale * param_vals
155+
156+
if kind == "lorentz":
157+
eps_inf = td.SpatialDataArray(np.full(param_vals.shape, case["eps_inf"]), coords=coords)
158+
de = td.SpatialDataArray(scaled, coords=coords)
159+
f0 = td.SpatialDataArray(np.full(param_vals.shape, case["f0"]), coords=coords)
160+
delta = td.SpatialDataArray(np.full(param_vals.shape, case["delta"]), coords=coords)
161+
return td.CustomLorentz(eps_inf=eps_inf, coeffs=[(de, f0, delta)])
162+
if kind == "sellmeier":
163+
b = td.SpatialDataArray(scaled, coords=coords)
164+
c = td.SpatialDataArray(np.full(param_vals.shape, case["c_val"]), coords=coords)
165+
return td.CustomSellmeier(coeffs=[(b, c)])
166+
if kind == "drude":
167+
eps_inf = td.SpatialDataArray(np.full(param_vals.shape, case["eps_inf"]), coords=coords)
168+
fp = td.SpatialDataArray(scaled, coords=coords)
169+
delta = td.SpatialDataArray(np.full(param_vals.shape, case["delta"]), coords=coords)
170+
return td.CustomDrude(eps_inf=eps_inf, coeffs=[(fp, delta)])
171+
if kind == "debye":
172+
eps_inf = td.SpatialDataArray(np.full(param_vals.shape, case["eps_inf"]), coords=coords)
173+
de = td.SpatialDataArray(scaled, coords=coords)
174+
tau = td.SpatialDataArray(np.full(param_vals.shape, case["tau"]), coords=coords)
175+
return td.CustomDebye(eps_inf=eps_inf, coeffs=[(de, tau)])
176+
if kind == "pole_residue":
177+
eps_inf = td.SpatialDataArray(np.full(param_vals.shape, case["eps_inf"]), coords=coords)
178+
a_val = td.SpatialDataArray(np.full(param_vals.shape, case["a_val"]), coords=coords)
179+
c_val = td.SpatialDataArray(scaled, coords=coords)
180+
return td.CustomPoleResidue(eps_inf=eps_inf, poles=[(a_val, c_val)])
181+
raise ValueError(f"Unsupported medium kind: {kind}")
182+
183+
184+
def _add_medium(
185+
sim: td.Simulation, box_geom: td.Box, case, param_vals: anp.ndarray
186+
) -> td.Simulation:
187+
medium = _custom_medium(case, param_vals, box_geom)
188+
structure = td.Structure(geometry=box_geom, medium=medium)
189+
return sim.updated_copy(structures=[structure])
190+
191+
192+
def _metric_value(dataset) -> float:
193+
ex_vals = dataset.Ex.values
194+
ey_vals = dataset.Ey.values
195+
ez_vals = dataset.Ez.values
196+
intensity = anp.abs(ex_vals) ** 2 + anp.abs(ey_vals) ** 2 + anp.abs(ez_vals) ** 2
197+
weighted = intensity * anp.asarray(FREQ_WEIGHTS)
198+
return anp.real(anp.mean(weighted))
199+
200+
201+
def _angle_deg(vec_a: np.ndarray, vec_b: np.ndarray) -> float:
202+
norm_a = np.linalg.norm(vec_a)
203+
norm_b = np.linalg.norm(vec_b)
204+
if norm_a == 0 or norm_b == 0:
205+
return np.nan
206+
cos_theta = np.clip(np.dot(vec_a, vec_b) / (norm_a * norm_b), -1.0, 1.0)
207+
return float(np.degrees(np.arccos(cos_theta)))
208+
209+
210+
def _expand_params(params: anp.ndarray) -> anp.ndarray:
211+
vals_2d = anp.reshape(params, PARAM_SHAPE_2D)
212+
return anp.repeat(vals_2d[..., None], PARAM_SHAPE[2], axis=2)
213+
214+
215+
def _run_simulation(
216+
sim: td.Simulation,
217+
monitor_name: str,
218+
tmp_path,
219+
label: str,
220+
local_gradient: bool,
221+
) -> float:
222+
sim_data = web.run(
223+
sim,
224+
task_name=f"custom_disp_{label}",
225+
local_gradient=local_gradient,
226+
verbose=False,
227+
path=str(tmp_path / f"custom_disp_{label}.hdf5"),
228+
)
229+
return _metric_value(sim_data[monitor_name])
230+
231+
232+
@pytest.mark.numerical
233+
@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["name"])
234+
def test_custom_dispersive_multifreq_grad_matches_fd(
235+
case, numerical_case_dir, tmp_path, _enable_local_cache
236+
):
237+
base_sim, monitor_name, wavelength_min = _build_base_sim(FREQS)
238+
box_geom = _box_geometry(wavelength_min)
239+
240+
params0 = anp.full(PARAM_SHAPE_2D, case["param0"]).reshape(-1)
241+
242+
def objective(param_vec):
243+
param_vals = _expand_params(param_vec)
244+
sim = _add_medium(base_sim, box_geom, case, param_vals)
245+
return _run_simulation(
246+
sim=sim,
247+
monitor_name=monitor_name,
248+
tmp_path=tmp_path,
249+
label="adjoint",
250+
local_gradient=True,
251+
)
252+
253+
_, grad_adj = value_and_grad(objective)(params0)
254+
grad_adj = np.asarray(get_static(grad_adj), dtype=float).reshape(-1)
255+
256+
fd_sims: dict[str, td.Simulation] = {}
257+
for idx in range(params0.size):
258+
delta = np.zeros_like(params0)
259+
delta[idx] = FD_STEP
260+
plus_vals = _expand_params(params0 + delta)
261+
minus_vals = _expand_params(params0 - delta)
262+
fd_sims[f"plus_{idx}"] = _add_medium(base_sim, box_geom, case, plus_vals)
263+
fd_sims[f"minus_{idx}"] = _add_medium(base_sim, box_geom, case, minus_vals)
264+
265+
fd_results = web.run_async(
266+
fd_sims,
267+
path_dir=str(numerical_case_dir / f"{case['name']}"),
268+
local_gradient=False,
269+
verbose=False,
270+
)
271+
272+
grad_fd = np.zeros_like(grad_adj)
273+
for idx in range(params0.size):
274+
val_plus = _metric_value(fd_results[f"plus_{idx}"][monitor_name])
275+
val_minus = _metric_value(fd_results[f"minus_{idx}"][monitor_name])
276+
grad_fd[idx] = (val_plus - val_minus) / (2.0 * FD_STEP)
277+
278+
angle_deg = _angle_deg(grad_adj, grad_fd)
279+
print(
280+
(
281+
f"[custom-dispersive-multifreq:{case['name']}] adjoint={grad_adj}, "
282+
f"finite-difference={grad_fd}, angle_deg={angle_deg:.3f}"
283+
),
284+
file=sys.stderr,
285+
)
286+
287+
assert angle_deg <= ANGLE_TOL or np.isnan(angle_deg), (
288+
f"Multi-frequency CustomDispersive gradient mismatch for {case['name']}. "
289+
f"angle_deg={angle_deg:.3f}, adj={grad_adj}, fd={grad_fd}"
290+
)
291+
292+
293+
@pytest.mark.numerical
294+
def test_custom_lorentz_fd_step_sweep(numerical_case_dir, tmp_path, _enable_local_cache):
295+
base_sim, monitor_name, wavelength_min = _build_base_sim(FREQS)
296+
box_geom = _box_geometry(wavelength_min)
297+
298+
case = TEST_CASES[0]
299+
params0 = anp.full(PARAM_SHAPE_2D, case["param0"]).reshape(-1)
300+
301+
def objective(de_params):
302+
de_vals = _expand_params(de_params)
303+
sim = _add_medium(base_sim, box_geom, case, de_vals)
304+
return _run_simulation(
305+
sim=sim,
306+
monitor_name=monitor_name,
307+
tmp_path=tmp_path,
308+
label="adjoint_sweep",
309+
local_gradient=True,
310+
)
311+
312+
_, grad_adj = value_and_grad(objective)(params0)
313+
grad_adj = np.asarray(get_static(grad_adj), dtype=float).reshape(-1)
314+
315+
sweep_runs: dict[str, td.Simulation] = {}
316+
step_labels = [f"{step:.3e}" for step in FD_SWEEP_STEPS]
317+
for step_label, step in zip(step_labels, FD_SWEEP_STEPS):
318+
plus_vals = _expand_params(params0 + step)
319+
minus_vals = _expand_params(params0 - step)
320+
sweep_runs[f"step_{step_label}_plus"] = _add_medium(base_sim, box_geom, case, plus_vals)
321+
sweep_runs[f"step_{step_label}_minus"] = _add_medium(base_sim, box_geom, case, minus_vals)
322+
323+
sweep_results = web.run_async(
324+
sweep_runs,
325+
path_dir=str(numerical_case_dir / f"fd_sweep_{case['name']}"),
326+
local_gradient=False,
327+
verbose=False,
328+
)
329+
330+
fd_sweep = []
331+
for step_label, step in zip(step_labels, FD_SWEEP_STEPS):
332+
plus_key = f"step_{step_label}_plus"
333+
minus_key = f"step_{step_label}_minus"
334+
plus_val = _metric_value(sweep_results[plus_key][monitor_name])
335+
minus_val = _metric_value(sweep_results[minus_key][monitor_name])
336+
fd_sweep.append((plus_val - minus_val) / (2.0 * step))
337+
338+
fd_sweep = np.array(fd_sweep, dtype=float)
339+
fd_min = float(np.min(fd_sweep))
340+
fd_max = float(np.max(fd_sweep))
341+
342+
fig, ax = plt.subplots(figsize=(6, 4))
343+
ax.plot(FD_SWEEP_STEPS, fd_sweep, marker="o", label="FD")
344+
ax.axhline(
345+
np.mean(grad_adj),
346+
color=ax.get_lines()[-1].get_color(),
347+
linestyle="--",
348+
alpha=0.7,
349+
label="Adjoint (mean)",
350+
)
351+
ax.set_xscale("log")
352+
ax.set_xlabel("Finite difference step")
353+
ax.set_ylabel("Gradient value")
354+
ax.set_title("CustomLorentz FD sweep")
355+
ax.grid(True, which="both", ls=":")
356+
ax.legend()
357+
358+
fig_path = numerical_case_dir / "custom_lorentz_fd_step_sweep.png"
359+
fig.savefig(fig_path, dpi=200)
360+
plt.close(fig)
361+
362+
print(
363+
(
364+
"[custom-dispersive-fd-sweep] "
365+
f"grad_adj={grad_adj} "
366+
f"fd_grad[min,max]=({fd_min:.6e},{fd_max:.6e})"
367+
),
368+
file=sys.stderr,
369+
)

0 commit comments

Comments
 (0)