Skip to content

Commit 5f6de2c

Browse files
committed
improving zonal fast, adding robustness testing
1 parent b4b5bc7 commit 5f6de2c

13 files changed

Lines changed: 566 additions & 32 deletions

README.md

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,27 @@ kl_gen.save("my_kl_basis.npz")
110110

111111
## Zonal Fast Basis
112112

113-
`ZonalFastBasisGenerator` groups actuators into binary poke patterns such that no two actuators in the same mode are closer than a user-defined distance `D`. This is useful when you want a compact calibration basis that reduces the number of measurements compared with pure zonal pokes.
113+
`ZonalFastBasisGenerator` groups actuators into binary poke patterns such that no two actuators in the same mode are closer than a user-defined distance `D`. This is useful when you want a compact calibration basis that reduces the number of measurements compared with pure zonal pokes. For square-grid actuator layouts it uses a modulo lattice grouping directly, and for exotic layouts it falls back to a greedy graph-coloring approach.
114114

115115
```python
116-
from aobasis import ZonalFastBasisGenerator, make_circular_actuator_grid
116+
import numpy as np
117117

118-
positions = make_circular_actuator_grid(telescope_diameter=10.0, grid_size=20)
119-
120-
# Distance threshold in the same units as the actuator coordinates.
121-
zonal_fast_gen = ZonalFastBasisGenerator(positions, min_distance=0.8)
118+
from aobasis import ZonalFastBasisGenerator, make_circular_actuator_grid, make_concentric_actuator_grid
122119

123-
# Omitting n_modes returns the full grouped basis.
124-
zonal_fast_modes = zonal_fast_gen.generate()
125-
print(zonal_fast_modes.shape)
126-
127-
zonal_fast_gen.plot(count=min(12, zonal_fast_modes.shape[1]), title_prefix="Zonal Fast")
120+
# Example 1: grid-like actuator positions clipped by a circular pupil.
121+
positions = make_circular_actuator_grid(telescope_diameter=10.0, grid_size=20)
122+
grid_gen = ZonalFastBasisGenerator(positions, min_distance=0.8)
123+
grid_modes = grid_gen.generate()
124+
print("Grid layout:", grid_modes.shape)
125+
grid_gen.plot(count=min(12, grid_modes.shape[1]), title_prefix="Zonal Fast Grid")
126+
127+
# Example 2: non-grid actuator positions.
128+
exotic_positions = make_concentric_actuator_grid(telescope_diameter=10.0, n_rings=5)
129+
exotic_positions = exotic_positions + 0.03 * np.sin(exotic_positions)
130+
exotic_gen = ZonalFastBasisGenerator(exotic_positions, min_distance=1.0)
131+
exotic_modes = exotic_gen.generate()
132+
print("Exotic layout:", exotic_modes.shape)
133+
exotic_gen.plot(count=min(12, exotic_modes.shape[1]), title_prefix="Zonal Fast Exotic")
128134
```
129135

130136
The returned matrix still has the standard `(n_actuators, n_modes)` layout, but each column is now a sparse binary pattern rather than a single-actuator poke. Every actuator appears in exactly one column of the full basis.

