Skip to content

Commit 9b4c2f6

Browse files
authored
♻️ refactor(pt_map): remove manifold parameter dispatch (#519)
Remove the manifold-based pt_map dispatch and consolidate to use chart-based routing via chart.M.atlas. This simplifies the API by eliminating a redundant dispatch path. Also apply formatting improvements to hypothesis manifold strategy dispatch signatures for consistency. Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent f4d0b50 commit 9b4c2f6

4 files changed

Lines changed: 42 additions & 77 deletions

File tree

packages/coordinax.hypothesis/src/coordinax/hypothesis/manifolds/_src/manifold.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,59 +71,45 @@ def manifold_classes(
7171

7272
@plum.dispatch
7373
def _manifold_class_supports_ndim(
74-
cls: type[cxm.EuclideanManifold],
75-
ndim: int,
76-
/,
74+
cls: type[cxm.EuclideanManifold], ndim: int, /
7775
) -> bool:
7876
"""EuclideanManifold supports any dimensionality."""
7977
return True
8078

8179

8280
@plum.dispatch
8381
def _manifold_class_supports_ndim(
84-
cls: type[cxm.HyperSphericalManifold],
85-
ndim: int,
86-
/,
82+
cls: type[cxm.HyperSphericalManifold], ndim: int, /
8783
) -> bool:
8884
"""HyperSphericalManifold is always 2-D."""
8985
return ndim == 2
9086

9187

9288
@plum.dispatch
9389
def _manifold_class_supports_ndim(
94-
cls: type[cxm.EmbeddedManifold],
95-
ndim: int,
96-
/,
90+
cls: type[cxm.EmbeddedManifold], ndim: int, /
9791
) -> bool:
9892
"""EmbeddedManifold: only the 2-D embedded two-sphere is currently generated."""
9993
return ndim == 2
10094

10195

10296
@plum.dispatch
10397
def _manifold_class_supports_ndim(
104-
cls: type[cxm.CartesianProductManifold],
105-
ndim: int,
106-
/,
98+
cls: type[cxm.CartesianProductManifold], ndim: int, /
10799
) -> bool:
108100
"""CartesianProductManifold requires at least 1 dimension."""
109101
return ndim >= 1
110102

111103

112104
@plum.dispatch
113-
def _manifold_class_supports_ndim(
114-
cls: type[cxm.CustomManifold],
115-
ndim: int,
116-
/,
117-
) -> bool:
105+
def _manifold_class_supports_ndim(cls: type[cxm.CustomManifold], ndim: int, /) -> bool:
118106
"""CustomManifold supports ndim when matching zero-arg charts exist."""
119107
return len(_matching_chart_classes_for_ndim(ndim)) > 0
120108

121109

122110
@plum.dispatch
123111
def _manifold_class_supports_ndim(
124-
cls: type[cxm.AbstractManifold],
125-
ndim: int,
126-
/,
112+
cls: type[cxm.AbstractManifold], ndim: int, /
127113
) -> bool:
128114
"""Fallback: unknown manifold types are assumed to support any ndim."""
129115
return True

src/coordinax/_src/base/manifold.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def pt_map(self, x: Any, /, *args: Any, **kwargs: Any) -> Any:
268268
Atlas EuclideanAtlas(ndim=2) does not support chart SphericalTwoSphere(M=Sn(2))
269269
270270
"""
271-
return cxcapi.pt_map(x, self, *args, **kwargs)
271+
# TODO: check chart compatible with manifold
272+
return cxcapi.pt_map(x, self.atlas, *args, **kwargs)
272273

273274
# =====================================================
274275

src/coordinax/_src/charts/register_ptmap.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def pt_map(
209209
True
210210
211211
"""
212-
del from_chart, to_chart, usys # unused
212+
del usys # unused
213+
assert from_chart.M == to_chart.M # noqa: S101
213214
return p
214215

215216

@@ -283,7 +284,8 @@ def pt_map(
283284
{'x': 5.0}
284285
285286
"""
286-
del to_chart, from_chart, usys # unused
287+
del usys # unused
288+
assert from_chart.M == to_chart.M # noqa: S101
287289
return {"x": p["r"]}
288290

289291

@@ -313,7 +315,8 @@ def pt_map(
313315
{'r': 5.0}
314316
315317
"""
316-
del to_chart, from_chart, usys # unused
318+
del usys # unused
319+
assert from_chart.M == to_chart.M # noqa: S101
317320
return {"r": p["x"]}
318321

319322

@@ -346,7 +349,7 @@ def pt_map(
346349
'y': Array(5., dtype=float64, ...)}
347350
348351
"""
349-
del to_chart, from_chart # unused
352+
assert from_chart.M == to_chart.M # noqa: S101
350353
theta = uconvert_to_rad(p["theta"], usys)
351354
x = p["r"] * jnp.cos(theta)
352355
y = p["r"] * jnp.sin(theta)
@@ -377,7 +380,8 @@ def pt_map(
377380
'theta': Array(0.92729522, dtype=float64, ...)}
378381
379382
"""
380-
del to_chart, from_chart, usys # unused
383+
del usys # unused
384+
assert from_chart.M == to_chart.M # noqa: S101
381385
r_ = jnp.hypot(p["x"], p["y"])
382386
theta = jnp.arctan2(p["y"], p["x"])
383387
return {"r": r_, "theta": theta}
@@ -409,7 +413,7 @@ def pt_map(
409413
'z': 2.0}
410414
411415
"""
412-
del to_chart, from_chart # unused
416+
assert from_chart.M == to_chart.M # noqa: S101
413417
phi = uconvert_to_rad(p["phi"], usys)
414418
x = p["rho"] * jnp.cos(phi)
415419
y = p["rho"] * jnp.sin(phi)
@@ -444,7 +448,7 @@ def pt_map(
444448
'z': Array(1.2246468e-16, dtype=float64, ...)}
445449
446450
"""
447-
del to_chart, from_chart # unused
451+
assert from_chart.M == to_chart.M # noqa: S101
448452
r_ = p["r"]
449453
theta = uconvert_to_rad(p["theta"], usys)
450454
phi = uconvert_to_rad(p["phi"], usys)
@@ -485,7 +489,7 @@ def pt_map(
485489
'z': Array(0., dtype=float64, ...)}
486490
487491
"""
488-
del to_chart, from_chart # unused
492+
assert from_chart.M == to_chart.M # noqa: S101
489493
r_ = p["distance"]
490494
lon = uconvert_to_rad(p["lon"], usys)
491495
lat = uconvert_to_rad(p["lat"], usys)
@@ -529,7 +533,7 @@ def pt_map(
529533
{'x': Q(1.2246468e-16, 'm'), 'y': Q(0., 'm'), 'z': Q(2., 'm')}
530534
531535
"""
532-
del to_chart, from_chart # unused
536+
assert from_chart.M == to_chart.M # noqa: S101
533537
lon_coslat, r_ = p["lon_coslat"], p["distance"]
534538
lat = uconvert_to_rad(p["lat"], usys)
535539
# Handle the poles where cos(lat) == 0
@@ -570,7 +574,7 @@ def pt_map(
570574
{'x': Q(2., 'm'), 'y': Q(0., 'm'), 'z': Q(1.2246468e-16, 'm')}
571575
572576
"""
573-
del to_chart, from_chart # unused
577+
assert from_chart.M == to_chart.M # noqa: S101
574578
r_ = p["r"]
575579
theta = uconvert_to_rad(p["theta"], usys)
576580
phi = uconvert_to_rad(p["phi"], usys)
@@ -621,7 +625,7 @@ def pt_map(
621625
'z': Array(1.11803399, dtype=float64)}
622626
623627
"""
624-
del to_chart
628+
assert from_chart.M == to_chart.M # noqa: S101
625629
# Calculate cylindrical distance
626630
nu, mu = p["nu"], p["mu"]
627631
if not isinstance(nu, ABCQ) or not isinstance(mu, ABCQ):
@@ -667,7 +671,8 @@ def pt_map(
667671
'z': 5.0}
668672
669673
"""
670-
del to_chart, from_chart, usys # Unused
674+
del usys # Unused
675+
assert from_chart.M == to_chart.M # noqa: S101
671676
rho = jnp.hypot(p["x"], p["y"])
672677
phi = jnp.atan2(p["y"], p["x"])
673678
return {"rho": rho, "phi": phi, "z": p["z"]}
@@ -736,7 +741,7 @@ def pt_map(
736741
'phi': Array(0., dtype=float64, ...)}
737742
738743
"""
739-
del to_chart, from_chart # unused
744+
assert from_chart.M == to_chart.M # noqa: S101
740745
x, y, z = p["x"], p["y"], p["z"]
741746
r = jnp.sqrt(x**2 + y**2 + z**2)
742747
# Avoid division by zero: when r == 0, set theta = 0 by convention
@@ -776,7 +781,7 @@ def pt_map(
776781
'phi': 0}
777782
778783
"""
779-
del to_chart, from_chart # unused
784+
assert from_chart.M == to_chart.M # noqa: S101
780785
r_ = jnp.hypot(p["rho"], p["z"])
781786
# Avoid division by zero: when r == 0, set theta = 0 by convention
782787
theta = jnp.acos(jnp.where(r_ == 0, jnp.ones(r_.shape), p["z"] / r_))
@@ -814,7 +819,7 @@ def pt_map(
814819
'z': Array(1.2246468e-16, dtype=float64, ...)}
815820
816821
"""
817-
del to_chart, from_chart # unused
822+
assert from_chart.M == to_chart.M # noqa: S101
818823
theta = uconvert_to_rad(p["theta"], usys)
819824
rho = p["r"] * jnp.sin(theta)
820825
z = p["r"] * jnp.cos(theta)
@@ -849,7 +854,7 @@ def pt_map(
849854
{'lon': 0, 'lat': 1.5707963267948966, 'distance': 1.0}
850855
851856
"""
852-
del to_chart, from_chart # unused
857+
assert from_chart.M == to_chart.M # noqa: S101
853858
lat = (
854859
u.Q(90, "deg") if isinstance(p["theta"], ABCQ) else jnp.pi / 2
855860
) - uconvert_to_rad(p["theta"], usys)
@@ -886,7 +891,7 @@ def pt_map(
886891
'lat': 1.5707963267948966, 'distance': 1.0}
887892
888893
"""
889-
del to_chart, from_chart # unused
894+
assert from_chart.M == to_chart.M # noqa: S101
890895
lat = (
891896
u.Q(90, "deg") if isinstance(p["theta"], ABCQ) else jnp.pi / 2
892897
) - uconvert_to_rad(p["theta"], usys)
@@ -922,7 +927,8 @@ def pt_map(
922927
{'r': 1.0, 'theta': 60, 'phi': 30}
923928
924929
"""
925-
del to_chart, from_chart, usys # Unused
930+
del usys # Unused
931+
assert from_chart.M == to_chart.M # noqa: S101
926932
return {"r": p["r"], "theta": p["phi"], "phi": p["theta"]}
927933

928934

@@ -954,7 +960,8 @@ def pt_map(
954960
{'r': 1.0, 'theta': 30, 'phi': 60}
955961
956962
"""
957-
del to_chart, from_chart, usys # Unused
963+
del usys # Unused
964+
assert from_chart.M == to_chart.M # noqa: S101
958965
return {"r": p["r"], "theta": p["phi"], "phi": p["theta"]}
959966

960967

@@ -999,7 +1006,7 @@ def pt_map(
9991006
'z': Array(1.11803399, dtype=float64)}
10001007
10011008
"""
1002-
del to_chart # Unused
1009+
assert from_chart.M == to_chart.M # noqa: S101
10031010
nu, mu = p["nu"], p["mu"]
10041011
if not isinstance(nu, ABCQ) or not isinstance(mu, ABCQ):
10051012
if usys is None:
@@ -1068,7 +1075,7 @@ def pt_map(
10681075
'nu': Array(2.47920271, dtype=float64), 'phi': 0}
10691076
10701077
"""
1071-
del from_chart # Unused
1078+
assert from_chart.M == to_chart.M # noqa: S101
10721079
# Pre-compute common terms
10731080
R2 = p["rho"] ** 2
10741081
z2 = p["z"] ** 2
@@ -1156,6 +1163,7 @@ def pt_map(
11561163
'phi': Q(0., 'rad')}
11571164
11581165
"""
1166+
assert from_chart.M == to_chart.M # noqa: S101
11591167
# Cast to the result type
11601168
dtype = jnp.result_type(
11611169
to_chart.Delta, from_chart.Delta, *[v.dtype for v in p.values()]
@@ -1220,7 +1228,7 @@ def pt_map(
12201228
{'x': Q(1., 'm'), 'y': Q(2., 'm'), 'z': Q(3., 'm')}
12211229
12221230
"""
1223-
del from_chart # Unused
1231+
# assert from_chart.M == to_chart.M # TODO: CartND manifold
12241232

12251233
# If target is CartND, we can't convert (would be infinite recursion)
12261234
if isinstance(to_chart, CartND):
@@ -1304,7 +1312,7 @@ def pt_map(
13041312
{'q': Q([0., 0., 5.], 'm')}
13051313
13061314
"""
1307-
del to_chart # Unused
1315+
# assert from_chart.M == to_chart.M # TODO: CartND manifold
13081316

13091317
# If source is CartND, we can't convert (would be infinite recursion)
13101318
if isinstance(from_chart, CartND):
@@ -1392,7 +1400,8 @@ def pt_map(
13921400
True
13931401
13941402
"""
1395-
del to_chart, from_chart, usys # Unused
1403+
del usys # Unused
1404+
assert from_chart.M == to_chart.M # noqa: S101
13961405
return q
13971406

13981407

src/coordinax/_src/manifolds/register_charts.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import wadler_lindig as wl
1111

1212
import coordinax.api.charts as cxcapi
13-
from coordinax._src.base import AbstractAtlas, AbstractChart, AbstractManifold
13+
from coordinax._src.base import AbstractAtlas, AbstractChart
1414

1515
_ATLAS_MSG: Final[Callable[[AbstractAtlas, AbstractChart[Any, Any]], str]] = (
1616
lambda a, c: (
@@ -53,34 +53,3 @@ def pt_map(
5353

5454
# If charts are supported, delegate to ptm
5555
return cxcapi.pt_map(x, chart_from, chart_to, *args, **kwargs)
56-
57-
58-
# default route
59-
@plum.dispatch(precedence=-1) # ty: ignore[no-matching-overload]
60-
def pt_map(
61-
x: Any,
62-
M: AbstractManifold,
63-
chart_from: AbstractChart,
64-
chart_to: AbstractChart,
65-
*args: Any,
66-
**kwargs: Any,
67-
) -> Any:
68-
"""Transition map for points, checking the manifold's atlas.
69-
70-
>>> import coordinax.charts as cxc
71-
>>> import coordinax.manifolds as cxm
72-
73-
>>> M = cxm.EuclideanManifold(2)
74-
75-
>>> x = {"x": 1.0, "y": 1.0}
76-
>>> cxc.pt_map(x, M, cxc.cart2d, cxc.polar2d)
77-
{'r': Array(1.41421356, dtype=float64, ...),
78-
'theta': Array(0.78539816, dtype=float64, ...)}
79-
80-
>>> try: cxc.pt_map(x, M, cxc.cart2d, cxc.sph2)
81-
... except ValueError as e: print(e)
82-
Atlas EuclideanAtlas(ndim=2) does not support chart SphericalTwoSphere(M=Sn(2))
83-
84-
"""
85-
# Redispatch to the atlas
86-
return cxcapi.pt_map(x, M.atlas, chart_from, chart_to, *args, **kwargs)

0 commit comments

Comments
 (0)