Skip to content

Commit d2bc9d4

Browse files
fix(tidy3d): FXC-4641-fix-gradients-in-custom-medium
1 parent ea3a8ad commit d2bc9d4

3 files changed

Lines changed: 574 additions & 75 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
### Fixed
1717
- Fixed `AutoImpedanceSpec` validation to check path intersections against all conductors, not just filtered ones, as well as the mode plane bounds.
1818
- Fixed adjoint gradients being treated as zero due to scale-dependent `np.allclose(..., atol=1e-8)` checks, which could skip adjoint simulations and return zero gradients.
19+
- Fixed interpolation handling for permittivity and conductivity gradients in CustomMedium.
1920

2021
## [2.10.0] - 2025-12-18
2122

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
5+
import autograd.numpy as anp
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pytest
9+
from autograd import value_and_grad
10+
11+
import tidy3d as td
12+
import tidy3d.web as web
13+
from tidy3d.components.autograd import get_static
14+
15+
td.config.local_cache.enabled = True
16+
17+
SIM_SIZE_SCALE = (4, 3, 4)
18+
BOX_SIZE_SCALE = (1, 1, 1)
19+
GRID_STEPS_PER_WVL = 30
20+
RUN_TIME = 2e-12
21+
ANGLE_TOL = 10.0
22+
FD_STEP = 5e-2
23+
24+
TEST_CASES = [
25+
{
26+
"name": "opt_flux_iso",
27+
"wavelength": 1.0,
28+
"permittivities": (2.2, 2.2, 2.2),
29+
"objective_kind": "flux",
30+
"monitor_size": (np.inf, np.inf, 0.0),
31+
"polarization": 0.0,
32+
"medium_type": "isotropic",
33+
},
34+
{
35+
"name": "mw_intensity_iso",
36+
"wavelength": 1.6,
37+
"permittivities": (1.8, 1.8, 1.8),
38+
"objective_kind": "intensity",
39+
"monitor_size": (0.4, 0.4, 0.0),
40+
"polarization": np.pi / 5,
41+
"medium_type": "isotropic",
42+
},
43+
{
44+
"name": "opt_flux_custom_iso",
45+
"wavelength": 1.3,
46+
"permittivities": (2.0, 2.0, 2.0),
47+
"objective_kind": "flux",
48+
"monitor_size": (np.inf, np.inf, 0.0),
49+
"polarization": 0.0,
50+
"medium_type": "custom",
51+
},
52+
{
53+
"name": "mw_int_custom_iso",
54+
"wavelength": 1.1,
55+
"permittivities": (1.6, 1.6, 1.6),
56+
"objective_kind": "intensity",
57+
"monitor_size": (0.3, 0.3, 0.0),
58+
"polarization": np.pi / 3,
59+
"medium_type": "custom",
60+
},
61+
]
62+
63+
64+
def _scale_monitor_dim(dim: float, wavelength: float) -> float:
65+
if np.isinf(dim):
66+
return np.inf
67+
return dim * wavelength
68+
69+
70+
def _box_geometry(case) -> td.Box:
71+
size = tuple(scale * case["wavelength"] for scale in BOX_SIZE_SCALE)
72+
return td.Box(size=size, center=(0.0, 0.0, 0.0))
73+
74+
75+
def _build_base_sim(case):
76+
wavelength = case["wavelength"]
77+
freq0 = td.C_0 / wavelength
78+
sim_size = tuple(scale * wavelength for scale in SIM_SIZE_SCALE)
79+
80+
plane_wave = td.PlaneWave(
81+
center=(0.0, 0.0, -0.75 * sim_size[2] / 2),
82+
size=(sim_size[0], sim_size[1], 0.0),
83+
source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0),
84+
direction="+",
85+
pol_angle=case.get("polarization", 0.0),
86+
)
87+
88+
monitor_center = (0.0, 0.0, sim_size[2] / 2 * 0.75)
89+
monitor_size = tuple(_scale_monitor_dim(dim, wavelength) for dim in case["monitor_size"])
90+
monitor_name = f"{case['name']}_monitor"
91+
monitor = td.FieldMonitor(
92+
center=monitor_center,
93+
size=monitor_size,
94+
freqs=[freq0],
95+
name=monitor_name,
96+
colocate=False,
97+
)
98+
99+
sim = td.Simulation(
100+
size=sim_size,
101+
center=(0.0, 0.0, 0.0),
102+
grid_spec=td.GridSpec.auto(min_steps_per_wvl=GRID_STEPS_PER_WVL, wavelength=wavelength),
103+
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True),
104+
sources=[plane_wave],
105+
monitors=[monitor],
106+
structures=[],
107+
run_time=RUN_TIME,
108+
)
109+
return sim, monitor_name, freq0
110+
111+
112+
def _add_medium(case, base_sim: td.Simulation, box_geom: td.Box, eps_vals) -> td.Simulation:
113+
medium_type = case["medium_type"]
114+
115+
coords = None
116+
factor = None
117+
if medium_type in ("custom_anisotropic", "custom"):
118+
coords = {
119+
"x": np.linspace(-box_geom.size[0] / 2, box_geom.size[0] / 2, 4),
120+
"y": np.linspace(-box_geom.size[1] / 2, box_geom.size[1] / 2, 5),
121+
"z": np.linspace(-box_geom.size[2] / 2, box_geom.size[2] / 2, 3),
122+
}
123+
_cx, _cy, _cz = np.meshgrid(coords["x"], coords["y"], coords["z"], indexing="ij")
124+
factor = 1 + 0.2 * (_cx + _cy + _cz) / 3.0
125+
126+
if medium_type == "custom_anisotropic":
127+
128+
def _custom_medium(val):
129+
values = factor * val
130+
data = td.SpatialDataArray(values, coords=coords)
131+
return td.CustomMedium(permittivity=data)
132+
133+
medium = td.CustomAnisotropicMedium(
134+
xx=_custom_medium(eps_vals[0]),
135+
yy=_custom_medium(eps_vals[1]),
136+
zz=_custom_medium(eps_vals[2]),
137+
)
138+
elif medium_type == "custom":
139+
140+
def _custom_isotropic(val):
141+
values = factor * val
142+
data = td.SpatialDataArray(values, coords=coords)
143+
return td.CustomMedium(permittivity=data)
144+
145+
medium = _custom_isotropic(eps_vals[0])
146+
elif medium_type == "isotropic":
147+
# use first entry; others are identical by construction
148+
medium = td.Medium(permittivity=eps_vals[0])
149+
elif medium_type == "anisotropic":
150+
medium = td.AnisotropicMedium(
151+
xx=td.Medium(permittivity=eps_vals[0]),
152+
yy=td.Medium(permittivity=eps_vals[1]),
153+
zz=td.Medium(permittivity=eps_vals[2]),
154+
)
155+
else:
156+
raise ValueError(
157+
"Medium type has to be one of 'custom', 'isotropic', 'anisotropic' or 'custom_anisotropic'"
158+
)
159+
160+
structure = td.Structure(geometry=box_geom, medium=medium)
161+
return base_sim.updated_copy(structures=[structure])
162+
163+
164+
def _metric_value(case, dataset, freq0):
165+
if case["objective_kind"] == "flux":
166+
return dataset.flux.values
167+
ex_vals = dataset.Ex.values
168+
ey_vals = dataset.Ey.values
169+
ez_vals = dataset.Ez.values
170+
intensity = np.abs(ex_vals) ** 2 + np.abs(ey_vals) ** 2 + np.abs(ez_vals) ** 2
171+
return anp.real(anp.mean(intensity))
172+
173+
174+
def _angle_deg(vec_a: np.ndarray, vec_b: np.ndarray) -> float:
175+
norm_a = np.linalg.norm(vec_a)
176+
norm_b = np.linalg.norm(vec_b)
177+
if norm_a == 0 or norm_b == 0:
178+
return np.nan
179+
cos_theta = np.clip(np.dot(vec_a, vec_b) / (norm_a * norm_b), -1.0, 1.0)
180+
return float(np.degrees(np.arccos(cos_theta)))
181+
182+
183+
def _run_simulation(
184+
case, base_sim, box_geom, eps_vals, label, tmp_path, monitor_name, freq0, gradient
185+
):
186+
sim = _add_medium(case, base_sim, box_geom, eps_vals)
187+
sim_data = web.run(
188+
sim,
189+
task_name=f"medium_grad_{case['name']}_{label}",
190+
local_gradient=gradient,
191+
verbose=False,
192+
path=str(tmp_path / f"{case['name']}_{label}.hdf5"),
193+
)
194+
return _metric_value(case, sim_data[monitor_name], freq0)
195+
196+
197+
@pytest.mark.numerical
198+
@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["name"])
199+
def test_medium_grads_match_fd(case, numerical_case_dir, tmp_path):
200+
base_sim, monitor_name, freq0 = _build_base_sim(case)
201+
box_geom = _box_geometry(case)
202+
params0 = anp.array(case["permittivities"])
203+
204+
def objective(eps_vals):
205+
return _run_simulation(
206+
case,
207+
base_sim,
208+
box_geom,
209+
eps_vals,
210+
label="adjoint",
211+
tmp_path=tmp_path,
212+
monitor_name=monitor_name,
213+
freq0=freq0,
214+
gradient=True,
215+
)
216+
217+
_, grad_adj = value_and_grad(objective)(params0)
218+
grad_adj = get_static(grad_adj)
219+
220+
fd_sims = {}
221+
base_params = get_static(params0)
222+
for axis in range(3):
223+
delta = np.zeros_like(base_params)
224+
delta[axis] = FD_STEP
225+
fd_sims[f"fd_plus_{axis}"] = _add_medium(case, base_sim, box_geom, base_params + delta)
226+
fd_sims[f"fd_minus_{axis}"] = _add_medium(case, base_sim, box_geom, base_params - delta)
227+
228+
fd_results = web.run_async(
229+
fd_sims,
230+
path_dir=str(numerical_case_dir / f"fd_batch_{case['name']}"),
231+
local_gradient=False,
232+
verbose=False,
233+
)
234+
235+
grad_fd = np.zeros_like(grad_adj)
236+
for axis in range(3):
237+
plus = _metric_value(case, fd_results[f"fd_plus_{axis}"][monitor_name], freq0)
238+
minus = _metric_value(case, fd_results[f"fd_minus_{axis}"][monitor_name], freq0)
239+
grad_fd[axis] = (plus - minus) / (2.0 * FD_STEP)
240+
241+
angle_deg = _angle_deg(grad_adj, grad_fd)
242+
243+
print(
244+
f"[medium-grad-test:{case['name']}] adjoint={grad_adj}, "
245+
f"finite-difference={grad_fd}, angle_deg={angle_deg:.3f}",
246+
file=sys.stderr,
247+
)
248+
249+
angle_tol = case.get("angle_tol_deg", ANGLE_TOL)
250+
assert angle_deg <= angle_tol or np.isnan(angle_deg), (
251+
f"Gradient angle deviation {angle_deg:.3f} deg exceeds tolerance ({angle_tol}). "
252+
f"adj={grad_adj}, fd={grad_fd}"
253+
)
254+
255+
256+
@pytest.mark.skip
257+
@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["name"])
258+
def test_medium_fd_step_sweep(case, numerical_case_dir, tmp_path):
259+
base_sim, monitor_name, freq0 = _build_base_sim(case)
260+
box_geom = _box_geometry(case)
261+
params0 = anp.array(case["permittivities"])
262+
263+
def objective(eps_vals):
264+
return _run_simulation(
265+
case,
266+
base_sim,
267+
box_geom,
268+
eps_vals,
269+
label="adjoint_sweep",
270+
tmp_path=tmp_path,
271+
monitor_name=monitor_name,
272+
freq0=freq0,
273+
gradient=True,
274+
)
275+
276+
_, grad_adj = value_and_grad(objective)(params0)
277+
grad_adj = get_static(grad_adj)
278+
base_params = get_static(params0)
279+
280+
sweep_steps = np.logspace(-4, -1, num=9)
281+
step_labels = [f"{step:.3e}" for step in sweep_steps]
282+
283+
sweep_runs: dict[str, td.Simulation] = {}
284+
for step_label, step in zip(step_labels, sweep_steps):
285+
for axis in range(base_params.size):
286+
delta = np.zeros_like(base_params)
287+
delta[axis] = step
288+
key_base = f"{case['name']}_axis{axis}_{step_label}"
289+
sweep_runs[f"{key_base}_plus"] = _add_medium(
290+
case,
291+
base_sim,
292+
box_geom,
293+
base_params + delta,
294+
)
295+
sweep_runs[f"{key_base}_minus"] = _add_medium(
296+
case,
297+
base_sim,
298+
box_geom,
299+
base_params - delta,
300+
)
301+
302+
sweep_results = web.run_async(
303+
sweep_runs,
304+
path_dir=str(numerical_case_dir / f"fd_sweep_{case['name']}"),
305+
local_gradient=False,
306+
verbose=False,
307+
)
308+
309+
fd_sweep_matrix = np.zeros((len(sweep_steps), base_params.size), dtype=float)
310+
for step_idx, (step_label, step) in enumerate(zip(step_labels, sweep_steps)):
311+
for axis in range(base_params.size):
312+
plus_key = f"{case['name']}_axis{axis}_{step_label}_plus"
313+
minus_key = f"{case['name']}_axis{axis}_{step_label}_minus"
314+
plus_val = _metric_value(case, sweep_results[plus_key][monitor_name], freq0)
315+
minus_val = _metric_value(case, sweep_results[minus_key][monitor_name], freq0)
316+
fd_sweep_matrix[step_idx, axis] = (plus_val - minus_val) / (2.0 * step)
317+
318+
labels = ["xx", "yy", "zz"]
319+
fig, ax = plt.subplots(figsize=(6, 4))
320+
for axis, label in enumerate(labels[: base_params.size]):
321+
ax.plot(sweep_steps, fd_sweep_matrix[:, axis], marker="o", label=f"{label} (FD)")
322+
color = ax.get_lines()[-1].get_color()
323+
ax.axhline(
324+
grad_adj[axis],
325+
color=color,
326+
linestyle="--",
327+
alpha=0.7,
328+
label=f"{label} (autograd)",
329+
)
330+
331+
ax.set_xscale("log")
332+
ax.set_xlabel("Finite difference step")
333+
ax.set_ylabel("Gradient value")
334+
ax.set_title(f"FD gradients vs. step size ({case['name']})")
335+
ax.grid(True, which="both", ls=":")
336+
ax.legend()
337+
338+
fig_path = numerical_case_dir / f"medium_fd_step_sweep_{case['name']}.png"
339+
fig.savefig(fig_path, dpi=200)
340+
plt.close(fig)
341+
342+
# FD gradient extrema per parameter (across all step sizes)
343+
fd_min_per_param = fd_sweep_matrix.min(axis=0)
344+
fd_max_per_param = fd_sweep_matrix.max(axis=0)
345+
346+
print(
347+
(
348+
f"[medium-fd-sweep:{case['name']}] "
349+
f"grad_adj={np.array2string(grad_adj, precision=6, separator=', ')} "
350+
f"fd_grad_per_param[min,max]="
351+
f"{[(f'({mn:.3e},{mx:.3e})') for mn, mx in zip(fd_min_per_param, fd_max_per_param)]}"
352+
),
353+
file=sys.stderr,
354+
)

0 commit comments

Comments
 (0)