1+ import argparse
2+ from typing import Iterable , List , Sequence , Tuple
3+
4+ import numpy as np
5+
6+ from aobasis import ZonalFastBasisGenerator
7+
8+
9+ def make_integer_circular_grid (grid_size : int , radius : float ) -> Tuple [np .ndarray , np .ndarray ]:
10+ """Build a unit-pitch square lattice clipped by a circular aperture."""
11+ axis = np .arange (grid_size , dtype = float ) - 0.5 * (grid_size - 1 )
12+ xx , yy = np .meshgrid (axis , axis , indexing = "xy" )
13+ full_positions = np .column_stack ((xx .ravel (), yy .ravel ()))
14+ full_indices = np .column_stack (np .unravel_index (np .arange (grid_size * grid_size ), (grid_size , grid_size )))
15+
16+ mask = np .sum (full_positions ** 2 , axis = 1 ) <= radius ** 2 + 1e-12
17+ return full_positions [mask ], full_indices [mask ]
18+
19+
20+ def build_corner_subgrid_basis (grid_indices : np .ndarray , spacing : int ) -> np .ndarray :
21+ """Group actuators by their row and column residue classes modulo spacing."""
22+ if spacing <= 0 :
23+ raise ValueError ("spacing must be a positive integer." )
24+ if grid_indices .ndim != 2 or grid_indices .shape [1 ] != 2 :
25+ raise ValueError ("grid_indices must have shape (n_actuators, 2)." )
26+
27+ residues = np .mod (grid_indices , spacing )
28+ groups = {}
29+ for actuator_index , residue in enumerate (residues ):
30+ key = (int (residue [0 ]), int (residue [1 ]))
31+ groups .setdefault (key , []).append (actuator_index )
32+
33+ ordered_groups = [groups [key ] for key in sorted (groups )]
34+ basis = np .zeros ((grid_indices .shape [0 ], len (ordered_groups )), dtype = float )
35+ for mode_index , actuator_group in enumerate (ordered_groups ):
36+ basis [actuator_group , mode_index ] = 1.0
37+ return basis
38+
39+
40+ def minimum_pairwise_distance (positions : np .ndarray , active_indices : np .ndarray ) -> float :
41+ if active_indices .size < 2 :
42+ return np .inf
43+
44+ active_positions = positions [active_indices ]
45+ deltas = active_positions [:, None , :] - active_positions [None , :, :]
46+ distances = np .linalg .norm (deltas , axis = - 1 )
47+ upper_triangle = distances [np .triu_indices (active_positions .shape [0 ], k = 1 )]
48+ return float (upper_triangle .min ())
49+
50+
51+ def validate_basis (positions : np .ndarray , basis : np .ndarray , min_distance : float ) -> Tuple [bool , float ]:
52+ if basis .shape [0 ] != positions .shape [0 ]:
53+ raise ValueError ("basis row count must match the number of actuator positions." )
54+
55+ covered = np .allclose (basis .sum (axis = 1 ), 1.0 )
56+ mode_min_distances : List [float ] = []
57+
58+ for mode_index in range (basis .shape [1 ]):
59+ active_indices = np .flatnonzero (basis [:, mode_index ] > 0.5 )
60+ mode_min_distances .append (minimum_pairwise_distance (positions , active_indices ))
61+
62+ worst_case_distance = min (mode_min_distances ) if mode_min_distances else np .inf
63+ spacing_ok = worst_case_distance >= min_distance - 1e-12
64+ return covered and spacing_ok , worst_case_distance
65+
66+
67+ def compare_for_spacing (positions : np .ndarray , grid_indices : np .ndarray , spacing : int ) -> dict :
68+ naive_basis = build_corner_subgrid_basis (grid_indices , spacing )
69+ fast_basis = ZonalFastBasisGenerator (positions , min_distance = float (spacing )).generate ()
70+
71+ naive_valid , naive_min_distance = validate_basis (positions , naive_basis , float (spacing ))
72+ fast_valid , fast_min_distance = validate_basis (positions , fast_basis , float (spacing ))
73+
74+ naive_modes = int (naive_basis .shape [1 ])
75+ fast_modes = int (fast_basis .shape [1 ])
76+ reduction = naive_modes - fast_modes
77+ reduction_fraction = reduction / naive_modes if naive_modes else 0.0
78+
79+ return {
80+ "spacing" : spacing ,
81+ "n_actuators" : int (positions .shape [0 ]),
82+ "naive_modes" : naive_modes ,
83+ "fast_modes" : fast_modes ,
84+ "reduction" : reduction ,
85+ "reduction_fraction" : reduction_fraction ,
86+ "naive_valid" : naive_valid ,
87+ "fast_valid" : fast_valid ,
88+ "naive_min_distance" : naive_min_distance ,
89+ "fast_min_distance" : fast_min_distance ,
90+ }
91+
92+
93+ def parse_distances (values : Sequence [int ]) -> List [int ]:
94+ unique_values = sorted ({int (value ) for value in values })
95+ if not unique_values :
96+ raise ValueError ("At least one spacing value must be provided." )
97+ if any (value <= 0 for value in unique_values ):
98+ raise ValueError ("Spacing values must all be positive integers." )
99+ return unique_values
100+
101+
102+ def print_report (results : Iterable [dict ]) -> None :
103+ header = (
104+ f"{ 'D' :>4} { 'Actuators' :>10} { 'Naive' :>8} { 'Fast' :>8} { 'Saved' :>8} { 'Saved %' :>9} "
105+ f"{ 'Naive OK' :>9} { 'Fast OK' :>8} { 'Naive min d' :>12} { 'Fast min d' :>11} "
106+ )
107+ print (header )
108+ print ("-" * len (header ))
109+
110+ for result in results :
111+ print (
112+ f"{ result ['spacing' ]:>4d} "
113+ f"{ result ['n_actuators' ]:>10d} "
114+ f"{ result ['naive_modes' ]:>8d} "
115+ f"{ result ['fast_modes' ]:>8d} "
116+ f"{ result ['reduction' ]:>8d} "
117+ f"{ 100.0 * result ['reduction_fraction' ]:>8.2f} % "
118+ f"{ str (result ['naive_valid' ]):>9} "
119+ f"{ str (result ['fast_valid' ]):>8} "
120+ f"{ result ['naive_min_distance' ]:>12.3f} "
121+ f"{ result ['fast_min_distance' ]:>11.3f} "
122+ )
123+
124+
125+ def main () -> None :
126+ parser = argparse .ArgumentParser (
127+ description = (
128+ "Compare a naive corner-anchored D x D subgrid zonal-fast basis against "
129+ "the graph-coloring zonal-fast basis on a circular aperture."
130+ )
131+ )
132+ parser .add_argument ("--grid-size" , type = int , default = 60 , help = "Number of points along each side of the square lattice." )
133+ parser .add_argument ("--radius" , type = float , default = 30.0 , help = "Circular aperture radius in lattice-pitch units." )
134+ parser .add_argument (
135+ "--distances" ,
136+ type = int ,
137+ nargs = "+" ,
138+ default = [2 , 3 , 4 , 5 , 6 , 8 , 10 ],
139+ help = "Integer spacing values D to test, in lattice-pitch units." ,
140+ )
141+ parser .add_argument (
142+ "--fail-if-not-better" ,
143+ action = "store_true" ,
144+ help = "Exit with status 1 if the zonal-fast basis is not strictly better for every tested spacing." ,
145+ )
146+ args = parser .parse_args ()
147+
148+ distances = parse_distances (args .distances )
149+ positions , grid_indices = make_integer_circular_grid (args .grid_size , args .radius )
150+ results = [compare_for_spacing (positions , grid_indices , spacing ) for spacing in distances ]
151+
152+ print (
153+ f"Circular aperture on a { args .grid_size } x{ args .grid_size } unit-pitch lattice, "
154+ f"radius={ args .radius :.1f} , actuators inside pupil={ positions .shape [0 ]} "
155+ )
156+ print_report (results )
157+
158+ wins = [result for result in results if result ["fast_modes" ] < result ["naive_modes" ]]
159+ ties = [result for result in results if result ["fast_modes" ] == result ["naive_modes" ]]
160+ losses = [result for result in results if result ["fast_modes" ] > result ["naive_modes" ]]
161+
162+ print ()
163+ print (f"Fast basis uses fewer modes for { len (wins )} / { len (results )} tested spacings." )
164+ if ties :
165+ tied_spacings = ", " .join (str (result ["spacing" ]) for result in ties )
166+ print (f"Tied spacings: { tied_spacings } " )
167+ if losses :
168+ loss_spacings = ", " .join (str (result ["spacing" ]) for result in losses )
169+ print (f"Fast basis used more modes at: { loss_spacings } " )
170+
171+ if not losses and not ties :
172+ print ("Verdict: zonal-fast is strictly better than the naive corner-subgrid basis for every tested spacing." )
173+ elif losses :
174+ print ("Verdict: this zonal-fast implementation does not beat the naive corner-subgrid baseline on this geometry." )
175+ print ("The graph coloring is producing a valid basis, but not a minimal one for these spacings." )
176+ else :
177+ print ("Verdict: zonal-fast matches the naive construction for some spacings and never improves on it in this run." )
178+
179+ if args .fail_if_not_better and (losses or ties ):
180+ raise SystemExit (1 )
181+
182+
183+ if __name__ == "__main__" :
184+ main ()
0 commit comments