Skip to content

Commit 37d5d88

Browse files
committed
Add docstrings to enhance clarity across benchmarks and examples
1 parent c5e186c commit 37d5d88

22 files changed

Lines changed: 285 additions & 0 deletions

benchmarks/compare_mode_solver_fixtures.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""CLI and helpers for comparing committed mode-solver fixtures against MicroMode."""
2+
13
from __future__ import annotations
24

35
import argparse
@@ -28,6 +30,7 @@
2830

2931

3032
def parse_args() -> argparse.Namespace:
33+
"""Parse fixture-comparison CLI options."""
3134
parser = argparse.ArgumentParser(description="Inspect committed mode-solver reference fixtures.")
3235
parser.add_argument(
3336
"--suite",
@@ -72,6 +75,7 @@ def parse_args() -> argparse.Namespace:
7275

7376

7477
def main() -> None:
78+
"""Inspect fixtures and optionally run local comparisons."""
7579
args = parse_args()
7680
fixture_root = args.fixture_root or (DEFAULT_FIXTURE_ROOT / args.suite)
7781
manifest = read_json(manifest_path(fixture_root))
@@ -125,6 +129,7 @@ def main() -> None:
125129

126130

127131
def _compare_local_case(root: Path, entry: dict) -> dict:
132+
"""Run one reconstructable fixture recipe through MicroMode and compare outputs."""
128133
case_id = entry["case_id"]
129134
try:
130135
import micromode as sm
@@ -321,6 +326,7 @@ def _compare_local_case(root: Path, entry: dict) -> dict:
321326
def _solver_edges_from_field_coords(
322327
edges: tuple[np.ndarray, np.ndarray], recipe: dict
323328
) -> tuple[np.ndarray, np.ndarray]:
329+
"""Derive solver edge coordinates from fixture field coordinates."""
324330
dmin_pmc = tuple(bool(value) for value in recipe.get("dmin_pmc", (False, False)))
325331
trim_edges = tuple(recipe.get("trim_edges", ((0, 0), (0, 0))))
326332
out = []
@@ -354,6 +360,7 @@ def _solve_recipe(
354360
normal_dim: str,
355361
normal_coord: float,
356362
):
363+
"""Solve all frequencies described by a local fixture recipe."""
357364
freqs = tuple(float(freq) for freq in ref_n.coords["f"].values)
358365
if recipe.get("solve_each_frequency"):
359366
rows = []
@@ -424,6 +431,7 @@ def _solve_recipe_for_freq(
424431
normal_dim: str,
425432
normal_coord: float,
426433
):
434+
"""Solve one fixture recipe for one frequency or frequency tuple."""
427435
freqs = (float(freq),) if np.isscalar(freq) else tuple(float(value) for value in freq)
428436
centers = tuple((axis_edges[:-1] + axis_edges[1:]) / 2 for axis_edges in edges)
429437
eps_xx, eps_yy, eps_zz = _eps_components_from_recipe(recipe, edges, centers, tangent_dims, freqs[0], sm)
@@ -460,6 +468,7 @@ def _eps_components_from_recipe(
460468
freq: float,
461469
sm,
462470
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
471+
"""Rasterize diagonal epsilon components at Yee sample locations."""
463472
if recipe.get("yee_staggered", True):
464473
return (
465474
_eps_from_recipe(recipe, (coords[0], edges[1][:-1]), tangent_dims, freq, sm),
@@ -477,6 +486,7 @@ def _eps_from_recipe(
477486
freq: float,
478487
sm,
479488
) -> np.ndarray:
489+
"""Rasterize primitive boxes and circles onto one component grid."""
480490
grids = np.meshgrid(*coords, indexing="ij")
481491
eps = np.full(tuple(len(coord) for coord in coords), recipe.get("clad_eps", 1.0), dtype=np.complex128)
482492
for box in recipe.get("boxes", ()):
@@ -502,24 +512,28 @@ def _eps_from_recipe(
502512

503513

504514
def _reorder_modes(values: np.ndarray, recipe: dict) -> np.ndarray:
515+
"""Apply fixture-specific mode ordering before comparison."""
505516
if recipe.get("sort_order") != "ascending":
506517
return values
507518
order = np.argsort(values.real, axis=1)
508519
return np.take_along_axis(values, order, axis=1)
509520

510521

511522
def _reorder_field_modes(values: np.ndarray, recipe: dict) -> np.ndarray:
523+
"""Apply fixture-specific field ordering for diagnostic overlap checks."""
512524
if recipe.get("sort_order") != "ascending":
513525
return values
514526
# Field reordering is only used for a coarse overlap diagnostic; n sorting is authoritative.
515527
return values[..., ::-1]
516528

517529

518530
def _status(status: str, summary: str, **details) -> dict:
531+
"""Build a normalized status dictionary for report output."""
519532
return {"status": status, "failed": status == "fail", "summary": summary, **details}
520533

521534

522535
def _n_tolerance(entry: dict, recipe: dict | None = None) -> float:
536+
"""Resolve the effective-index tolerance for one fixture case."""
523537
if recipe is not None:
524538
recipe_tolerance = recipe.get("n_tolerance")
525539
if recipe_tolerance is not None:

benchmarks/compare_tidy3d_backends.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Benchmark MicroMode against equivalent Tidy3D mode-solver setups."""
2+
13
from __future__ import annotations
24

