Skip to content

Commit cec14c8

Browse files
authored
♻️ refactor(tests): align suites with metric/chart API updates and public JAX dtypes (#525)
Refactor tests and internal utilities to match the updated metric, chart, and manifold APIs while simplifying formatting and numeric literals. Highlights: - Update Hypothesis strategies and chart/vector/distance tests for the metric refactor and current manifold/chart APIs - Replace whole-number float literals with ints across manifold, representation, transform, vector, package, pt_map, strategy, and quantity_matrix tests where dtype semantics permit - Normalize assert_allclose(rtol=0.0) → rtol=0 - Preserve float literals where required for JAX grad/jvp behaviour, dtype tracing, and mixed-unit conversions - Collapse single-trailing-comma expressions onto one line across src, tests, config, and utility scripts for cleaner ruff-format output - Fix DenseMetric fixture dtype to avoid JIT/MLIR dtype mismatches - Replace private jax._src.dtypes usage with public jnp.issubdtype + jnp.result_type equivalents Also includes cosmetic formatting cleanup, updated fixtures, and minor test maintenance throughout the repository. Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent e3b9782 commit cec14c8

64 files changed

Lines changed: 700 additions & 1044 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ class CodeCellParser:
271271
def __init__(self, doctest_optionflags: int = 0) -> None: # noqa: D107
272272
self.doctest_parser = DocTestStringParser(DocTestEvaluator(doctest_optionflags))
273273
self.codeblock_parser = myst.CodeBlockParser(
274-
language="ipython3",
275-
evaluator=PythonEvaluator(),
274+
language="ipython3", evaluator=PythonEvaluator()
276275
)
277276

278277
@staticmethod

noxfile.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,7 @@ def docs(s: nox.Session, /) -> None:
160160
parser = argparse.ArgumentParser()
161161
parser.add_argument("--serve", action="store_true", help="Serve after building")
162162
parser.add_argument(
163-
"-b",
164-
dest="builder",
165-
default="html",
166-
help="Build target (default: html)",
163+
"-b", dest="builder", default="html", help="Build target (default: html)"
167164
)
168165
parser.add_argument("--output-dir", dest="output_dir", default="_build")
169166
args, posargs = parser.parse_known_args(s.posargs)
@@ -184,14 +181,7 @@ def docs(s: nox.Session, /) -> None:
184181
)
185182

186183
if args.builder == "linkcheck":
187-
s.run(
188-
"sphinx-build",
189-
"-b",
190-
"linkcheck",
191-
".",
192-
"_build/linkcheck",
193-
*posargs,
194-
)
184+
s.run("sphinx-build", "-b", "linkcheck", ".", "_build/linkcheck", *posargs)
195185
return
196186