compare_zonal_fast.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import argparse
2+
from typing import Iterable, List, Sequence, Tuple
3+
4+
import numpy as np
5+
6+
from aobasis import ZonalFastBasisGenerator
7+
8+
9+
def make_integer_circular_grid(grid_size: int, radius: float) -> Tuple[np.ndarray, np.ndarray]:
10+
"""Build a unit-pitch square lattice clipped by a circular aperture."""
11+
axis = np.arange(grid_size, dtype=float) - 0.5 * (grid_size - 1)
12+
xx, yy = np.meshgrid(axis, axis, indexing="xy")
13+
full_positions = np.column_stack((xx.ravel(), yy.ravel()))
14+
full_indices = np.column_stack(np.unravel_index(np.arange(grid_size * grid_size), (grid_size, grid_size)))
15+
16+
mask = np.sum(full_positions**2, axis=1) <= radius**2 + 1e-12
17+
return full_positions[mask], full_indices[mask]
18+
19+
20+
def build_corner_subgrid_basis(grid_indices: np.ndarray, spacing: int) -> np.ndarray:
21+
"""Group actuators by their row and column residue classes modulo spacing."""
22+
if spacing <= 0:
23+
raise ValueError("spacing must be a positive integer.")
24+
if grid_indices.ndim != 2 or grid_indices.shape[1] != 2:
25+
raise ValueError("grid_indices must have shape (n_actuators, 2).")
26+
27+
residues = np.mod(grid_indices, spacing)
28+
groups = {}
29+
for actuator_index, residue in enumerate(residues):
30+
key = (int(residue[0]), int(residue[1]))
31+
groups.setdefault(key, []).append(actuator_index)
32+
33+
ordered_groups = [groups[key] for key in sorted(groups)]
34+
basis = np.zeros((grid_indices.shape[0], len(ordered_groups)), dtype=float)
35+
for mode_index, actuator_group in enumerate(ordered_groups):
36+
basis[actuator_group, mode_index] = 1.0
37+
return basis
38+
39+
40+
def minimum_pairwise_distance(positions: np.ndarray, active_indices: np.ndarray) -> float:
41+
if active_indices.size < 2:
42+
return np.inf
43+
44+
active_positions = positions[active_indices]
45+
deltas = active_positions[:, None, :] - active_positions[None, :, :]
46+
distances = np.linalg.norm(deltas, axis=-1)
47+
upper_triangle = distances[np.triu_indices(active_positions.shape[0], k=1)]
48+
return float(upper_triangle.min())
49+
50+
51+
def validate_basis(positions: np.ndarray, basis: np.ndarray, min_distance: float) -> Tuple[bool, float]:
52+
if basis.shape[0] != positions.shape[0]:
53+
raise ValueError("basis row count must match the number of actuator positions.")
54+
55+
covered = np.allclose(basis.sum(axis=1), 1.0)
56+
mode_min_distances: List[float] = []
57+
58+
for mode_index in range(basis.shape[1]):
59+
active_indices = np.flatnonzero(basis[:, mode_index] > 0.5)
60+
mode_min_distances.append(minimum_pairwise_distance(positions, active_indices))
61+
62+
worst_case_distance = min(mode_min_distances) if mode_min_distances else np.inf
63+
spacing_ok = worst_case_distance >= min_distance - 1e-12
64+
return covered and spacing_ok, worst_case_distance
65+
66+
67+
def compare_for_spacing(positions: np.ndarray, grid_indices: np.ndarray, spacing: int) -> dict:
68+
naive_basis = build_corner_subgrid_basis(grid_indices, spacing)
69+
fast_basis = ZonalFastBasisGenerator(positions, min_distance=float(spacing)).generate()
70+
71+
naive_valid, naive_min_distance = validate_basis(positions, naive_basis, float(spacing))
72+
fast_valid, fast_min_distance = validate_basis(positions, fast_basis, float(spacing))
73+
74+
naive_modes = int(naive_basis.shape[1])
75+
fast_modes = int(fast_basis.shape[1])
76+
reduction = naive_modes - fast_modes
77+
reduction_fraction = reduction / naive_modes if naive_modes else 0.0
78+
79+
return {
80+
"spacing": spacing,
81+
"n_actuators": int(positions.shape[0]),
82+
"naive_modes": naive_modes,
83+
"fast_modes": fast_modes,
84+
"reduction": reduction,
85+
"reduction_fraction": reduction_fraction,
86+
"naive_valid": naive_valid,
87+
"fast_valid": fast_valid,
88+
"naive_min_distance": naive_min_distance,
89+
"fast_min_distance": fast_min_distance,
90+
}
91+
92+
93+
def parse_distances(values: Sequence[int]) -> List[int]:
94+
unique_values = sorted({int(value) for value in values})
95+
if not unique_values:
96+
raise ValueError("At least one spacing value must be provided.")
97+
if any(value <= 0 for value in unique_values):
98+
raise ValueError("Spacing values must all be positive integers.")
99+
return unique_values
100+
101+
102+
def print_report(results: Iterable[dict]) -> None:
103+
header = (
104+
f"{'D':>4} {'Actuators':>10} {'Naive':>8} {'Fast':>8} {'Saved':>8} {'Saved %':>9} "
105+
f"{'Naive OK':>9} {'Fast OK':>8} {'Naive min d':>12} {'Fast min d':>11}"
106+
)
107+
print(header)
108+
print("-" * len(header))
109+
110+
for result in results:
111+
print(
112+
f"{result['spacing']:>4d} "
113+
f"{result['n_actuators']:>10d} "
114+
f"{result['naive_modes']:>8d} "
115+
f"{result['fast_modes']:>8d} "
116+
f"{result['reduction']:>8d} "
117+
f"{100.0 * result['reduction_fraction']:>8.2f}% "
118+
f"{str(result['naive_valid']):>9} "
119+
f"{str(result['fast_valid']):>8} "
120+
f"{result['naive_min_distance']:>12.3f} "
121+
f"{result['fast_min_distance']:>11.3f}"
122+
)
123+
124+
125+
def main() -> None:
126+
parser = argparse.ArgumentParser(
127+
description=(
128+
"Compare a naive corner-anchored D x D subgrid zonal-fast basis against "
129+
"the graph-coloring zonal-fast basis on a circular aperture."
130+
)
131+
)
132+
parser.add_argument("--grid-size", type=int, default=60, help="Number of points along each side of the square lattice.")
133+
parser.add_argument("--radius", type=float, default=30.0, help="Circular aperture radius in lattice-pitch units.")
134+
parser.add_argument(
135+
"--distances",
136+
type=int,
137+
nargs="+",
138+
default=[2, 3, 4, 5, 6, 8, 10],
139+
help="Integer spacing values D to test, in lattice-pitch units.",
140+
)
141+
parser.add_argument(
142+
"--fail-if-not-better",
143+
action="store_true",
144+
help="Exit with status 1 if the zonal-fast basis is not strictly better for every tested spacing.",
145+
)
146+
args = parser.parse_args()
147+
148+
distances = parse_distances(args.distances)
149+
positions, grid_indices = make_integer_circular_grid(args.grid_size, args.radius)
150+
results = [compare_for_spacing(positions, grid_indices, spacing) for spacing in distances]
151+
152+
print(
153+
f"Circular aperture on a {args.grid_size}x{args.grid_size} unit-pitch lattice, "
154+
f"radius={args.radius:.1f}, actuators inside pupil={positions.shape[0]}"
155+
)
156+
print_report(results)
157+
158+
wins = [result for result in results if result["fast_modes"] < result["naive_modes"]]
159+
ties = [result for result in results if result["fast_modes"] == result["naive_modes"]]
160+
losses = [result for result in results if result["fast_modes"] > result["naive_modes"]]
161+
162+
print()
163+
print(f"Fast basis uses fewer modes for {len(wins)} / {len(results)} tested spacings.")
164+
if ties:
165+
tied_spacings = ", ".join(str(result["spacing"]) for result in ties)
166+
print(f"Tied spacings: {tied_spacings}")
167+
if losses:
168+
loss_spacings = ", ".join(str(result["spacing"]) for result in losses)
169+
print(f"Fast basis used more modes at: {loss_spacings}")
170+
171+
if not losses and not ties:
172+
print("Verdict: zonal-fast is strictly better than the naive corner-subgrid basis for every tested spacing.")
173+
elif losses:
174+
print("Verdict: this zonal-fast implementation does not beat the naive corner-subgrid baseline on this geometry.")
175+
print("The graph coloring is producing a valid basis, but not a minimal one for these spacings.")
176+
else:
177+
print("Verdict: zonal-fast matches the naive construction for some spacings and never improves on it in this run.")
178+
179+
if args.fail_if_not_better and (losses or ties):
180+
raise SystemExit(1)
181+
182+
183+
if __name__ == "__main__":
184+
main()

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "aobasis"
7-
version = "1.0.1"
7+
version = "1.0.2"
88
description = "A package for generating AO basis sets (KL, Zernike, Fourier)"
99
readme = "README.md"
10-
authors = [{ name = "Jacob Taylor", email = "jtaylor@keck.hawaii.edu" }]
10+
authors = [{ name = "Jacob Taylor", email = "jacobataylor7@gmail.com" }]
1111
license = { file = "LICENSE" }
1212
classifiers = [
1313
"Programming Language :: Python :: 3",

src/aobasis/base.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
from typing import Tuple, Optional, Union
55
from .utils import plot_basis_modes
66

7+
8+
def _validate_positions_array(positions: np.ndarray) -> np.ndarray:
9+
try:
10+
array = np.asarray(positions, dtype=float)
11+
except (TypeError, ValueError) as exc:
12+
raise ValueError("positions must be a finite numeric array with shape (n_actuators, 2).") from exc
13+
14+
if array.ndim != 2 or array.shape[1] != 2:
15+
raise ValueError("positions must have shape (n_actuators, 2).")
16+
if not np.all(np.isfinite(array)):
17+
raise ValueError("positions must contain only finite values.")
18+
19+
return array
20+
721
class BasisGenerator(ABC):
822
"""
923
Abstract base class for AO basis generators.
@@ -14,9 +28,21 @@ def __init__(self, positions: np.ndarray):
1428
Args:
1529
positions: (N, 2) array of actuator coordinates (x, y) in meters.
1630
"""
17-
self.positions = np.array(positions)
31+
self.positions = _validate_positions_array(positions)
1832
self.n_actuators = self.positions.shape[0]
1933
self.modes: Optional[np.ndarray] = None
34+
35+
def _validate_n_modes(self, n_modes: int, max_modes: Optional[int] = None) -> int:
36+
if isinstance(n_modes, bool) or not isinstance(n_modes, (int, np.integer)):
37+
raise ValueError("n_modes must be an integer.")
38+
39+
n_modes = int(n_modes)
40+
if n_modes < 0:
41+
raise ValueError("n_modes must be non-negative.")
42+
if max_modes is not None and n_modes > max_modes:
43+
raise ValueError(f"Cannot generate {n_modes} modes; maximum available is {max_modes}.")
44+
45+
return n_modes
2046

2147
@abstractmethod
2248
def generate(self, n_modes: int, **kwargs) -> np.ndarray:
@@ -73,4 +99,5 @@ class ConcreteBasis(BasisGenerator):
7399
def generate(self, n_modes: int, **kwargs) -> np.ndarray:
74100
if self.modes is None:
75101
raise NotImplementedError("This is a loaded basis container.")
102+
n_modes = self._validate_n_modes(n_modes, max_modes=self.modes.shape[1])
76103
return self.modes[:, :n_modes]

src/aobasis/fourier.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,20 @@ class FourierBasisGenerator(BasisGenerator):
88

99
def __init__(self, positions: np.ndarray, pupil_diameter: float):
1010
super().__init__(positions)
11+
if not np.isscalar(pupil_diameter) or not np.isfinite(pupil_diameter) or pupil_diameter <= 0:
12+
raise ValueError("pupil_diameter must be a positive finite scalar.")
1113
self.pupil_diameter = pupil_diameter
1214

1315
def generate(self, n_modes: int, ignore_piston: bool = False, **kwargs) -> np.ndarray:
1416
"""
1517
Generate Fourier modes.
1618
We generate pairs of sin/cos for increasing spatial frequencies.
1719
"""
20+
n_modes = self._validate_n_modes(n_modes)
21+
if n_modes == 0:
22+
self.modes = np.zeros((self.n_actuators, 0), dtype=float)
23+
return self.modes
24+
1825
x = self.positions[:, 0]
1926
y = self.positions[:, 1]
2027

src/aobasis/hadamard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def generate(self, n_modes: int, **kwargs) -> np.ndarray:
1616
we find the next power of 2 >= n_actuators, generate the Hadamard matrix,
1717
and truncate it to the number of actuators (rows) and requested modes (columns).
1818
"""
19+
n_modes = self._validate_n_modes(n_modes)
20+
if n_modes == 0:
21+
self.modes = np.zeros((self.n_actuators, 0), dtype=int)
22+
return self.modes
23+
1924
# Find next power of 2 covering the number of actuators
2025
# We need at least n_actuators rows to define the pattern on the grid
2126
# And we need enough columns for n_modes

src/aobasis/kl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ class KLBasisGenerator(BasisGenerator):
9999

100100
def __init__(self, positions: np.ndarray, fried_parameter: float = 0.16, outer_scale: float = 30.0, use_gpu: bool = False):
101101
super().__init__(positions)
102+
if not np.isscalar(fried_parameter) or not np.isfinite(fried_parameter) or fried_parameter <= 0:
103+
raise ValueError("fried_parameter must be a positive finite scalar.")
104+
if not np.isscalar(outer_scale) or not np.isfinite(outer_scale) or outer_scale <= 0:
105+
raise ValueError("outer_scale must be a positive finite scalar.")
102106
self.fried_parameter = fried_parameter
103107
self.outer_scale = outer_scale
104108
self.eigenvalues = None
@@ -177,6 +181,14 @@ def _von_karman_covariance_gpu(self):
177181
return cov
178182

179183
def generate(self, n_modes: int, ignore_piston: bool = False, **kwargs) -> np.ndarray:
184+
max_modes = self.n_actuators - (1 if ignore_piston else 0)
185+
n_modes = self._validate_n_modes(n_modes, max_modes=max_modes)
186+
187+
if n_modes == 0:
188+
self.eigenvalues = np.array([], dtype=float)
189+
self.modes = np.zeros((self.n_actuators, 0), dtype=float)
190+
return self.modes
191+
180192
cov = self._von_karman_covariance()
181193

182194
if self.use_gpu:

0 commit comments

Comments
 (0)