Skip to content

Commit 85db8c1

Browse files
cailmdaleyclaude
andcommitted
tests: hypothesis property tests for the PSF column refactor
Four property-test modules, each mutation-verified non-vacuous: - grammar: orig/reconv never crossed for any distinct inputs; every emitted column matches the NGMIX_<COMPONENT>[_ERR]_<OBJECT>[_<SHEAR>] grammar; every NGMIX token in final_cat.param is producible by the writer. - physics: T_reconv > T_orig and |g_reconv| <= |g_orig| over random elliptical PSFs (the dilation clause is the robust transposition catcher). - no-op: average_original_psf leaves gal_obs.psf pristine (no gmix leak), so the restored original-PSF fit cannot alter the galaxy/shear results. - averaging: epoch-weighted mean within per-epoch range, flagged epochs excluded, single survivor passes through, sentinel fills for absent objects. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent b690d3f commit 85db8c1

4 files changed

Lines changed: 978 additions & 0 deletions

File tree

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
"""PROPERTY-BASED TESTS FOR PSF EPOCH-AVERAGING + make_cat SENTINELS.
2+
3+
The two PSF families this module exports (the metacal reconvolution kernel via
4+
``average_multiepoch_psf`` and the original image PSF via
5+
``average_original_psf``) share a single averaging core,
6+
:func:`shapepipe.modules.ngmix_package.ngmix._average_psf_fits`. These
7+
hypothesis properties pin the contract of that core directly — a weighted mean
8+
over the surviving (``flags == 0``) epochs — and the companion sentinel-fill
9+
contract on the make_cat reader, which must leave an obj_id absent from the
10+
ngmix catalogue at its type-specific sentinel.
11+
12+
Strategies are kept in physically sensible ranges (positive epoch weights,
13+
positive sizes ``T``, ``|g| < 1``) so every generated input is a valid PSF-fit
14+
result the averaging core would actually be handed in production.
15+
"""
16+
17+
import os
18+
import tempfile
19+
from itertools import zip_longest
20+
21+
import numpy as np
22+
import numpy.testing as npt
23+
from astropy.io import fits
24+
from hypothesis import given, settings
25+
from hypothesis import strategies as st
26+
27+
from shapepipe.modules.make_cat_package.make_cat import SaveCatalogue
28+
from shapepipe.modules.ngmix_package.ngmix import _average_psf_fits
29+
30+
31+
# --------------------------------------------------------------------------- #
32+
# Strategies for one valid per-epoch PSF-fit result.
33+
# --------------------------------------------------------------------------- #
34+
35+
# Bounded, non-degenerate floats: positive sizes, |g| < 1, strictly positive
36+
# weights. min_magnitude keeps weights away from 0 so wsum can never vanish on
37+
# the surviving epochs (the core raises ZeroDivisionError there by contract).
38+
_g_component = st.floats(min_value=-0.9, max_value=0.9, allow_nan=False)
39+
_positive_T = st.floats(min_value=1e-3, max_value=10.0, allow_nan=False)
40+
_positive_err = st.floats(min_value=1e-6, max_value=1.0, allow_nan=False)
41+
_weight = st.floats(
42+
min_value=1e-3, max_value=1e3, allow_nan=False, allow_infinity=False
43+
)
44+
45+
46+
def _result(g1, g2, t, g1_err, g2_err, t_err, flags=0):
47+
"""One ngmix PSF-fit result dict, as the averaging core consumes it."""
48+
return {
49+
"flags": flags,
50+
"g": np.array([g1, g2]),
51+
"g_err": np.array([g1_err, g2_err]),
52+
"T": t,
53+
"T_err": t_err,
54+
}
55+
56+
57+
@st.composite
58+
def _good_epoch(draw):
59+
"""A surviving (flags == 0) epoch paired with its positive weight."""
60+
return (
61+
_result(
62+
draw(_g_component), draw(_g_component), draw(_positive_T),
63+
draw(_positive_err), draw(_positive_err), draw(_positive_err),
64+
),
65+
draw(_weight),
66+
)
67+
68+
69+
# --------------------------------------------------------------------------- #
70+
# (a) weighted average lies within [min, max] of the per-epoch values.
71+
# --------------------------------------------------------------------------- #
72+
73+
@settings(deadline=None, max_examples=50)
74+
@given(st.lists(_good_epoch(), min_size=1, max_size=8))
75+
def test_average_lies_within_per_epoch_range(epochs):
76+
"""A positive-weight weighted mean is a convex combination of its inputs.
77+
78+
For every averaged component (g1, g2, T) the result must sit within the
79+
[min, max] envelope of the per-epoch values. This is the defining property
80+
of a weighted mean with strictly positive weights; it fails immediately if
81+
the core ever divided by the wrong weight sum, summed unweighted, or let a
82+
value escape the convex hull.
83+
"""
84+
out = _average_psf_fits(epochs)
85+
86+
g1_vals = np.array([r["g"][0] for r, _ in epochs])
87+
g2_vals = np.array([r["g"][1] for r, _ in epochs])
88+
t_vals = np.array([r["T"] for r, _ in epochs])
89+
90+
# rtol absorbs only floating-point round-off, not a real bound violation.
91+
npt.assert_array_less(out["g_psf"][0], g1_vals.max() + 1e-9)
92+
npt.assert_array_less(g1_vals.min() - 1e-9, out["g_psf"][0])
93+
npt.assert_array_less(out["g_psf"][1], g2_vals.max() + 1e-9)
94+
npt.assert_array_less(g2_vals.min() - 1e-9, out["g_psf"][1])
95+
npt.assert_array_less(out["T_psf"], t_vals.max() + 1e-9)
96+
npt.assert_array_less(t_vals.min() - 1e-9, out["T_psf"])
97+
98+
assert out["n_epoch"] == len(epochs)
99+
100+
101+
@settings(deadline=None, max_examples=50)
102+
@given(st.lists(_good_epoch(), min_size=1, max_size=8))
103+
def test_average_matches_explicit_weighted_mean(epochs):
104+
"""The core's output equals the textbook weighted mean of the survivors.
105+
106+
Stronger than the [min, max] envelope: pins the exact value, so a swap of
107+
weighting factor or an off-by-one in the accumulation is caught.
108+
"""
109+
out = _average_psf_fits(epochs)
110+
w = np.array([weight for _, weight in epochs])
111+
112+
g = np.array([r["g"] for r, _ in epochs])
113+
t = np.array([r["T"] for r, _ in epochs])
114+
npt.assert_allclose(out["g_psf"], (g * w[:, None]).sum(0) / w.sum())
115+
npt.assert_allclose(out["T_psf"], (t * w).sum() / w.sum())
116+
117+
118+
# --------------------------------------------------------------------------- #
119+
# (b) epochs with flags != 0 are excluded from the average.
120+
# --------------------------------------------------------------------------- #
121+
122+
@settings(deadline=None, max_examples=50)
123+
@given(
124+
st.lists(_good_epoch(), min_size=1, max_size=6),
125+
st.lists(
126+
st.tuples(_weight, st.integers(min_value=1, max_value=255)),
127+
min_size=1,
128+
max_size=6,
129+
),
130+
)
131+
def test_flagged_epochs_are_excluded(good_epochs, flagged_specs):
132+
"""A failed-PSF epoch (flags != 0) must not enter the average at all.
133+
134+
Each flagged epoch carries poisoned NaN measurement fields and a huge,
135+
out-of-range T — values that would wreck the mean (NaN-poison it, or shove
136+
it far outside the good-epoch envelope) if they leaked in. The averaged
137+
result and n_epoch must match a clean average over the good epochs only,
138+
proving the flagged ones were dropped, not merely down-weighted.
139+
"""
140+
flagged = [
141+
(
142+
_result(
143+
np.nan, np.nan, 1e6, np.nan, np.nan, np.nan, flags=flags
144+
),
145+
weight,
146+
)
147+
for weight, flags in flagged_specs
148+
]
149+
# Interleave good and flagged epochs WITHOUT duplicating any good epoch, so
150+
# exclusion can't be a happy accident of ordering. itertools.zip_longest
151+
# threads them together; the leftover tail of the longer list follows.
152+
mixed = [
153+
e
154+
for pair in zip_longest(good_epochs, flagged)
155+
for e in pair
156+
if e is not None
157+
]
158+
159+
out = _average_psf_fits(mixed)
160+
expected = _average_psf_fits(good_epochs)
161+
162+
npt.assert_allclose(out["g_psf"], expected["g_psf"])
163+
npt.assert_allclose(out["T_psf"], expected["T_psf"])
164+
npt.assert_allclose(out["T_psf_err"], expected["T_psf_err"])
165+
assert out["n_epoch"] == len(good_epochs)
166+
assert np.isfinite(out["g_psf"]).all() and np.isfinite(out["T_psf"])
167+
168+
169+
# --------------------------------------------------------------------------- #
170+
# (c) a single surviving epoch returns that epoch's value.
171+
# --------------------------------------------------------------------------- #
172+
173+
@settings(deadline=None, max_examples=50)
174+
@given(
175+
_good_epoch(),
176+
st.lists(
177+
st.tuples(_weight, st.integers(min_value=1, max_value=255)),
178+
max_size=5,
179+
),
180+
)
181+
def test_single_survivor_returns_its_own_value(survivor, flagged_specs):
182+
"""When exactly one epoch survives, its values pass through untouched.
183+
184+
The lone survivor's weight cancels in mean = (v*w)/w, so the result must be
185+
the survivor's value exactly, regardless of how many flagged epochs (with
186+
arbitrary weights) surround it. n_epoch must be 1.
187+
"""
188+
result, weight = survivor
189+
flagged = [
190+
(_result(np.nan, np.nan, 1e6, np.nan, np.nan, np.nan, flags=f), w)
191+
for w, f in flagged_specs
192+
]
193+
out = _average_psf_fits(flagged + [survivor] + flagged)
194+
195+
npt.assert_allclose(out["g_psf"], result["g"])
196+
npt.assert_allclose(out["g_psf_err"], result["g_err"])
197+
npt.assert_allclose(out["T_psf"], result["T"])
198+
npt.assert_allclose(out["T_psf_err"], result["T_err"])
199+
assert out["n_epoch"] == 1
200+
201+
202+
# --------------------------------------------------------------------------- #
203+
# (d) make_cat: an obj_id absent from the ngmix cat keeps its sentinel fill.
204+
# --------------------------------------------------------------------------- #
205+
206+
# The sentinel value each column family is pre-filled with before the matched
207+
# rows are overwritten (mirrors make_cat._save_ngmix_data). The property: a row
208+
# whose obj_id never appears among the ngmix ids keeps exactly these.
209+
_SENTINELS = {
210+
"NGMIX_T_GAL_NOSHEAR": 0.0,
211+
"NGMIX_SNR_GAL_NOSHEAR": 0.0,
212+
"NGMIX_FLAGS_GAL_NOSHEAR": 0.0,
213+
"NGMIX_T_PSF_ORIG_NOSHEAR": 0.0,
214+
"NGMIX_T_PSF_RECONV_NOSHEAR": 0.0,
215+
"NGMIX_FLUX_ERR_GAL_NOSHEAR": -1.0,
216+
"NGMIX_MAG_ERR_GAL_NOSHEAR": -1.0,
217+
"NGMIX_G1_GAL_NOSHEAR": -10.0,
218+
"NGMIX_G2_GAL_NOSHEAR": -10.0,
219+
"NGMIX_G1_PSF_ORIG_NOSHEAR": -10.0,
220+
"NGMIX_G1_PSF_RECONV_NOSHEAR": -10.0,
221+
"NGMIX_T_ERR_GAL_NOSHEAR": 1e30,
222+
"NGMIX_T_ERR_PSF_ORIG_NOSHEAR": 1e30,
223+
"NGMIX_T_ERR_PSF_RECONV_NOSHEAR": 1e30,
224+
"NGMIX_N_EPOCH": 0.0,
225+
"NGMIX_MCAL_FLAGS": 0.0,
226+
"NGMIX_MCAL_TYPES_FAIL": 0.0,
227+
}
228+
229+
# Per-key write format and a measured value distinct from every sentinel, so a
230+
# matched row is unmistakably "overwritten" and an absent row unmistakably not.
231+
_NGMIX_KEYS = [
232+
"id", "n_epoch_model", "mcal_types_fail", "nfev_fit",
233+
"g1", "g1_err", "g2", "g2_err", "T", "T_err",
234+
"flux", "flux_err", "s2n", "mag", "mag_err", "flags", "mcal_flags",
235+
"g1_psf_orig", "g2_psf_orig", "g1_err_psf_orig", "g2_err_psf_orig",
236+
"T_psf_orig", "T_err_psf_orig",
237+
"g1_psf_reconv", "g2_psf_reconv", "g1_err_psf_reconv", "g2_err_psf_reconv",
238+
"T_psf_reconv", "T_err_psf_reconv",
239+
]
240+
_INT_KEYS = {
241+
"id", "n_epoch_model", "mcal_types_fail", "nfev_fit", "flags", "mcal_flags"
242+
}
243+
_SHEAR_EXTS = ["1M", "1P", "2M", "2P", "NOSHEAR"]
244+
245+
246+
class _NullLogger:
247+
def info(self, *_args, **_kwargs):
248+
pass
249+
250+
251+
def _measured_row(obj_id):
252+
"""One fit object whose every value is far from any sentinel (5 / 0.5)."""
253+
row = {key: (5 if key in _INT_KEYS else 0.5) for key in _NGMIX_KEYS}
254+
row["id"] = obj_id
255+
return row
256+
257+
258+
def _write_ngmix_cat(path, obj_ids):
259+
rows = [_measured_row(oid) for oid in obj_ids]
260+
hdus = [fits.PrimaryHDU()]
261+
for ext in _SHEAR_EXTS:
262+
cols = [
263+
fits.Column(
264+
name=key,
265+
format="K" if key in _INT_KEYS else "D",
266+
array=np.array([row[key] for row in rows]),
267+
)
268+
for key in _NGMIX_KEYS
269+
]
270+
hdus.append(fits.BinTableHDU.from_columns(cols, name=ext))
271+
fits.HDUList(hdus).writeto(path, overwrite=True)
272+
273+
274+
def _run_save_ngmix(ngmix_path, obj_id):
275+
inst = object.__new__(SaveCatalogue)
276+
inst._obj_id = np.asarray(obj_id)
277+
inst._output_dict = {}
278+
inst._cat_size_target = len(inst._obj_id)
279+
inst._w_log = _NullLogger()
280+
err_msg = inst._save_ngmix_data(str(ngmix_path))
281+
assert err_msg is None
282+
return inst._output_dict
283+
284+
285+
# Distinct positive integer obj_ids; split into "fit by ngmix" vs "absent".
286+
_distinct_ids = st.lists(
287+
st.integers(min_value=1, max_value=10_000),
288+
min_size=2,
289+
max_size=8,
290+
unique=True,
291+
)
292+
293+
294+
@settings(deadline=None, max_examples=25)
295+
@given(_distinct_ids, st.data())
296+
def test_absent_obj_id_keeps_sentinel_fill(all_ids, data):
297+
"""An obj_id SExtractor saw but ngmix never fit keeps every sentinel.
298+
299+
The final catalogue carries all of ``all_ids``; ngmix fit only a non-empty
300+
proper subset. The unfit rows must retain the exact per-column sentinel
301+
(0 / -10 / 1e30 / -1), while the fit rows must have been overwritten to the
302+
measured value — so this is not vacuously satisfied by an all-sentinel cat.
303+
"""
304+
fit_ids = data.draw(
305+
st.lists(st.sampled_from(all_ids), min_size=1, unique=True).filter(
306+
lambda s: 0 < len(s) < len(all_ids)
307+
)
308+
)
309+
absent_ids = [oid for oid in all_ids if oid not in fit_ids]
310+
311+
with tempfile.TemporaryDirectory() as tmp:
312+
ngmix_path = os.path.join(tmp, "ngmix.fits")
313+
_write_ngmix_cat(ngmix_path, fit_ids)
314+
out = _run_save_ngmix(ngmix_path, all_ids)
315+
316+
absent_idx = [all_ids.index(oid) for oid in absent_ids]
317+
fit_idx = [all_ids.index(oid) for oid in fit_ids]
318+
319+
for col, sentinel in _SENTINELS.items():
320+
arr = np.asarray(out[col])
321+
npt.assert_allclose(
322+
arr[absent_idx], sentinel,
323+
err_msg=f"{col}: absent rows lost their sentinel {sentinel}",
324+
)
325+
# The fit rows were overwritten — measured value 5 (int cols) or 0.5,
326+
# both distinct from every sentinel, so the fill wasn't global.
327+
assert not np.any(np.isclose(arr[fit_idx], sentinel)), (
328+
f"{col}: a fit row still carries the sentinel {sentinel}"
329+
)

0 commit comments

Comments
 (0)