197187
shared_args = (

packages/coordinax.astro/src/coordinax/astro/_src/galactocentric.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,12 @@ class Galactocentric(AbstractSpaceFrame):
5252
#: ra, dec: https://ui.adsabs.harvard.edu/abs/2004ApJ...616..872R
5353
#: distance: https://ui.adsabs.harvard.edu/abs/2018A%26A...615L..15G
5454
galcen: cxv.Point[cxc.LonLatSpherical3D, Any] = eqx.field( # ty: ignore[invalid-assignment]
55-
converter=cxv.Point[cxc.LonLatSpherical3D, Any].from_,
56-
default=GALCEN_DEFAULT,
55+
converter=cxv.Point[cxc.LonLatSpherical3D, Any].from_, default=GALCEN_DEFAULT
5756
)
5857

5958
#: Rotation angle of the Galactic center from the ICRS x-axis.
6059
roll: ScalarAngle = eqx.field(
61-
converter=Unless(u.Angle, u.Angle.from_),
62-
default=u.Angle(jnp.array(0), "deg"),
60+
converter=Unless(u.Angle, u.Angle.from_), default=u.Angle(jnp.array(0), "deg")
6361
)
6462

6563
#: Distance from the Sun to the Galactic midplane.

packages/coordinax.astro/tests/unit/test_frame_transforms.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ def _astropy_icrs_to_gcf_xyz_pc(xyz_pc: Iterable[float], frame: cxastro.Galactoc
5555
)
5656
out = sc.transform_to(_as_astropy_galactocentric(frame)).cartesian
5757
return np.array(
58-
[
59-
out.x.to_value(apyu.pc),
60-
out.y.to_value(apyu.pc),
61-
out.z.to_value(apyu.pc),
62-
],
58+
[out.x.to_value(apyu.pc), out.y.to_value(apyu.pc), out.z.to_value(apyu.pc)],
6359
dtype=float,
6460
)
6561

@@ -79,19 +75,12 @@ def _astropy_gcf_to_icrs_xyz_pc(xyz_pc: Iterable[float], frame: cxastro.Galactoc
7975
)
8076
out = sc.transform_to(apyc.ICRS()).cartesian
8177
return np.array(
82-
[
83-
out.x.to_value(apyu.pc),
84-
out.y.to_value(apyu.pc),
85-
out.z.to_value(apyu.pc),
86-
],
78+
[out.x.to_value(apyu.pc), out.y.to_value(apyu.pc), out.z.to_value(apyu.pc)],
8779
dtype=float,
8880
)
8981

9082

91-
@pytest.mark.parametrize(
92-
"xyz_pc",
93-
[(0, 0, 0), (100, -20, 50), (-5000, 3200, 1200)],
94-
)
83+
@pytest.mark.parametrize("xyz_pc", [(0, 0, 0), (100, -20, 50), (-5000, 3200, 1200)])
9584
def test_icrs_to_galactocentric_matches_astropy_positions(xyz_pc) -> None:
9685
"""ICRS->Galactocentric position transforms match Astropy."""
9786
gcf = cxastro.Galactocentric()
@@ -128,7 +117,7 @@ def test_icrs_galactocentric_transitions_are_inverse_for_positions() -> None:
128117
q = u.Q(jnp.asarray([450, -100, 220]), "pc")
129118
back = cxfm.act(bwd, None, cxfm.act(fwd, None, q))
130119

131-
np.testing.assert_allclose(_to_np(back, "pc"), _to_np(q, "pc"), rtol=0.0, atol=1e-6)
120+
np.testing.assert_allclose(_to_np(back, "pc"), _to_np(q, "pc"), rtol=0, atol=1e-6)
132121

133122

134123
# ===================================================================
@@ -140,9 +129,7 @@ class TestFrameTransformProperties:
140129

141130
@given(
142131
q=ust.quantities(
143-
"pc",
144-
shape=(3,),
145-
elements={"min_value": -5e4, "max_value": 5e4},
132+
"pc", shape=(3,), elements={"min_value": -5e4, "max_value": 5e4}
146133
)
147134
)
148135
@settings(deadline=None)
@@ -161,9 +148,7 @@ def test_icrs_gcf_icrs_roundtrip(self, q: u.AbstractQuantity) -> None:
161148

162149
@given(
163150
q=ust.quantities(
164-
"pc",
165-
shape=(3,),
166-
elements={"min_value": -5e4, "max_value": 5e4},
151+
"pc", shape=(3,), elements={"min_value": -5e4, "max_value": 5e4}
167152
)
168153
)
169154
@settings(deadline=None)
@@ -180,9 +165,7 @@ def test_gcf_icrs_gcf_roundtrip(self, q: u.AbstractQuantity) -> None:
180165

181166
@given(
182167
q=ust.quantities(
183-
"pc",
184-
shape=(3,),
185-
elements={"min_value": -5e4, "max_value": 5e4},
168+
"pc", shape=(3,), elements={"min_value": -5e4, "max_value": 5e4}
186169
)
187170
)
188171
@settings(deadline=None)
@@ -205,14 +188,12 @@ def test_inverse_is_frame_transition_in_reverse(
205188
via_bwd = cxfm.act(bwd, None, q_gcf)
206189

207190
np.testing.assert_allclose(
208-
via_inverse.ustrip("pc"), via_bwd.ustrip("pc"), rtol=0.0, atol=1e-6
191+
via_inverse.ustrip("pc"), via_bwd.ustrip("pc"), rtol=0, atol=1e-6
209192
)
210193

211194
@given(
212195
q=ust.quantities(
213-
"pc",
214-
shape=(3,),
215-
elements={"min_value": -5e4, "max_value": 5e4},
196+
"pc", shape=(3,), elements={"min_value": -5e4, "max_value": 5e4}
216197
)
217198
)
218199
@settings(deadline=None)
@@ -226,4 +207,4 @@ def test_icrs_to_gcf_matches_astropy_on_random_positions(
226207
xyz = q.ustrip("pc")
227208
got = cxfm.act(op, None, q).ustrip("pc")
228209
expected = _astropy_icrs_to_gcf_xyz_pc((xyz[0], xyz[1], xyz[2]), gcf)
229-
np.testing.assert_allclose(got, expected, rtol=0.0, atol=1e-6)
210+
np.testing.assert_allclose(got, expected, rtol=0, atol=1e-6)

packages/coordinax.curveframes/src/coordinax/curveframes/_src/bishop.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,6 @@ def from_curve(
605605
606606
"""
607607
xop = BishopTransform.from_curve(
608-
curve,
609-
tau_unit=tau_unit,
610-
tau_0=tau_0,
611-
initial_normal=initial_normal,
608+
curve, tau_unit=tau_unit, tau_0=tau_0, initial_normal=initial_normal
612609
)
613610
return cls(base_frame=base_frame, xop=xop, xop_inv=xop.inverse)

packages/coordinax.curveframes/src/coordinax/curveframes/_src/frenetserret.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,7 @@ def binormal(self, tau: Any) -> u.Q:
210210

211211
@classmethod
212212
def from_curve(
213-
cls,
214-
curve: Callable[[Any], Any],
215-
/,
216-
tau_unit: u.AbstractUnit | str = "s",
213+
cls, curve: Callable[[Any], Any], /, tau_unit: u.AbstractUnit | str = "s"
217214
) -> "FrenetSerretTransform":
218215
r"""Construct a Frenet-Serret transform from a curve callable.
219216

packages/coordinax.curveframes/tests/test_bishop.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,7 @@ def test_roundtrip_forward_inverse(
318318
U2i = inv.normal2(tau)
319319
diff_inv = p_fwd - g_inv
320320
p_rec = qnp.stack(
321-
[
322-
qnp.sum(Ti * diff_inv),
323-
qnp.sum(U1i * diff_inv),
324-
qnp.sum(U2i * diff_inv),
325-
]
321+
[qnp.sum(Ti * diff_inv), qnp.sum(U1i * diff_inv), qnp.sum(U2i * diff_inv)]
326322
)
327323

328324
assert jnp.allclose(p_rec.value, p.value, atol=1e-3)
@@ -392,11 +388,7 @@ def test_roundtrip(self, helix_bishop: cxfc.BishopTransform):
392388
U2i = inv.normal2(tau)
393389
diff_inv = p_fwd - g_inv
394390
p_rec = qnp.stack(
395-
[
396-
qnp.sum(Ti * diff_inv),
397-
qnp.sum(U1i * diff_inv),
398-
qnp.sum(U2i * diff_inv),
399-
]
391+
[qnp.sum(Ti * diff_inv), qnp.sum(U1i * diff_inv), qnp.sum(U2i * diff_inv)]
400392
)
401393

402394
assert jnp.allclose(p_rec.value, p.value, atol=1e-3)

packages/coordinax.curveframes/tests/test_bishop_frame.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def line_bishop_frame() -> cxfc.BishopFrame:
5151
# Helpers
5252

5353

54-
def _as_array(x: object, unit: str) -> np.ndarray:
54+
def _as_arr(x: object, unit: str) -> np.ndarray:
5555
assert isinstance(x, u.AbstractQuantity)
5656
return np.asarray(u.ustrip(unit, x), dtype=float)
5757

@@ -126,7 +126,7 @@ def test_act_forward_at_tau_zero(self, circle_bishop_transform) -> None:
126126
tau = u.Q(0, "s")
127127
p = u.Q(jnp.array([1, 0, 0]), "km")
128128
result = cxfm.act(circle_bishop_transform, tau, p)
129-
np.testing.assert_allclose(_as_array(result, "km"), [0, 0, 0], atol=1e-5)
129+
np.testing.assert_allclose(_as_arr(result, "km"), [0, 0, 0], atol=1e-5)
130130

131131
def test_act_inverse_roundtrip(self, circle_bishop_transform) -> None:
132132
"""Forward then inverse recovers original point."""
@@ -136,9 +136,7 @@ def test_act_inverse_roundtrip(self, circle_bishop_transform) -> None:
136136
p_curve = cxfm.act(circle_bishop_transform, tau, p)
137137
p_back = cxfm.act(circle_bishop_transform.inverse, tau, p_curve)
138138

139-
np.testing.assert_allclose(
140-
_as_array(p_back, "km"), _as_array(p, "km"), atol=1e-3
141-
)
139+
np.testing.assert_allclose(_as_arr(p_back, "km"), _as_arr(p, "km"), atol=1e-3)
142140

143141
def test_act_at_different_tau_values(self, circle_bishop_transform) -> None:
144142
"""Different tau values give different results for same point."""
@@ -147,7 +145,7 @@ def test_act_at_different_tau_values(self, circle_bishop_transform) -> None:
147145
r1 = cxfm.act(circle_bishop_transform, u.Q(0, "s"), p)
148146
r2 = cxfm.act(circle_bishop_transform, u.Q(1, "s"), p)
149147

150-
assert not np.allclose(_as_array(r1, "km"), _as_array(r2, "km"), atol=1e-3)
148+
assert not np.allclose(_as_arr(r1, "km"), _as_arr(r2, "km"), atol=1e-3)
151149

152150

153151
# ===================================================================
@@ -176,9 +174,7 @@ def test_roundtrip_alice_to_bishop_and_back(self, circle_bishop_frame) -> None:
176174
p_bishop = cxfm.act(op_fwd, tau, p)
177175
p_back = cxfm.act(op_bwd, tau, p_bishop)
178176

179-
np.testing.assert_allclose(
180-
_as_array(p_back, "km"), _as_array(p, "km"), atol=1e-3
181-
)
177+
np.testing.assert_allclose(_as_arr(p_back, "km"), _as_arr(p, "km"), atol=1e-3)
182178

183179
def test_alice_bishop_alex_chain(self) -> None:
184180
"""Alice -> Bishop(tau) -> Alex chain and reverse."""
@@ -194,9 +190,7 @@ def test_alice_bishop_alex_chain(self) -> None:
194190
op_b_to_a = cxf.frame_transition(b_frame, cxf.Alice())
195191
p_back = cxfm.act(op_b_to_a, tau, p_bishop)
196192

197-
np.testing.assert_allclose(
198-
_as_array(p_back, "km"), _as_array(p, "km"), atol=1e-3
199-
)
193+
np.testing.assert_allclose(_as_arr(p_back, "km"), _as_arr(p, "km"), atol=1e-3)
200194

201195
# Bishop -> Alex
202196
op_b_to_alex = cxf.frame_transition(b_frame, cxf.Alex())
@@ -207,7 +201,7 @@ def test_alice_bishop_alex_chain(self) -> None:
207201
p_bishop2 = cxfm.act(op_alex_to_b, tau, p_alex)
208202

209203
np.testing.assert_allclose(
210-
_as_array(p_bishop2, "km"), _as_array(p_bishop, "km"), atol=1e-3
204+
_as_arr(p_bishop2, "km"), _as_arr(p_bishop, "km"), atol=1e-3
211205
)
212206

213207
def test_full_chain_alice_bishop_alex_roundtrip(self) -> None:
@@ -224,9 +218,7 @@ def test_full_chain_alice_bishop_alex_roundtrip(self) -> None:
224218
op4 = cxf.frame_transition(b_frame, cxf.Alice())
225219
p_back = cxfm.act(op4, tau, cxfm.act(op3, tau, p_alex))
226220

227-
np.testing.assert_allclose(
228-
_as_array(p_back, "km"), _as_array(p, "km"), atol=1e-2
229-
)
221+
np.testing.assert_allclose(_as_arr(p_back, "km"), _as_arr(p, "km"), atol=1e-2)
230222

231223
def test_straight_line_frame_transition(self, line_bishop_frame) -> None:
232224
"""Frame transition works on a straight line (kappa=0)."""
@@ -239,9 +231,7 @@ def test_straight_line_frame_transition(self, line_bishop_frame) -> None:
239231
op_inv = cxf.frame_transition(line_bishop_frame, cxf.Alice())
240232
p_back = cxfm.act(op_inv, tau, p_bishop)
241233

242-
np.testing.assert_allclose(
243-
_as_array(p_back, "km"), _as_array(p, "km"), atol=1e-3
244-
)
234+
np.testing.assert_allclose(_as_arr(p_back, "km"), _as_arr(p, "km"), atol=1e-3)
245235

246236

247237
# ===================================================================
@@ -262,9 +252,7 @@ def test_act_jit(self, circle_bishop_transform) -> None:
262252
)
263253

264254
np.testing.assert_allclose(
265-
_as_array(result_jit, "km"),
266-
_as_array(result_eager, "km"),
267-
atol=1e-5,
255+
_as_arr(result_jit, "km"), _as_arr(result_eager, "km"), atol=1e-5
268256
)
269257

270258
def test_act_vmap_over_tau(self, circle_bishop_transform) -> None:
@@ -290,15 +278,15 @@ def test_forward_moves_point_to_curve_frame(self, circle_bishop_transform) -> No
290278
p_on_curve = u.Q(jnp.array([1, 0, 0]), "km")
291279
result = cxfm.act(circle_bishop_transform, tau, p_on_curve)
292280

293-
np.testing.assert_allclose(_as_array(result, "km"), [0, 0, 0], atol=1e-5)
281+
np.testing.assert_allclose(_as_arr(result, "km"), [0, 0, 0], atol=1e-5)
294282

295283
def test_inverse_moves_point_back_to_ambient(self, circle_bishop_transform) -> None:
296284
"""Origin of curve frame at tau=0 maps back to gamma(0)."""
297285
tau = u.Q(0, "s")
298286
p_origin = u.Q(jnp.array([0, 0, 0]), "km")
299287
result = cxfm.act(circle_bishop_transform.inverse, tau, p_origin)
300288

301-
np.testing.assert_allclose(_as_array(result, "km"), [1, 0, 0], atol=1e-3)
289+
np.testing.assert_allclose(_as_arr(result, "km"), [1, 0, 0], atol=1e-3)
302290

303291
def test_frame_transition_matches_direct_transform(
304292
self, circle_bishop_frame, circle_bishop_transform
@@ -312,7 +300,5 @@ def test_frame_transition_matches_direct_transform(
312300
result_direct = cxfm.act(circle_bishop_transform, tau, p)
313301

314302
np.testing.assert_allclose(
315-
_as_array(result_ft, "km"),
316-
_as_array(result_direct, "km"),
317-
atol=1e-5,
303+
_as_arr(result_ft, "km"), _as_arr(result_direct, "km"), atol=1e-5
318304
)

packages/coordinax.hypothesis/src/coordinax/hypothesis/distances/_src/dist.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,7 @@ def distances(
7979
# Generate the Distance quantity
8080
return draw( # ty: ignore[invalid-return-type]
8181
ust.quantities(
82-
unit,
83-
quantity_cls=cxd.Distance,
84-
check_negative=check_negative,
85-
**kwargs,
82+
unit, quantity_cls=cxd.Distance, check_negative=check_negative, **kwargs
8683
)
8784
)
8885

packages/coordinax.hypothesis/src/coordinax/hypothesis/manifolds/_src/atlas.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,7 @@ def atlases( # noqa: F811
346346
chosen_filter = draw_if_strategy(draw, filter)
347347

348348
all_classes = get_all_subclasses(
349-
cxm.AbstractAtlas,
350-
filter=chosen_filter,
351-
exclude_abstract=True,
352-
exclude=exclude,
349+
cxm.AbstractAtlas, filter=chosen_filter, exclude_abstract=True, exclude=exclude
353350
)
354351
classes = tuple(
355352
cls
@@ -408,9 +405,7 @@ def atlases( # noqa: F811
408405
selected_cls = draw(atlas_cls)
409406
return draw(
410407
cast(Any, atlases)(
411-
selected_cls,
412-
ndim=ndim,
413-
required_chart_classes=required_chart_classes,
408+
selected_cls, ndim=ndim, required_chart_classes=required_chart_classes
414409
)
415410
)
416411

0 commit comments

Comments
 (0)