35
import argparse
@@ -13,6 +15,8 @@
1315

1416
@dataclass(frozen=True)
1517
class BenchmarkCase:
18+
"""Configuration for one backend comparison problem."""
19+
1620
case_id: str
1721
description: str
1822
ny: int
@@ -53,6 +57,7 @@ class BenchmarkCase:
5357

5458

5559
def main() -> None:
60+
"""Run selected backend-comparison cases and print a markdown table."""
5661
args = parse_args()
5762
cases = list(PRESETS[args.preset])
5863
rows = [run_case(case, profile_source=args.profile_source) for case in cases]
@@ -64,6 +69,7 @@ def main() -> None:
6469

6570

6671
def parse_args() -> argparse.Namespace:
72+
"""Parse backend benchmark CLI options."""
6773
parser = argparse.ArgumentParser(description="Compare MicroMode SciPy and Tidy3D local solves.")
6874
parser.add_argument("--preset", choices=tuple(PRESETS), default="quick")
6975
parser.add_argument(
@@ -77,6 +83,7 @@ def parse_args() -> argparse.Namespace:
7783

7884

7985
def run_case(case: BenchmarkCase, *, profile_source: str) -> dict[str, object]:
86+
"""Execute one benchmark case for MicroMode and Tidy3D."""
8087
tidy3d_solver = make_tidy3d_solver(case)
8188
materials = (
8289
micromode_materials_from_tidy3d_solver(tidy3d_solver, case)
@@ -111,6 +118,7 @@ def run_case(case: BenchmarkCase, *, profile_source: str) -> dict[str, object]:
111118

112119

113120
def time_micromode(case: BenchmarkCase, materials: mm.Materials) -> tuple[float, np.ndarray]:
121+
"""Time the MicroMode solve for a prepared material grid."""
114122
start = time.perf_counter()
115123
data = mm.solve_modes(
116124
material_grid=materials,
@@ -124,12 +132,14 @@ def time_micromode(case: BenchmarkCase, materials: mm.Materials) -> tuple[float,
124132

125133

126134
def time_tidy3d(solver) -> tuple[float, np.ndarray]:
135+
"""Time a Tidy3D mode solve."""
127136
start = time.perf_counter()
128137
data = solver.solve()
129138
return time.perf_counter() - start, np.asarray(data.n_eff.values[0], dtype=float)
130139

131140

132141
def make_tidy3d_solver(case: BenchmarkCase):
142+
"""Construct a Tidy3D mode solver for one benchmark case."""
133143
try:
134144
import tidy3d as td
135145
from tidy3d.plugins.mode import ModeSolver
@@ -157,6 +167,7 @@ def make_tidy3d_solver(case: BenchmarkCase):
157167

158168

159169
def micromode_materials_from_tidy3d_solver(solver, case: BenchmarkCase) -> mm.Materials:
170+
"""Rasterize Tidy3D solver materials into a MicroMode grid."""
160171
eps = np.asarray(solver._solver_eps(tidy3d_frequency()), dtype=np.complex128)
161172
grid = solver._solver_grid
162173
return mm.Materials.from_components(
@@ -176,12 +187,14 @@ def micromode_materials_from_tidy3d_solver(solver, case: BenchmarkCase) -> mm.Ma
176187

177188

178189
def tidy3d_frequency() -> float:
190+
"""Return the benchmark frequency in Hz."""
179191
import tidy3d as td
180192

181193
return float(td.C_0 / WAVELENGTH_UM)
182194

183195

184196
def micromode_materials(case: BenchmarkCase) -> mm.Materials:
197+
"""Build a direct MicroMode material grid for one benchmark case."""
185198
y_edges = np.linspace(-0.5 * WIDTH_Y, 0.5 * WIDTH_Y, case.ny + 1)
186199
z_edges = np.linspace(-0.5 * WIDTH_Z, 0.5 * WIDTH_Z, case.nz + 1)
187200
y = 0.5 * (y_edges[:-1] + y_edges[1:])
@@ -200,6 +213,7 @@ def micromode_materials(case: BenchmarkCase) -> mm.Materials:
200213

201214

202215
def tidy3d_structures(td, problem: str):
216+
"""Return Tidy3D geometry structures for one benchmark problem."""
203217
structures = [
204218
td.Structure(
205219
geometry=td.Box(center=(0.0, 0.0, -0.5001), size=(td.inf, td.inf, 1.0002)),
@@ -227,13 +241,15 @@ def tidy3d_structures(td, problem: str):
227241

228242

229243
def max_abs_delta(left: np.ndarray, right: np.ndarray) -> float:
244+
"""Return the maximum absolute difference between sorted mode arrays."""
230245
count = min(left.size, right.size)
231246
if count == 0:
232247
return float("nan")
233248
return float(np.max(np.abs(left[:count] - right[:count])))
234249

235250

236251
def markdown_table(rows: list[dict[str, object]]) -> str:
252+
"""Format benchmark rows as a markdown table."""
237253
header = (
238254
"| Problem | Grid | MicroMode SciPy (s) | Tidy3D local (s) | max abs Δn_eff SciPy/Tidy3D |\n"
239255
"|---|---:|---:|---:|---:|"

benchmarks/micromode_solver_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Time MicroMode solves and record sparse-operator diagnostics."""
2+
13
from __future__ import annotations
24

35
import argparse
@@ -11,6 +13,7 @@
1113

1214

1315
def main() -> None:
16+
"""Run timing cases and write a JSON benchmark report."""
1417
args = parse_args()
1518
rows = []
1619
grids = args.grid or ["20x14", "32x22", "48x32"]
@@ -55,6 +58,7 @@ def main() -> None:
5558

5659

5760
def parse_args() -> argparse.Namespace:
61+
"""Parse solver benchmark CLI options."""
5862
parser = argparse.ArgumentParser(description="Benchmark MicroMode sparse solves over grid sizes.")
5963
parser.add_argument(
6064
"--grid",
@@ -73,6 +77,7 @@ def parse_args() -> argparse.Namespace:
7377

7478

7579
def parse_grid_sizes(values: list[str]) -> list[tuple[int, int]]:
80+
"""Parse grid-size strings like 20x14 into integer pairs."""
7681
sizes = []
7782
for value in values:
7883
left, sep, right = value.lower().partition("x")
@@ -83,6 +88,7 @@ def parse_grid_sizes(values: list[str]) -> list[tuple[int, int]]:
8388

8489

8590
def strip_materials(*, nx: int, ny: int) -> mm.Materials:
91+
"""Build a simple strip-waveguide material grid for timing."""
8692
x_edges = np.linspace(-1.2, 1.2, nx + 1)
8793
y_edges = np.linspace(-0.8, 0.8, ny + 1)
8894
x = 0.5 * (x_edges[:-1] + x_edges[1:])

benchmarks/mode_solver/fixtures.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Shared helpers for reading committed mode-solver reference fixtures."""
2+
13
from __future__ import annotations
24

35
import hashlib
@@ -15,22 +17,27 @@
1517

1618

1719
def case_dir(root: Path, case_id: str) -> Path:
20+
"""Return the directory containing one fixture case."""
1821
return root / case_id
1922

2023

2124
def data_path(root: Path, case_id: str) -> Path:
25+
"""Return the HDF5 path for one fixture case."""
2226
return case_dir(root, case_id) / "mode_data.hdf5"
2327

2428

2529
def summary_path(root: Path, case_id: str) -> Path:
30+
"""Return the JSON summary path for one fixture case."""
2631
return case_dir(root, case_id) / "summary.json"
2732

2833

2934
def manifest_path(root: Path) -> Path:
35+
"""Return the manifest path for a fixture suite."""
3036
return root / "manifest.json"
3137

3238

3339
def sha256_file(path: Path) -> str:
40+
"""Return the SHA-256 digest for a file."""
3441
digest = hashlib.sha256()
3542
with path.open("rb") as f:
3643
for chunk in iter(lambda: f.read(1024 * 1024), b""):
@@ -39,14 +46,17 @@ def sha256_file(path: Path) -> str:
3946

4047

4148
def read_json(path: Path) -> dict[str, Any]:
49+
"""Read a JSON file into a dictionary."""
4250
return json.loads(path.read_text())
4351

4452

4553
def write_json(path: Path, payload: dict[str, Any]) -> None:
54+
"""Write a dictionary as stable, pretty JSON."""
4655
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
4756

4857

4958
def iter_manifest_entries(root: Path) -> tuple[dict[str, Any], ...]:
59+
"""Return all case entries from a suite manifest."""
5060
return tuple(read_json(manifest_path(root))["cases"])
5161

5262

@@ -65,6 +75,7 @@ def load_data_array(path: Path, name: str) -> xr.DataArray:
6575

6676

6777
def _read_xarray_group(group: Any) -> xr.DataArray:
78+
"""Read a legacy xarray-style HDF5 group."""
6879
if _XARRAY_VALUE_NAME not in group:
6980
raise KeyError(f"HDF5 group {group.name!r} is missing {_XARRAY_VALUE_NAME!r}")
7081
values = group[_XARRAY_VALUE_NAME][()]
@@ -75,6 +86,7 @@ def _read_xarray_group(group: Any) -> xr.DataArray:
7586

7687

7788
def _infer_dims(shape: tuple[int, ...], coords: dict[str, np.ndarray]) -> tuple[str, ...]:
89+
"""Infer dimension names from HDF5 coordinate datasets."""
7890
dims = tuple(dim for dim in _PREFERRED_DIMS if dim in coords)
7991
if len(dims) == len(shape) and all(len(coords[dim]) == size for dim, size in zip(dims, shape, strict=True)):
8092
return dims
@@ -86,6 +98,7 @@ def _infer_dims(shape: tuple[int, ...], coords: dict[str, np.ndarray]) -> tuple[
8698

8799

88100
def phase_aligned_relative_error(golden: np.ndarray, actual: np.ndarray) -> tuple[float, float]:
101+
"""Compare complex arrays after removing a global phase offset."""
89102
g = np.asarray(golden).reshape(-1)
90103
a = np.asarray(actual).reshape(-1)
91104
norm_g = float(np.linalg.norm(g))

0 commit comments

Comments
 (0)