From af07642c95f5bb6f80299ac918f29fa14161abe4 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:21:26 -0400 Subject: [PATCH 01/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(internal):?= =?UTF-8?q?=20rename=20QuantityMatrix=20=E2=86=92=20QMatrix=20and=20split?= =?UTF-8?q?=20into=20package?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit QuantityMatrix is renamed to QMatrix for brevity. The monolithic quantity_matrix.py module is supplemented by a new quantity_matrix/ package that splits the implementation into focused submodules (_quantity_matrix, _units_matrix, _det, _inv, _register_primitives, _utils). det and inv operations are added to the public surface. The .py file and the package coexist; Python resolves to the package at import time. --- src/coordinax/_src/internal/__init__.py | 2 +- .../_src/internal/quantity_matrix.py | 952 ++++++++++++++---- .../_src/internal/quantity_matrix/__init__.py | 45 + .../_src/internal/quantity_matrix/_det.py | 210 ++++ .../_src/internal/quantity_matrix/_inv.py | 193 ++++ .../quantity_matrix/_quantity_matrix.py | 423 ++++++++ .../quantity_matrix/_register_primitives.py | 808 +++++++++++++++ .../internal/quantity_matrix/_units_matrix.py | 380 +++++++ .../_src/internal/quantity_matrix/_utils.py | 31 + tests/unit/internal/test_quantity_matrix.py | 506 ++++++++-- 10 files changed, 3272 insertions(+), 278 deletions(-) create mode 100644 src/coordinax/_src/internal/quantity_matrix/__init__.py create mode 100644 src/coordinax/_src/internal/quantity_matrix/_det.py create mode 100644 src/coordinax/_src/internal/quantity_matrix/_inv.py create mode 100644 src/coordinax/_src/internal/quantity_matrix/_quantity_matrix.py create mode 100644 src/coordinax/_src/internal/quantity_matrix/_register_primitives.py create mode 100644 src/coordinax/_src/internal/quantity_matrix/_units_matrix.py create mode 100644 src/coordinax/_src/internal/quantity_matrix/_utils.py diff --git a/src/coordinax/_src/internal/__init__.py b/src/coordinax/_src/internal/__init__.py index d6d90b06..1009aea9 100644 --- a/src/coordinax/_src/internal/__init__.py +++ b/src/coordinax/_src/internal/__init__.py @@ -11,7 +11,7 @@ Contents: -- ``QuantityMatrix`` +- ``QMatrix`` An N-D quantity matrix/vector where every element carries its own unit. Supports both 1-D (vector) and 2-D (matrix) cases. Useful for Jacobians and metric tensors whose entries have diff --git a/src/coordinax/_src/internal/quantity_matrix.py b/src/coordinax/_src/internal/quantity_matrix.py index 196530dc..4aed3ca1 100644 --- a/src/coordinax/_src/internal/quantity_matrix.py +++ b/src/coordinax/_src/internal/quantity_matrix.py @@ -3,7 +3,7 @@ This module provides two closely related building blocks: - `UnitsMatrix`, an immutable nested tuple of units with indexing support -- `QuantityMatrix`, a quantity-like wrapper around one array plus a matching +- `QMatrix`, a quantity-like wrapper around one array plus a matching static `UnitsMatrix` The numeric payload is a single JAX array of shape ``(..., *shape)`` where the @@ -23,9 +23,20 @@ multiplicative scale factors. """ -__all__ = ("QuantityMatrix", "UnitsMatrix", "cdict_units") +__all__ = ( + "QMatrix", + "UnitsMatrix", + "cdict_units", + "det", + "det_p", + "inv", + "inv_p", +) +import functools as ft +import operator + from jaxtyping import Array, Shaped from typing import Any, NoReturn, TypeAlias, TypeVar, cast, final @@ -38,12 +49,13 @@ import plum import quax from jax import lax +from jax.extend import core as jexc +from jax.interpreters import ad as jax_ad, batching as jax_batching, mlir as jax_mlir import unxt as u from unxt.quantity import AllowValue -from .custom_types import CDict - +CDict: TypeAlias = dict[str, Any] _DMLS = u.unit("") @@ -64,10 +76,7 @@ def cdict_units(p: CDict, keys: tuple[str, ...], /) -> PackedUnitOutput: Non-quantity entries yield `None`, so the output tuple can be used for heterogeneous dictionaries containing both quantity and non-quantity data. - Examples - -------- >>> import unxt as u - >>> d = {'x': u.Q(1.0, 'm'), 'y': 2.0, 'z': u.Q(3.0, 'kg')} >>> cdict_units(d, ('x', 'y', 'z')) (Unit("m"), None, Unit("kg")) @@ -152,7 +161,7 @@ def _build_object_array(iterable: Any, /) -> np.ndarray: # noqa: C901 @final class UnitsMatrix: - """Immutable, hashable unit structure for `QuantityMatrix`. + """Immutable, hashable unit structure for `QMatrix`. `UnitsMatrix` wraps a numpy object array (``dtype=object``) of `~unxt.AbstractUnit` elements. Only 1-D and 2-D structures are accepted. @@ -243,6 +252,54 @@ def T(self) -> "UnitsMatrix": """ return UnitsMatrix(self._units.T) + def inverse(self) -> "UnitsMatrix": + r"""Inverse unit structure — each unit raised to the power -1. + + For a 1-D (diagonal) ``UnitsMatrix`` the inversion is done + entry-by-entry in *O(n)*, providing a speedup over the general 2-D + case. For a 2-D ``UnitsMatrix`` with a uniform unit (all entries + equal) the reciprocal is computed once and broadcast in *O(1)*; + mixed-unit 2-D structures fall back to an element-wise *O(nm)* loop. + + Examples + -------- + >>> from coordinax.internal import UnitsMatrix + + 1-D (diagonal) case — element-wise reciprocal: + + >>> UnitsMatrix(("m2", "s2")).inverse() + UnitsMatrix("(1 / m2, 1 / s2)") + + 2-D uniform-unit case: + + >>> UnitsMatrix((("m2", "m2"), ("m2", "m2"))).inverse() + UnitsMatrix("((1 / m2, 1 / m2), (1 / m2, 1 / m2))") + + 2-D mixed-unit case: + + >>> UnitsMatrix((("m2", "s2"), ("s2", "rad2"))).inverse() + UnitsMatrix("((1 / m2, 1 / s2), (1 / s2, 1 / rad2))") + + """ + inv_data = np.empty(self._units.shape, dtype=object) + if self._units.ndim == 1: + # Diagonal speedup: 1-D represents a diagonal metric's units. + for i in range(self._units.shape[0]): + inv_data[i] = self._units[i] ** (-1) + else: + # 2-D: fast path when all entries share the same unit. + flat = self._units.ravel() + first = flat[0] + if all(u == first for u in flat[1:]): + inv_unit = first ** (-1) + inv_data[:] = inv_unit + else: + n, m = self._units.shape + for i in range(n): + for j in range(m): + inv_data[i, j] = self._units[i, j] ** (-1) + return UnitsMatrix(inv_data) + # ── Serialization ───────────────────────────────────────────────── def to_tuple(self) -> UnitTree: @@ -351,12 +408,10 @@ def __getitem__(self, index: Any, /) -> Any: def unit(tuple_of_units: tuple[Any, ...], /) -> UnitsMatrix: """Convert a nested tuple of units into a ``UnitsMatrix``. - This allows users to specify units in a convenient nested tuple format - when constructing ``QuantityMatrix`` instances, and have them automatically + This allows users to specify units in a convenient nested tuple format when + constructing ``QMatrix`` instances, and have them automatically converted to the appropriate ``UnitsMatrix``. - Examples - -------- >>> import unxt as u 1D case: @@ -377,8 +432,6 @@ def unit(tuple_of_units: tuple[Any, ...], /) -> UnitsMatrix: def unit(arr: np.ndarray, /) -> UnitsMatrix: """Convert a numpy object array of units into a ``UnitsMatrix``. - Examples - -------- >>> import numpy as np >>> import unxt as u >>> from coordinax.internal import UnitsMatrix @@ -400,8 +453,6 @@ def unit(obj: UnitsMatrix, /) -> UnitsMatrix: def unit_of(obj: UnitsMatrix, /) -> UnitsMatrix: """Identity conversion for UnitsMatrix to itself. - Examples - -------- >>> import unxt as u >>> unit = u.unit(("m", "s", "kg")) >>> u.unit_of(unit) is unit @@ -412,13 +463,13 @@ def unit_of(obj: UnitsMatrix, /) -> UnitsMatrix: ############################################################################## -# QuantityMatrix +# QMatrix -class QuantityMatrix(u.AbstractQuantity): +class QMatrix(u.AbstractQuantity): """Quantity container whose elements may each carry different units. - `QuantityMatrix` stores one numeric array together with a static + `QMatrix` stores one numeric array together with a static `UnitsMatrix` describing the unit of each logical element. The shape of the unit structure determines whether the object behaves as a heterogeneous vector or matrix. @@ -441,39 +492,39 @@ class QuantityMatrix(u.AbstractQuantity): -------- >>> import jax.numpy as jnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix 1D case (vector): - >>> qv = QuantityMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg")) + >>> qv = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg")) >>> qv.value Array([1., 2., 3.], dtype=float64) >>> qv.unit.shape (3,) >>> 2 * qv - QuantityMatrix([2., 4., 6.], '(m, s, kg)') + QMatrix([2., 4., 6.], '(m, s, kg)') - >>> qv2 = QuantityMatrix(jnp.array([0.1, 200.0, 300.0]), unit=("km", "ms", "g")) + >>> qv2 = QMatrix(jnp.array([0.1, 200.0, 300.0]), unit=("km", "ms", "g")) >>> qv + qv2 - QuantityMatrix([101. , 2.2, 3.3], '(m, s, kg)') + QMatrix([101. , 2.2, 3.3], '(m, s, kg)') 2D case (matrix): - >>> qm = QuantityMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) + >>> qm = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) >>> qm.value.shape (2, 2) >>> qm.unit.shape (2, 2) >>> 2 * qm - QuantityMatrix([[2., 2.], + QMatrix([[2., 2.], [2., 2.]], '((m, s), (kg, rad))') - >>> qm2 = QuantityMatrix(jnp.array([[0.1, 200.0], [300.0, 0.5]]), + >>> qm2 = QMatrix(jnp.array([[0.1, 200.0], [300.0, 0.5]]), ... unit=(("km", "ms"), ("g", "deg"))) >>> qm + qm2 - QuantityMatrix([[101. , 1.2 ], + QMatrix([[101. , 1.2 ], [ 1.3 , 1.00872665]], '((m, s), (kg, rad))') Indexing: @@ -481,7 +532,7 @@ class QuantityMatrix(u.AbstractQuantity): >>> qv[0] Q(1., 'm') >>> qm[0] - QuantityMatrix([1., 1.], '(m, s)') + QMatrix([1., 1.], '(m, s)') >>> qm[1, 0] Q(1., 'kg') @@ -501,10 +552,8 @@ def shape(self) -> tuple[int, ...]: return self.value.shape @classmethod - def from_cdict( - cls, v: CDict, /, keys: tuple[str, ...] | None = None - ) -> "QuantityMatrix": - """Pack a component dictionary into a 1-D ``QuantityMatrix``. + def from_cdict(cls, v: CDict, /, keys: tuple[str, ...] | None = None) -> "QMatrix": + """Pack a component dictionary into a 1-D ``QMatrix``. Each value in *v* is stripped to its numeric value and stacked into a single JAX array. Values that carry units (``unxt.Quantity``) retain @@ -514,12 +563,12 @@ def from_cdict( Examples -------- >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix From a dictionary of quantities: >>> v = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "s"), "z": u.Q(3.0, "kg")} - >>> qv = QuantityMatrix.from_cdict(v) + >>> qv = QMatrix.from_cdict(v) >>> qv.unit.to_string() '(m, s, kg)' >>> qv.value @@ -527,7 +576,7 @@ def from_cdict( Selecting and reordering a subset of keys: - >>> qv2 = QuantityMatrix.from_cdict(v, keys=("z", "x")) + >>> qv2 = QMatrix.from_cdict(v, keys=("z", "x")) >>> qv2.unit.to_string() '(kg, m)' >>> qv2.value @@ -537,7 +586,7 @@ def from_cdict( >>> import jax.numpy as jnp >>> v2 = {"a": jnp.array(4.0), "b": u.Q(5.0, "m")} - >>> qv3 = QuantityMatrix.from_cdict(v2) + >>> qv3 = QMatrix.from_cdict(v2) >>> qv3.unit.to_string() '(, m)' @@ -548,33 +597,33 @@ def from_cdict( svs = jnp.stack([u.ustrip(AllowValue, unt, x) for x, unt in strict_zip(vs, us)]) return cls(svs, unit=UnitsMatrix(us)) - def __getitem__(self, index: Any, /) -> "u.Q | QuantityMatrix": # ty: ignore[invalid-method-override] - """Index into the QuantityMatrix to retrieve a specific element. + def __getitem__(self, index: Any, /) -> "u.Q | QMatrix": # ty: ignore[invalid-method-override] + """Index into the QMatrix to retrieve a specific element. Indexing a logical dimension returns a ``Quantity`` when the result is - a scalar unit, or a ``QuantityMatrix`` when the result still has + a scalar unit, or a ``QMatrix`` when the result still has structure. Examples -------- >>> import jax.numpy as jnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix **1-D vector** — indexing a single element returns a ``Quantity``: - >>> qv = QuantityMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg")) + >>> qv = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg")) >>> qv[0] Q(1., 'm') >>> qv[2] Q(3., 'kg') - **2-D matrix** — indexing a row returns a ``QuantityMatrix``: + **2-D matrix** — indexing a row returns a ``QMatrix``: - >>> qm = QuantityMatrix(jnp.ones((2, 3)), + >>> qm = QMatrix(jnp.ones((2, 3)), ... unit=(("m", "s", "kg"), ("rad", "deg", "m"))) >>> qm[0] - QuantityMatrix([1., 1., 1.], '(m, s, kg)') + QMatrix([1., 1., 1.], '(m, s, kg)') Indexing a specific element returns a ``Quantity``: @@ -585,7 +634,7 @@ def __getitem__(self, index: Any, /) -> "u.Q | QuantityMatrix": # ty: ignore[in value_item = self.value[index] unit_item = self.unit[index] if isinstance(unit_item, UnitsMatrix): - return QuantityMatrix(value=value_item, unit=unit_item) + return QMatrix(value=value_item, unit=unit_item) return u.Q(value_item, unit_item) # ── Quax API ───────────────────────────────────────────────────── @@ -594,33 +643,33 @@ def aval(self) -> jax.core.ShapedArray: return jax.core.ShapedArray(self.value.shape, self.value.dtype) def materialise(self) -> NoReturn: - msg = "Refusing to materialise `QuantityMatrix`." + msg = "Refusing to materialise `QMatrix`." raise RuntimeError(msg) - def diag(self) -> "QuantityMatrix": - """Return a 1-D ``QuantityMatrix`` containing the diagonal of this matrix. + def diag(self) -> "QMatrix": + """Return a 1-D ``QMatrix`` containing the diagonal of this matrix. Unlike ``qnp.diag``, this method operates directly on the static ``unit`` structure and the raw value array, so it works correctly under ``jax.jit`` and with heterogeneous-unit matrices. - Only supported for 2-D ``QuantityMatrix`` objects. + Only supported for 2-D ``QMatrix`` objects. Returns ------- - QuantityMatrix - 1-D ``QuantityMatrix`` of length ``min(n_rows, n_cols)`` whose + QMatrix + 1-D ``QMatrix`` of length ``min(n_rows, n_cols)`` whose ``unit[i]`` is ``self.unit[i, i]`` and whose ``value[..., i]`` is ``self.value[..., i, i]``. Examples -------- >>> import jax.numpy as jnp - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix Uniform units: - >>> A = QuantityMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), + >>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), ... unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m"))) >>> d = A.diag() >>> d.unit.shape @@ -630,7 +679,7 @@ def diag(self) -> "QuantityMatrix": Heterogeneous units — works under jit: - >>> B = QuantityMatrix(jnp.diag(jnp.array([1.0, 2.0, 3.0])), + >>> B = QMatrix(jnp.diag(jnp.array([1.0, 2.0, 3.0])), ... unit=(("m", "s", "kg"), ... ("m", "s", "kg"), ... ("m", "s", "kg"))) @@ -642,27 +691,27 @@ def diag(self) -> "QuantityMatrix": """ if self.ndim != 2: - msg = f"QuantityMatrix.diag() requires a 2D matrix, got ndim={self.ndim}" + msg = f"QMatrix.diag() requires a 2D matrix, got ndim={self.ndim}" raise ValueError(msg) n = min(self.shape[-2], self.shape[-1]) diag_value = jnp.stack([self.value[..., i, i] for i in range(n)], axis=-1) diag_unit = UnitsMatrix(self.unit._units.diagonal()) - return QuantityMatrix(value=diag_value, unit=diag_unit) + return QMatrix(value=diag_value, unit=diag_unit) @property - def T(self) -> "QuantityMatrix": - """Transpose a 2-D ``QuantityMatrix`` (swap rows/columns and units). + def T(self) -> "QMatrix": + """Transpose a 2-D ``QMatrix`` (swap rows/columns and units). - Returns a new ``QuantityMatrix`` whose value array and unit structure + Returns a new ``QMatrix`` whose value array and unit structure are both transposed. Only 2-D matrices are supported. Examples -------- >>> import jax.numpy as jnp >>> import quaxed.numpy as qnp - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix - >>> a = QuantityMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + >>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), ... unit=(("m", "s"), ("kg", "rad"))) >>> aT = a.T >>> aT.value @@ -682,9 +731,9 @@ def T(self) -> "QuantityMatrix": """ if self.ndim != 2: - msg = f"QuantityMatrix.T requires a 2-D matrix, got ndim={self.ndim}" + msg = f"QMatrix.T requires a 2-D matrix, got ndim={self.ndim}" raise ValueError(msg) - return QuantityMatrix(value=jnp.swapaxes(self.value, -2, -1), unit=self.unit.T) + return QMatrix(value=jnp.swapaxes(self.value, -2, -1), unit=self.unit.T) def _convert_value_vector( @@ -734,48 +783,44 @@ def _convert_value_matrix( ) -@plum.conversion_method(type_from=QuantityMatrix, type_to=u.Quantity) -def quantitymatrix_to_quantity(x: QuantityMatrix, /) -> u.Quantity: - """Convert a ``QuantityMatrix`` to a regular ``Quantity``. +@plum.conversion_method(type_from=QMatrix, type_to=u.Q) +def QMatrix_to_quantity(x: QMatrix, /) -> u.Q: + """Convert a ``QMatrix`` to a regular ``Quantity``. - Conversion is only valid when all elements of ``x`` share the same unit. - If units are heterogeneous, this conversion is ambiguous and raises + Conversion is only valid when all elements of ``x`` share the same unit. If + units are heterogeneous, this conversion is ambiguous and raises ``ValueError``. - Examples - -------- >>> import plum >>> import jax.numpy as jnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix Uniform units convert to a plain quantity: - >>> qmat = QuantityMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "m", "m")) - >>> plum.convert(qmat, u.Quantity) + >>> qmat = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "m", "m")) + >>> plum.convert(qmat, u.Q) Q([1., 2., 3.], 'm') Mixed units are rejected: - >>> bad = QuantityMatrix(jnp.array([1.0, 2.0]), unit=("m", "s")) - >>> plum.convert(bad, u.Quantity) + >>> bad = QMatrix(jnp.array([1.0, 2.0]), unit=("m", "s")) + >>> plum.convert(bad, u.Q) Traceback (most recent call last): ... - ValueError: Cannot convert QuantityMatrix to Quantity unless all units are + ValueError: Cannot convert QMatrix to Quantity unless all units are identical. """ units = jtu.tree_leaves(x.unit.to_tuple()) if not units: - msg = "Cannot convert QuantityMatrix with no unit entries." + msg = "Cannot convert QMatrix with no unit entries." raise ValueError(msg) first = units[0] if any(unit != first for unit in units[1:]): - msg = ( - "Cannot convert QuantityMatrix to Quantity unless all units are identical." - ) + msg = "Cannot convert QMatrix to Quantity unless all units are identical." raise ValueError(msg) return u.Q(x.value, first) @@ -803,22 +848,19 @@ def _convert_value( @plum.dispatch -def uconvert(to_units: UnitsMatrix, x: QuantityMatrix, /) -> QuantityMatrix: - """Convert a ``QuantityMatrix`` to different (but compatible) units. +def uconvert(to_units: UnitsMatrix, x: QMatrix, /) -> QMatrix: + """Convert a ``QMatrix`` to different (but compatible) units. - Unlike the generic astropy ``StructuredUnit.to()`` path, this dispatch - uses ``_convert_value`` directly so that the regular 2D JAX array in - ``x.value`` is converted element-by-element without requiring a numpy - structured array. + Unlike the generic astropy ``StructuredUnit.to()`` path, this dispatch uses + ``_convert_value`` directly so that the regular 2D JAX array in ``x.value`` + is converted element-by-element without requiring a numpy structured array. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix >>> x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) - >>> q = QuantityMatrix(x, (("m", "rad"), ("m", "rad"))) + >>> q = QMatrix(x, (("m", "rad"), ("m", "rad"))) >>> target = u.unit((("km", "deg"), ("km", "deg"))) >>> q.uconvert(target).unit.to_string() '((km, deg), (km, deg))' @@ -827,29 +869,27 @@ def uconvert(to_units: UnitsMatrix, x: QuantityMatrix, /) -> QuantityMatrix: if x.unit == to_units: return x value = _convert_value(x.value, x.unit, to_units) - return QuantityMatrix(value=value, unit=to_units) + return QMatrix(value=value, unit=to_units) @quax.register(lax.add_p) -def add_qm_qm(x: QuantityMatrix, y: QuantityMatrix, /) -> QuantityMatrix: - """Element-wise addition of two `QuantityMatrix` objects. +def add_qm_qm(x: QMatrix, y: QMatrix, /) -> QMatrix: + """Element-wise addition of two `QMatrix` objects. - The result adopts the units of *x*. Each element is converted - from ``y.unit`` → ``x.unit`` before the numeric add. + The result adopts the units of *x*. Each element is converted from + ``y.unit`` → ``x.unit`` before the numeric add. Works for both 1D (vector) and 2D (matrix) cases. - Examples - -------- >>> import jax.numpy as jnp >>> import quaxed.numpy as qnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix 2D case: - >>> a = QuantityMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) - >>> b = QuantityMatrix(jnp.ones((2, 2)), unit=(("km", "ms"), ("g", "deg"))) + >>> a = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) + >>> b = QMatrix(jnp.ones((2, 2)), unit=(("km", "ms"), ("g", "deg"))) >>> result = qnp.add(a, b) >>> result.unit.to_string() @@ -861,8 +901,8 @@ def add_qm_qm(x: QuantityMatrix, y: QuantityMatrix, /) -> QuantityMatrix: 1D case: - >>> a1d = QuantityMatrix(jnp.ones(3), unit=("m", "s", "kg")) - >>> b1d = QuantityMatrix(jnp.ones(3), unit=("km", "ms", "g")) + >>> a1d = QMatrix(jnp.ones(3), unit=("m", "s", "kg")) + >>> b1d = QMatrix(jnp.ones(3), unit=("km", "ms", "g")) >>> result1d = qnp.add(a1d, b1d) >>> result1d.unit.to_string() @@ -873,40 +913,38 @@ def add_qm_qm(x: QuantityMatrix, y: QuantityMatrix, /) -> QuantityMatrix: """ y_converted = _convert_value(y.value, y.unit, x.unit) - return QuantityMatrix(value=lax.add(x.value, y_converted), unit=x.unit) + return QMatrix(value=lax.add(x.value, y_converted), unit=x.unit) @quax.register(lax.dot_general_p) def dot_general_qm_qm( - lhs: QuantityMatrix, - rhs: QuantityMatrix, + lhs: QMatrix, + rhs: QMatrix, /, *, dimension_numbers: lax.DotDimensionNumbers, precision: Any = None, preferred_element_type: Any = None, **kw: Any, -) -> QuantityMatrix | u.Quantity: - """Dot product / matrix multiply two `QuantityMatrix` objects. +) -> QMatrix | u.Q: + """Dot product / matrix multiply two `QMatrix` objects. Delegates to specialized implementations based on the dimensionality: - 1D @ 1D → scalar (vector dot product) - 2D @ 2D → 2D (matrix-matrix multiply) For the standard matmul contraction: contracting_dims = ((-1,), (-2,)), - with no batch dims (batch is handled by leading dims in QuantityMatrix). + with no batch dims (batch is handled by leading dims in QMatrix). - Examples - -------- >>> import jax.numpy as jnp >>> import quaxed.numpy as qnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix 1D @ 1D (dot product): - >>> v1 = QuantityMatrix(jnp.array([1.0, 2.0]), unit=("m", "km")) - >>> v2 = QuantityMatrix(jnp.array([3.0, 4.0]), unit=("s", "s")) + >>> v1 = QMatrix(jnp.array([1.0, 2.0]), unit=("m", "km")) + >>> v2 = QMatrix(jnp.array([3.0, 4.0]), unit=("s", "s")) >>> result = qnp.dot(v1, v2) >>> result.value Array(8003., dtype=float64) @@ -915,9 +953,9 @@ def dot_general_qm_qm( 2D @ 2D (matrix multiply): - >>> a = QuantityMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + >>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), ... unit=(("m", "km"), ("m", "km"))) - >>> b = QuantityMatrix(jnp.array([[1.0, 0.0], [0.0, 1.0]]), + >>> b = QMatrix(jnp.array([[1.0, 0.0], [0.0, 1.0]]), ... unit=(("s", "s"), ("s", "s"))) >>> c = qnp.matmul(a, b) @@ -966,19 +1004,187 @@ def dot_general_qm_qm( raise NotImplementedError(msg) +@quax.register(lax.dot_general_p) +def dot_general_qm_arr( + lhs: QMatrix, + rhs: jax.Array, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> "QMatrix | u.Q": + """Dot product of a :class:`QMatrix` with a plain JAX array. + + The plain array is treated as dimensionless. Delegates to + :func:`dot_general_qm_qm` after wrapping ``rhs`` in a dimensionless + :class:`QMatrix`. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix, UnitsMatrix + + 2D metric x 1D plain vector: + + >>> g = QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=UnitsMatrix((("m2", "m2"), ("m2", "m2"))), + ... ) + >>> v = jnp.array([1.0, 1.0]) + >>> w = qnp.matmul(g, v) + >>> w.unit.to_string() + '(m2, m2)' + >>> w.value + Array([2., 3.], dtype=float64) + + """ + if rhs.ndim == 1: + n = rhs.shape[0] + rhs_qm = QMatrix(rhs, unit=UnitsMatrix(tuple(_DMLS for _ in range(n)))) + else: + nr, nc = rhs.shape[-2], rhs.shape[-1] + rhs_qm = QMatrix( + rhs, + unit=UnitsMatrix(tuple(tuple(_DMLS for _ in range(nc)) for _ in range(nr))), + ) + return dot_general_qm_qm( + lhs, + rhs_qm, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + + +@quax.register(lax.dot_general_p) +def dot_general_qm_qty( + lhs: QMatrix, + rhs: u.AbstractQuantity, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> "QMatrix | u.Q": + """Dot product of a :class:`QMatrix` with a :class:`~unxt.AbstractQuantity`. + + The Quantity carries a single scalar unit that applies uniformly to all + elements. The ``rhs`` is wrapped as a uniform-unit + :class:`QMatrix` and delegated to :func:`dot_general_qm_qm`. + + Note that :class:`QMatrix` is itself a subtype of + :class:`~unxt.AbstractQuantity`, so :func:`dot_general_qm_qm` takes + precedence when both sides are :class:`QMatrix`. + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix, UnitsMatrix + + 2D metric with units @ uniform-unit Quantity vector: + + >>> g = QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=UnitsMatrix((("m2", "m2"), ("m2", "m2"))), + ... ) + >>> v = u.Q(jnp.array([1.0, 1.0]), "m") + >>> w = qnp.matmul(g, v) + >>> w.unit.to_string() + '(m3, m3)' + >>> w.value + Array([2., 3.], dtype=float64) + + """ + rhs_unit = u.unit_of(rhs) + rhs_val = u.ustrip(AllowValue, rhs_unit, rhs) + if rhs_val.ndim == 1: + n = rhs_val.shape[0] + rhs_qm = QMatrix(rhs_val, unit=UnitsMatrix(tuple(rhs_unit for _ in range(n)))) + else: + nr, nc = rhs_val.shape[-2], rhs_val.shape[-1] + rhs_qm = QMatrix( + rhs_val, + unit=UnitsMatrix( + tuple(tuple(rhs_unit for _ in range(nc)) for _ in range(nr)) + ), + ) + return dot_general_qm_qm( + lhs, + rhs_qm, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + + +@quax.register(lax.dot_general_p) +def dot_general_arr_qm( + lhs: jax.Array, + rhs: QMatrix, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> "QMatrix | u.Q": + """Dot product of a plain JAX array with a :class:`QMatrix`. + + The plain array is treated as dimensionless. Delegates to + :func:`dot_general_qm_qm` after wrapping ``lhs`` in a dimensionless + :class:`QMatrix`. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix + + Dimensionless identity @ QMatrix vector: + + >>> A = jnp.eye(2, dtype=jnp.float64) + >>> v = QMatrix(jnp.array([2.0, 3.0]), unit=("m / s", "m / s")) + >>> w = qnp.matmul(A, v) + >>> w.unit.to_string() + '(m / s, m / s)' + >>> w.value + Array([2., 3.], dtype=float64) + + """ + if lhs.ndim == 1: + n = lhs.shape[0] + lhs_qm = QMatrix(lhs, unit=UnitsMatrix(tuple(_DMLS for _ in range(n)))) + else: + nr, nc = lhs.shape[-2], lhs.shape[-1] + lhs_qm = QMatrix( + lhs, + unit=UnitsMatrix(tuple(tuple(_DMLS for _ in range(nc)) for _ in range(nr))), + ) + return dot_general_qm_qm( + lhs_qm, + rhs, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + + vec_uconvert_value = np.vectorize(u.uconvert_value) def _dot_general_1d_1d( - lhs: QuantityMatrix, - rhs: QuantityMatrix, + lhs: QMatrix, + rhs: QMatrix, /, *, dimension_numbers: lax.DotDimensionNumbers, precision: Any = None, preferred_element_type: Any = None, **kw: Any, -) -> u.Quantity: +) -> u.Q: """Vector dot product: (N,) @ (N,) → scalar. Result = Σ_i lhs[i] * rhs[i] @@ -999,51 +1205,49 @@ def _dot_general_1d_1d( # Compute dot product with rescaling result_value = jnp.sum(scales * lhs.value * rhs.value, axis=-1) - return u.Quantity(result_value, ref_unit) + return u.Q(result_value, ref_unit) def _dot_general_2d_1d( - lhs: QuantityMatrix, - rhs: QuantityMatrix, + lhs: QMatrix, + rhs: QMatrix, /, *, dimension_numbers: lax.DotDimensionNumbers, precision: Any = None, preferred_element_type: Any = None, **kw: Any, -) -> QuantityMatrix: +) -> QMatrix: """Matrix-vector multiply: (N, K) @ (K,) → (N,). For ``w = A @ v`` where ``A`` is ``(N, K)`` and ``v`` is ``(K,)``: ``w[i] = Σ_j A[i, j] * v[j]`` - Each product ``A[i,j] * v[j]`` has unit ``A.unit[i][j] * v.unit[j]``. - All ``K`` terms in the sum for output row ``i`` must be unit-compatible. - We convert every term to the unit of the *first* term (``j = 0``) for - each output row ``i``: ``ref[i] = A.unit[i][0] * v.unit[0]``. + Each product ``A[i,j] * v[j]`` has unit ``A.unit[i][j] * v.unit[j]``. All + ``K`` terms in the sum for output row ``i`` must be unit-compatible. We + convert every term to the unit of the *first* term (``j = 0``) for each + output row ``i``: ``ref[i] = A.unit[i][0] * v.unit[0]``. - Examples - -------- >>> import jax.numpy as jnp >>> import quaxed.numpy as qnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix Identity matrix times a vector: - >>> A = QuantityMatrix(jnp.eye(3, dtype=jnp.float64), + >>> A = QMatrix(jnp.eye(3, dtype=jnp.float64), ... unit=(("", "", ""), ("", "", ""), ("", "", ""))) - >>> v = QuantityMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "m", "m")) + >>> v = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "m", "m")) >>> w = qnp.matmul(A, v) >>> w.value Array([1., 2., 3.], dtype=float64) Mixed units on contraction axis (km column converted to m): - >>> A2 = QuantityMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + >>> A2 = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), ... unit=(("m", "km"), ("m", "km"))) - >>> v2 = QuantityMatrix(jnp.array([1.0, 1.0]), unit=("s", "s")) + >>> v2 = QMatrix(jnp.array([1.0, 1.0]), unit=("s", "s")) >>> w2 = qnp.matmul(A2, v2) >>> w2.value Array([2001., 4003.], dtype=float64) @@ -1070,22 +1274,22 @@ def _dot_general_2d_1d( # w[..., i] = Σ_j scale[i, j] * A[..., i, j] * v[..., j] accum = jnp.einsum("ij,...ij,...j->...i", scale_2d, lhs.value, rhs.value) - return QuantityMatrix(value=accum, unit=out_unit) + return QMatrix(value=accum, unit=out_unit) vec_uconvert_value = np.vectorize(u.uconvert_value) def _dot_general_2d_2d( - lhs: QuantityMatrix, - rhs: QuantityMatrix, + lhs: QMatrix, + rhs: QMatrix, /, *, dimension_numbers: lax.DotDimensionNumbers, precision: Any = None, preferred_element_type: Any = None, **kw: Any, -) -> QuantityMatrix: +) -> QMatrix: """Matrix multiply: (N, K) @ (K, M) → (N, M). For ``C = A @ B`` where ``A`` is ``(N, K)`` and ``B`` is ``(K, M)``: @@ -1152,29 +1356,27 @@ def _dot_general_2d_2d( axis=-2, ) - return QuantityMatrix(value=accum, unit=out_unit) + return QMatrix(value=accum, unit=out_unit) @quax.register(lax.sub_p) -def sub_qm_qm(x: QuantityMatrix, y: QuantityMatrix, /) -> QuantityMatrix: - """Element-wise subtraction of two `QuantityMatrix` objects. +def sub_qm_qm(x: QMatrix, y: QMatrix, /) -> QMatrix: + """Element-wise subtraction of two `QMatrix` objects. - The result adopts the units of *x*. Each element is converted - from ``y.unit`` → ``x.unit`` before the numeric subtract. + The result adopts the units of *x*. Each element is converted from + ``y.unit`` → ``x.unit`` before the numeric subtract. Works for both 1D (vector) and 2D (matrix) cases. - Examples - -------- >>> import jax.numpy as jnp >>> import quaxed.numpy as qnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix 2D case: - >>> a = QuantityMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) - >>> b = QuantityMatrix( + >>> a = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) + >>> b = QMatrix( ... value=jnp.ones((2, 2)), ... unit=(("km", u.unit("ms")), (u.unit("g"), u.unit("deg")))) @@ -1188,9 +1390,9 @@ def sub_qm_qm(x: QuantityMatrix, y: QuantityMatrix, /) -> QuantityMatrix: 1D case: - >>> a1d = QuantityMatrix(value=jnp.ones(3), + >>> a1d = QMatrix(value=jnp.ones(3), ... unit=("m", "s", "kg")) - >>> b1d = QuantityMatrix(value=jnp.ones(3), + >>> b1d = QMatrix(value=jnp.ones(3), ... unit=("km", u.unit("ms"), u.unit("g"))) >>> result1d = qnp.subtract(a1d, b1d) @@ -1202,29 +1404,25 @@ def sub_qm_qm(x: QuantityMatrix, y: QuantityMatrix, /) -> QuantityMatrix: """ y_converted = _convert_value(y.value, y.unit, x.unit) - return QuantityMatrix(value=lax.sub(x.value, y_converted), unit=x.unit) + return QMatrix(value=lax.sub(x.value, y_converted), unit=x.unit) @quax.register(lax.transpose_p) -def transpose_qm( - x: QuantityMatrix, /, *, permutation: tuple[int, ...] -) -> QuantityMatrix: - """Transpose a ``QuantityMatrix``, swapping only the last two (matrix) axes. +def transpose_qm(x: QMatrix, /, *, permutation: tuple[int, ...]) -> QMatrix: + """Transpose a ``QMatrix``, swapping only the last two (matrix) axes. Leading batch dimensions must be preserved unchanged. Only permutations that swap the last two axes while keeping all batch axes in place are supported, because the unit structure is purely 2-D and cannot represent arbitrary axis re-orderings. - Examples - -------- >>> import jax.numpy as jnp >>> import quaxed.numpy as qnp - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix 2-D (no batch): - >>> a = QuantityMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + >>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), ... unit=(("m", "s"), ("kg", "rad"))) >>> aT = qnp.matrix_transpose(a) >>> aT.value @@ -1236,7 +1434,7 @@ def transpose_qm( Batched ``(B, N, M)`` — batch axis is preserved: >>> import jax - >>> b = QuantityMatrix(jnp.ones((3, 2, 2)), + >>> b = QMatrix(jnp.ones((3, 2, 2)), ... unit=(("m", "s"), ("kg", "rad"))) >>> bT = qnp.matrix_transpose(b) >>> bT.shape @@ -1256,7 +1454,7 @@ def transpose_qm( ) raise NotImplementedError(msg) transposed_value = lax.transpose(x.value, permutation) - return QuantityMatrix(value=transposed_value, unit=x.unit.T) + return QMatrix(value=transposed_value, unit=x.unit.T) def _jit_fallback_uniform_unit(units: UnitsMatrix, out_size: int) -> UnitsMatrix: @@ -1269,9 +1467,9 @@ def _jit_fallback_uniform_unit(units: UnitsMatrix, out_size: int) -> UnitsMatrix first = all_units[0] if any(u_i != first for u_i in all_units[1:]): msg = ( - "QuantityMatrix gather (e.g. jnp.diag) under jit requires all units " + "QMatrix gather (e.g. jnp.diag) under jit requires all units " "to be equal when indices cannot be concretized. " - "Call eagerly (outside jit) for heterogeneous-unit QuantityMatrix." + "Call eagerly (outside jit) for heterogeneous-unit QMatrix." ) raise ValueError(msg) return UnitsMatrix(np.full((out_size,), first, dtype=object)) @@ -1279,7 +1477,7 @@ def _jit_fallback_uniform_unit(units: UnitsMatrix, out_size: int) -> UnitsMatrix @quax.register(lax.gather_p) def gather_qm( - x: QuantityMatrix, + x: QMatrix, start_indices: jax.Array, /, *, @@ -1290,17 +1488,17 @@ def gather_qm( fill_value: Any = None, unique_indices: bool = False, **kwargs: Any, -) -> QuantityMatrix: - """Handle element-selection gathers (e.g. ``jnp.diag``) for ``QuantityMatrix``. +) -> QMatrix: + """Handle element-selection gathers (e.g. ``jnp.diag``) for ``QMatrix``. Supports only *element-selection* gathers where every input dimension is collapsed (``offset_dims == ()`` and all ``slice_sizes == 1``). This covers ``jnp.diag``, ``jnp.diagonal``, and integer-array fancy indexing on - ``QuantityMatrix`` objects. + ``QMatrix`` objects. Unit extraction: - ``QuantityMatrix.unit`` is declared ``static=True`` and is therefore always + ``QMatrix.unit`` is declared ``static=True`` and is therefore always a concrete Python object, even inside ``jax.jit``. The *indices*, however, are traced under JIT and cannot be read concretely. Because JAX's ``jnp.diag`` uses ``platform_dependent`` internally, quax always traces @@ -1308,14 +1506,12 @@ def gather_qm( unit resolution. Consequently, all units in the input must be equal; heterogeneous-unit inputs raise ``ValueError``. - Examples - -------- >>> import jax.numpy as jnp - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix Diagonal of a 3x3 dimensionless matrix: - >>> A = QuantityMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), + >>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), ... unit=(("", "", ""), ("", "", ""), ("", "", ""))) >>> d = A.diag() >>> d.unit.shape @@ -1358,7 +1554,7 @@ def gather_qm( ) if not is_element_selection: msg = ( - "QuantityMatrix: only element-selection gathers (all input dims " + "QMatrix: only element-selection gathers (all input dims " "collapsed, all slice_sizes == 1) are supported. " f"Got offset_dims={dimension_numbers.offset_dims}, " f"collapsed_slice_dims={dimension_numbers.collapsed_slice_dims}, " @@ -1380,14 +1576,12 @@ def gather_qm( else: # x.unit.ndim == 2 out_unit = UnitsMatrix(x.unit._units[idx_np[:, 0], idx_np[:, 1]]) - return QuantityMatrix(value=result_value, unit=out_unit) + return QMatrix(value=result_value, unit=out_unit) @quax.register(lax.reduce_sum_p) -def reduce_sum_p_qm( - operand: QuantityMatrix, /, *, axes: Any, **kwargs: Any -) -> QuantityMatrix: - """Handle ``lax.reduce_sum`` for ``QuantityMatrix``. +def reduce_sum_p_qm(operand: QMatrix, /, *, axes: Any, **kwargs: Any) -> QMatrix: + """Handle ``lax.reduce_sum`` for ``QMatrix``. ``jnp.diag`` on a square 2-D matrix uses ``platform_dependent`` which traces *both* the default (gather-based) and Mosaic implementation. The Mosaic @@ -1399,22 +1593,20 @@ def reduce_sum_p_qm( Unit reduction rule: - When reducing a 2-D ``QuantityMatrix`` along ``axes=(0,)`` (rows): the + When reducing a 2-D ``QMatrix`` along ``axes=(0,)`` (rows): the output unit for column *j* is taken from ``operand.unit[0, j]`` (the first row). All elements being summed along a column must be unit-compatible for the sum to be physically meaningful. - Analogously for ``axes=(1,)`` (column reduction), the output unit - for row *i* is ``operand.unit[i, 0]``. + Analogously for ``axes=(1,)`` (column reduction), the output unit for row + *i* is ``operand.unit[i, 0]``. - Examples - -------- >>> import jax.numpy as jnp - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix - ``QuantityMatrix.diag()`` on a 3x3 uniform-unit matrix: + ``QMatrix.diag()`` on a 3x3 uniform-unit matrix: - >>> A = QuantityMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), + >>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), ... unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m"))) >>> d = A.diag() >>> d.unit.shape @@ -1435,13 +1627,393 @@ def reduce_sum_p_qm( # Column reduction → 1-D output; unit = first column's units. out_unit = UnitsMatrix(operand.unit._units[:, 0]) else: - msg = f"reduce_sum_p_qm: unsupported axes={axes} for 2-D QuantityMatrix." + msg = f"reduce_sum_p_qm: unsupported axes={axes} for 2-D QMatrix." raise NotImplementedError(msg) else: msg = ( - f"reduce_sum_p_qm: only 2-D QuantityMatrix is supported, " - f"got ndim={operand.ndim}." + f"reduce_sum_p_qm: only 2-D QMatrix is supported, got ndim={operand.ndim}." ) raise NotImplementedError(msg) - return QuantityMatrix(value=result_value, unit=out_unit) + return QMatrix(value=result_value, unit=out_unit) + + +############################################################################## +# Custom det_p JAX primitive +# +# JAX has no built-in `det` primitive. `jnp.linalg.det` decomposes into +# arithmetic (2×2 / 3×3) or `lu_p` + log/exp (larger matrices). We define +# `det_p` here so that Quax can intercept determinant calls on `QMatrix` +# objects. +# +# Supported transforms: +# - JIT via MLIR lowering that delegates to `jnp.linalg.det` +# - Forward-mode autodiff (JVP): d(det A)(dA) = det(A) · tr(A⁻¹ dA) +# - Reverse-mode autodiff (VJP): derived automatically from JVP via +# transposition — no explicit transpose rule needed because the JVP tangent +# only uses existing primitives (linalg.solve, trace, mul) that already +# carry transpose rules. +# - Batching (vmap): move the batch axis to the front and call det_p; the MLIR +# lowering handles any (*batch, n, n) shape natively. + +det_p = jexc.Primitive("det") +det_p.multiple_results = False + + +def det(x: Array, /) -> Array: + """Compute the determinant of a square matrix via the ``det_p`` primitive. + + Delegates to ``det_p``, a custom JAX primitive that supports JIT, + forward and reverse differentiation, and batching (vmap). + + For plain arrays the result is a bare :class:`~jaxtyping.Array`. + For :class:`~coordinax.internal.QMatrix` inputs the Quax + dispatch intercepts the call (see ``_det_p_QMatrix``) and + returns a :class:`~unxt.AbstractQuantity`. + + Parameters + ---------- + x : Array, shape ``(*batch, n, n)`` + Square matrix or batch of square matrices. + + Returns + ------- + Array, shape ``(*batch,)`` + Determinant of each matrix. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.internal.quantity_matrix import det + + Plain 2x2 diagonal matrix: + + >>> det(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + Array(6., dtype=float64) + + Under JIT: + + >>> import jax + >>> jax.jit(det)(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + Array(6., dtype=float64) + + Gradient (via reverse-mode autodiff): + + >>> jax.grad(det)(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + Array([[3., 0.], + [0., 2.]], dtype=float64) + + Batched (vmap): + + >>> A = jnp.stack([jnp.diag(jnp.array([2.0, 3.0])), + ... jnp.diag(jnp.array([4.0, 5.0]))]) + >>> jax.vmap(det)(A) + Array([ 6., 20.], dtype=float64) + + """ + return det_p.bind(x) + + +# ── 1. Primal evaluation rule (eager / concrete values) ────────────────── + + +def _det_impl(x: Array, /) -> Array: + return jnp.linalg.det(x) + + +det_p.def_impl(_det_impl) + + +# ── 2. Abstract evaluation rule (shape / dtype inference for JIT) ──────── + + +def _det_abstract_eval(x: "jax.core.ShapedArray", /) -> "jax.core.ShapedArray": + if x.ndim < 2: + msg = f"det_p requires at least 2-D input, got ndim={x.ndim}" + raise ValueError(msg) + if x.shape[-1] != x.shape[-2]: + msg = ( + f"det_p requires a square matrix " + f"(shape[-2] == shape[-1]), got shape={x.shape}" + ) + raise ValueError(msg) + # (*batch, n, n) → (*batch,) + return x.update(shape=x.shape[:-2]) + + +det_p.def_abstract_eval(_det_abstract_eval) + + +# ── 3. MLIR / XLA lowering (JIT compilation) ───────────────────────────── + +jax_mlir.register_lowering( + det_p, + jax_mlir.lower_fun(_det_impl, multiple_results=False), +) + + +# ── 4. Forward-mode differentiation (JVP) ──────────────────────────────── +# +# Jacobi's formula: d(det A)(dA) = det(A) · tr(A⁻¹ dA) +# +# We use `jnp.linalg.solve(A, dA)` to compute A⁻¹ dA without explicitly +# forming the inverse — more numerically stable. The tangent uses only +# existing JAX primitives (solve, trace, scalar mul), so JAX derives the +# reverse-mode (VJP) rule automatically via transposition; no explicit +# `ad.primitive_transposes` registration is needed. + + +def _det_jvp(primals: tuple, tangents: tuple) -> tuple: + (x,) = primals + (dx,) = tangents + primal_out = det_p.bind(x) + if type(dx) is jax_ad.Zero: + tangent_out = lax.full_like(primal_out, 0.0) + else: + # tr(A⁻¹ dA) via solve — avoids explicit matrix inversion + tangent_out = primal_out * jnp.trace(jnp.linalg.solve(x, dx)) + return primal_out, tangent_out + + +jax_ad.primitive_jvps[det_p] = _det_jvp + + +# ── 5. Batching rule (vmap) ─────────────────────────────────────────────── +# +# `jnp.linalg.det` (used in the MLIR lowering) already operates correctly +# on batched (*batch, n, n) arrays. The batching rule moves the vmap batch +# axis to the front (just before the matrix dims) and calls det_p; the +# result carries the batch axis at position 0. + + +def _det_batch(args: tuple, batch_axes: tuple) -> tuple: + (x,) = args + (ax,) = batch_axes + x = jnp.moveaxis(x, ax, 0) + return det_p.bind(x), 0 + + +jax_batching.primitive_batchers[det_p] = _det_batch + + +# ── 6. Quax dispatch for QMatrix ────────────────────────────────── + + +@quax.register(det_p) +def _det_p_QMatrix(x: QMatrix, /) -> "u.AbstractQuantity": + """Compute the determinant of a 2-D :class:`~coordinax.internal.QMatrix`. + + The numeric value is computed via ``det_p.bind(x.value)``. The unit + is the product of the main-diagonal units — valid for diagonal metrics + and any matrix where all cofactor products share the same physical + dimension (e.g. coordinate metric tensors). + + Examples + -------- + >>> import jax + >>> import jax.numpy as jnp + >>> import quax + >>> import unxt as u + >>> from coordinax.internal import QMatrix + >>> from coordinax._src.internal.quantity_matrix import det + + >>> A = QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=(("m2", "m2"), ("m2", "m2")), + ... ) + >>> quax.quaxify(det)(A) + Q(6., 'm4') + + """ + if x.ndim != 2: + msg = f"det_p QMatrix dispatch requires a 2-D unit structure, got ndim={x.ndim}" + raise ValueError(msg) + det_val = det_p.bind(x.value) + n = x.unit.shape[0] + det_unit = ft.reduce(operator.mul, (x.unit[i, i] for i in range(n))) + return u.Q(det_val, det_unit) + + +############################################################################## +# Custom inv_p JAX primitive +# +# JAX has no standalone `inv` primitive. `jnp.linalg.inv` decomposes into +# `lu_p` + `triangular_solve_p`, and Quax has no dispatch for those on +# `QMatrix` (which raises on materialise). We define `inv_p` to give +# Quax a single interception point with full unit tracking. +# +# Supported transforms: +# - JIT via MLIR lowering that delegates to `jnp.linalg.inv` +# - Forward-mode autodiff (JVP): d(A^{-1}) = -A^{-1} dA A^{-1} +# - Reverse-mode autodiff (VJP): derived automatically from JVP +# - Batching (vmap): move batch axis to front, call inv_p + +inv_p = jexc.Primitive("inv") +inv_p.multiple_results = False + + +def inv(x: Array, /) -> Array: + """Compute the matrix inverse of a square matrix via the ``inv_p`` primitive. + + Delegates to ``inv_p``, a custom JAX primitive that supports JIT, + forward and reverse differentiation, and batching (vmap). + + For plain arrays the result is a bare :class:`~jaxtyping.Array`. + For :class:`~coordinax.internal.QMatrix` inputs the Quax + dispatch intercepts the call (see ``_inv_p_QMatrix``) and + returns a :class:`~coordinax.internal.QMatrix` with + reciprocal units. + + Parameters + ---------- + x : Array, shape ``(*batch, n, n)`` + Square matrix or batch of square matrices. + + Returns + ------- + Array, shape ``(*batch, n, n)`` + Matrix inverse of each square matrix. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.internal.quantity_matrix import inv + + Plain 2x2 diagonal matrix: + + >>> inv(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + Array([[0.5 , 0. ], + [0. , 0.25]], dtype=float64) + + Under JIT: + + >>> import jax + >>> jax.jit(inv)(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + Array([[0.5 , 0. ], + [0. , 0.25]], dtype=float64) + + Gradient (via reverse-mode autodiff) — returns a rank-4 Jacobian: + + >>> jac = jax.jacobian(inv)(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + >>> jac.shape + (2, 2, 2, 2) + + Batched (vmap): + + >>> A = jnp.stack([jnp.diag(jnp.array([2.0, 4.0])), + ... jnp.diag(jnp.array([1.0, 2.0]))]) + >>> jax.vmap(inv)(A) + Array([[[0.5 , 0. ], + [0. , 0.25]], + + [[1. , 0. ], + [0. , 0.5 ]]], dtype=float64) + + """ + return inv_p.bind(x) + + +# ── 1. Primal evaluation rule ───────────────────────────────────────────── + + +def _inv_impl(x: Array, /) -> Array: + return jnp.linalg.inv(x) + + +inv_p.def_impl(_inv_impl) + + +# ── 2. Abstract evaluation rule ─────────────────────────────────────────── + + +def _inv_abstract_eval(x: "jax.core.ShapedArray", /) -> "jax.core.ShapedArray": + if x.ndim < 2: + msg = f"inv_p requires at least 2-D input, got ndim={x.ndim}" + raise ValueError(msg) + if x.shape[-1] != x.shape[-2]: + msg = ( + f"inv_p requires a square matrix " + f"(shape[-2] == shape[-1]), got shape={x.shape}" + ) + raise ValueError(msg) + return x.update(shape=x.shape) # same shape as input + + +inv_p.def_abstract_eval(_inv_abstract_eval) + + +# ── 3. MLIR / XLA lowering ──────────────────────────────────────────────── + +jax_mlir.register_lowering( + inv_p, + jax_mlir.lower_fun(_inv_impl, multiple_results=False), +) + + +# ── 4. Forward-mode differentiation (JVP) ──────────────────────────────── +# +# d(A^{-1})(dA) = -A^{-1} dA A^{-1} + + +def _inv_jvp(primals: tuple, tangents: tuple) -> tuple[Array, Array]: + (x,) = primals + (dx,) = tangents + primal_out = inv_p.bind(x) + if type(dx) is jax_ad.Zero: + tangent_out = jax_ad.Zero.from_primal_value(primal_out) # ty: ignore[unresolved-attribute] + else: + tangent_out = -primal_out @ dx @ primal_out + return primal_out, tangent_out + + +jax_ad.primitive_jvps[inv_p] = _inv_jvp + + +# ── 5. Batching rule (vmap) ─────────────────────────────────────────────── + + +def _inv_batch(args: tuple, batch_axes: tuple) -> tuple: + (x,) = args + (ax,) = batch_axes + x = jnp.moveaxis(x, ax, 0) + return inv_p.bind(x), 0 + + +jax_batching.primitive_batchers[inv_p] = _inv_batch + + +# ── 6. Quax dispatch for QMatrix ────────────────────────────────── + + +@quax.register(inv_p) +def _inv_p_QMatrix(x: QMatrix, /) -> QMatrix: + """Compute the inverse of a 2-D :class:`~coordinax.internal.QMatrix`. + + The numeric value is computed via ``inv_p.bind(x.value)``. + Units are assumed uniform (all entries share the same physical unit, + as is the case for metrics produced by the Cartesian-Jacobian pullback); + the inverse carries the reciprocal unit throughout. + + Examples + -------- + >>> import jax + >>> import jax.numpy as jnp + >>> import quax + >>> import unxt as u + >>> from coordinax.internal import QMatrix, UnitsMatrix + >>> from coordinax._src.internal.quantity_matrix import inv + + >>> A = QMatrix( + ... jnp.array([[4.0, 0.0], [0.0, 1.0]]), + ... unit=UnitsMatrix((('m2', 'm2'), ('m2', 'm2'))), + ... ) + >>> quax.quaxify(inv)(A) + QMatrix([[0.25, 0. ], + [0. , 1. ]], '((1 / m2, 1 / m2), (1 / m2, 1 / m2))') + + """ + if x.ndim != 2: + msg = f"inv_p QMatrix dispatch requires a 2-D unit structure, got ndim={x.ndim}" + raise ValueError(msg) + inv_val = inv_p.bind(x.value) + return QMatrix(inv_val, unit=x.unit.inverse()) diff --git a/src/coordinax/_src/internal/quantity_matrix/__init__.py b/src/coordinax/_src/internal/quantity_matrix/__init__.py new file mode 100644 index 00000000..1ec23fa6 --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/__init__.py @@ -0,0 +1,45 @@ +"""Heterogeneous unit containers for vectors and matrices. + +This package provides two closely related building blocks: + +- `UnitsMatrix`, an immutable nested tuple of units with indexing support +- `QMatrix`, a quantity-like wrapper around one array plus a matching + static `UnitsMatrix` + +The numeric payload is a single JAX array of shape ``(..., *shape)`` where the +trailing dimensions are the logical vector or matrix dimensions and any leading +dimensions are batch dimensions. Units are stored separately as a static nested +tuple structure with the same logical shape, allowing every element to carry +its own physical unit. + +Currently the public surface supports only 1-D and 2-D structures: + +- 1-D: ``(..., N)`` with units ``(u0, u1, ..., uN-1)`` +- 2-D: ``(..., N, M)`` with units ``((u00, u01, ...), (u10, u11, ...), ...)`` + +Quax primitive dispatches (``add_p``, ``dot_general_p``) perform the +necessary per-element unit conversions via `unxt.uconvert_value` — which +correctly handles affine conversions (e.g. °F → °C), not just +multiplicative scale factors. +""" + +__all__ = ( + "QMatrix", + "UnitsMatrix", + "cdict_units", + "det", + "det_p", + "inv", + "inv_p", +) + +from . import _register_primitives # noqa: F401 +from ._det import det, det_p +from ._inv import inv, inv_p +from ._quantity_matrix import ( # noqa: F401 + QMatrix, + _convert_value_matrix, + _convert_value_vector, +) +from ._units_matrix import UnitsMatrix +from ._utils import cdict_units diff --git a/src/coordinax/_src/internal/quantity_matrix/_det.py b/src/coordinax/_src/internal/quantity_matrix/_det.py new file mode 100644 index 00000000..f1778966 --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/_det.py @@ -0,0 +1,210 @@ +"""Custom ``det_p`` JAX primitive with full JAX transform support. + +JAX has no built-in ``det`` primitive. ``jnp.linalg.det`` decomposes into +arithmetic (2x2 / 3x3) or ``lu_p`` + log/exp (larger matrices). We define +``det_p`` here so that Quax can intercept determinant calls on ``QMatrix`` +objects. + +Supported transforms: + - JIT via MLIR lowering that delegates to ``jnp.linalg.det`` + - Forward-mode autodiff (JVP): d(det A)(dA) = det(A) · tr(A⁻¹ dA) + - Reverse-mode autodiff (VJP): derived automatically from JVP via + transposition — no explicit transpose rule needed because the JVP tangent + only uses existing primitives (linalg.solve, trace, mul) that already + carry transpose rules. + - Batching (vmap): move the batch axis to the front and call det_p; the MLIR + lowering handles any (*batch, n, n) shape natively. +""" + +import functools as ft +import operator + +from jaxtyping import Array + +import jax +import jax.core +import jax.numpy as jnp +import quax +from jax import lax +from jax.extend import core as jexc +from jax.interpreters import ad as jax_ad, batching as jax_batching, mlir as jax_mlir + +import unxt as u + +from ._quantity_matrix import QMatrix + +det_p = jexc.Primitive("det") +det_p.multiple_results = False + + +def det(x: Array, /) -> Array: + """Compute the determinant of a square matrix via the ``det_p`` primitive. + + Delegates to ``det_p``, a custom JAX primitive that supports JIT, + forward and reverse differentiation, and batching (vmap). + + For plain arrays the result is a bare :class:`~jaxtyping.Array`. + For :class:`~coordinax.internal.QMatrix` inputs the Quax + dispatch intercepts the call (see ``_det_p_QMatrix``) and + returns a :class:`~unxt.AbstractQuantity`. + + Parameters + ---------- + x : Array, shape ``(*batch, n, n)`` + Square matrix or batch of square matrices. + + Returns + ------- + Array, shape ``(*batch,)`` + Determinant of each matrix. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.internal.quantity_matrix import det + + Plain 2x2 diagonal matrix: + + >>> det(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + Array(6., dtype=float64) + + Under JIT: + + >>> import jax + >>> jax.jit(det)(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + Array(6., dtype=float64) + + Gradient (via reverse-mode autodiff): + + >>> jax.grad(det)(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + Array([[3., 0.], + [0., 2.]], dtype=float64) + + Batched (vmap): + + >>> A = jnp.stack([jnp.diag(jnp.array([2.0, 3.0])), + ... jnp.diag(jnp.array([4.0, 5.0]))]) + >>> jax.vmap(det)(A) + Array([ 6., 20.], dtype=float64) + + """ + return det_p.bind(x) + + +# ── 1. Primal evaluation rule (eager / concrete values) ────────────────── + + +def _det_impl(x: Array, /) -> Array: + return jnp.linalg.det(x) + + +det_p.def_impl(_det_impl) + + +# ── 2. Abstract evaluation rule (shape / dtype inference for JIT) ──────── + + +def _det_abstract_eval(x: jax.core.ShapedArray, /) -> jax.core.ShapedArray: + if x.ndim < 2: + raise ValueError(f"det_p requires at least 2-D input, got ndim={x.ndim}") + if x.shape[-1] != x.shape[-2]: + raise ValueError( + f"det_p requires a square matrix " + f"(shape[-2] == shape[-1]), got shape={x.shape}" + ) + # (*batch, n, n) → (*batch,) + return x.update(shape=x.shape[:-2]) + + +det_p.def_abstract_eval(_det_abstract_eval) + + +# ── 3. MLIR / XLA lowering (JIT compilation) ───────────────────────────── + +jax_mlir.register_lowering( + det_p, + jax_mlir.lower_fun(_det_impl, multiple_results=False), +) + + +# ── 4. Forward-mode differentiation (JVP) ──────────────────────────────── +# +# Jacobi's formula: d(det A)(dA) = det(A) · tr(A⁻¹ dA) +# +# We use `jnp.linalg.solve(A, dA)` to compute A⁻¹ dA without explicitly +# forming the inverse — more numerically stable. The tangent uses only +# existing JAX primitives (solve, trace, scalar mul), so JAX derives the +# reverse-mode (VJP) rule automatically via transposition; no explicit +# `ad.primitive_transposes` registration is needed. + + +def _det_jvp(primals: tuple, tangents: tuple) -> tuple: + (x,) = primals + (dx,) = tangents + primal_out = det_p.bind(x) + if type(dx) is jax_ad.Zero: + tangent_out = lax.full_like(primal_out, 0.0) + else: + # tr(A⁻¹ dA) via solve — avoids explicit matrix inversion + tangent_out = primal_out * jnp.trace(jnp.linalg.solve(x, dx)) + return primal_out, tangent_out + + +jax_ad.primitive_jvps[det_p] = _det_jvp + + +# ── 5. Batching rule (vmap) ─────────────────────────────────────────────── +# +# `jnp.linalg.det` (used in the MLIR lowering) already operates correctly +# on batched (*batch, n, n) arrays. The batching rule moves the vmap batch +# axis to the front (just before the matrix dims) and calls det_p; the +# result carries the batch axis at position 0. + + +def _det_batch(args: tuple, batch_axes: tuple) -> tuple: + (x,) = args + (ax,) = batch_axes + x = jnp.moveaxis(x, ax, 0) + return det_p.bind(x), 0 + + +jax_batching.primitive_batchers[det_p] = _det_batch + + +# ── 6. Quax dispatch for QMatrix ────────────────────────────────── + + +@quax.register(det_p) +def _det_p_QMatrix(x: QMatrix, /) -> "u.AbstractQuantity": + """Compute the determinant of a 2-D :class:`~coordinax.internal.QMatrix`. + + The numeric value is computed via ``det_p.bind(x.value)``. The unit + is the product of the main-diagonal units — valid for diagonal metrics + and any matrix where all cofactor products share the same physical + dimension (e.g. coordinate metric tensors). + + Examples + -------- + >>> import jax + >>> import jax.numpy as jnp + >>> import quax + >>> import unxt as u + >>> from coordinax.internal import QMatrix + >>> from coordinax._src.internal.quantity_matrix import det + + >>> A = QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=(("m2", "m2"), ("m2", "m2")), + ... ) + >>> quax.quaxify(det)(A) + Q(6., 'm4') + + """ + if x.ndim != 2: + raise ValueError( + f"det_p QMatrix dispatch requires a 2-D unit structure, got ndim={x.ndim}" + ) + det_val = det_p.bind(x.value) + n = x.unit.shape[0] + det_unit = ft.reduce(operator.mul, (x.unit[i, i] for i in range(n))) + return u.Q(det_val, det_unit) diff --git a/src/coordinax/_src/internal/quantity_matrix/_inv.py b/src/coordinax/_src/internal/quantity_matrix/_inv.py new file mode 100644 index 00000000..0c6d37d0 --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/_inv.py @@ -0,0 +1,193 @@ +"""Custom ``inv_p`` JAX primitive with full JAX transform support. + +JAX has no standalone ``inv`` primitive. ``jnp.linalg.inv`` decomposes into +``lu_p`` + ``triangular_solve_p``, and Quax has no dispatch for those on +``QMatrix`` (which raises on materialise). We define ``inv_p`` to give +Quax a single interception point with full unit tracking. + +Supported transforms: + - JIT via MLIR lowering that delegates to ``jnp.linalg.inv`` + - Forward-mode autodiff (JVP): d(A^{-1}) = -A^{-1} dA A^{-1} + - Reverse-mode autodiff (VJP): derived automatically from JVP + - Batching (vmap): move batch axis to front, call inv_p +""" + +from jaxtyping import Array + +import jax +import jax.numpy as jnp +import quax +from jax.extend import core as jexc +from jax.interpreters import ad as jax_ad, batching as jax_batching, mlir as jax_mlir + +from ._quantity_matrix import QMatrix + +inv_p = jexc.Primitive("inv") +inv_p.multiple_results = False + + +def inv(x: Array, /) -> Array: + """Compute the matrix inverse of a square matrix via the ``inv_p`` primitive. + + Delegates to ``inv_p``, a custom JAX primitive that supports JIT, + forward and reverse differentiation, and batching (vmap). + + For plain arrays the result is a bare :class:`~jaxtyping.Array`. + For :class:`~coordinax.internal.QMatrix` inputs the Quax + dispatch intercepts the call (see ``_inv_p_QMatrix``) and + returns a :class:`~coordinax.internal.QMatrix` with + reciprocal units. + + Parameters + ---------- + x : Array, shape ``(*batch, n, n)`` + Square matrix or batch of square matrices. + + Returns + ------- + Array, shape ``(*batch, n, n)`` + Matrix inverse of each square matrix. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.internal.quantity_matrix import inv + + Plain 2x2 diagonal matrix: + + >>> inv(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + Array([[0.5 , 0. ], + [0. , 0.25]], dtype=float64) + + Under JIT: + + >>> import jax + >>> jax.jit(inv)(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + Array([[0.5 , 0. ], + [0. , 0.25]], dtype=float64) + + Gradient (via reverse-mode autodiff) — returns a rank-4 Jacobian: + + >>> jac = jax.jacobian(inv)(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + >>> jac.shape + (2, 2, 2, 2) + + Batched (vmap): + + >>> A = jnp.stack([jnp.diag(jnp.array([2.0, 4.0])), + ... jnp.diag(jnp.array([1.0, 2.0]))]) + >>> jax.vmap(inv)(A) + Array([[[0.5 , 0. ], + [0. , 0.25]], + + [[1. , 0. ], + [0. , 0.5 ]]], dtype=float64) + + """ + return inv_p.bind(x) + + +# ── 1. Primal evaluation rule ───────────────────────────────────────────── + + +def _inv_impl(x: Array, /) -> Array: + return jnp.linalg.inv(x) + + +inv_p.def_impl(_inv_impl) + + +# ── 2. Abstract evaluation rule ─────────────────────────────────────────── + + +def _inv_abstract_eval(x: "jax.core.ShapedArray", /) -> "jax.core.ShapedArray": + if x.ndim < 2: + raise ValueError(f"inv_p requires at least 2-D input, got ndim={x.ndim}") + if x.shape[-1] != x.shape[-2]: + raise ValueError( + f"inv_p requires a square matrix " + f"(shape[-2] == shape[-1]), got shape={x.shape}" + ) + return x.update(shape=x.shape) # same shape as input + + +inv_p.def_abstract_eval(_inv_abstract_eval) + + +# ── 3. MLIR / XLA lowering ──────────────────────────────────────────────── + +jax_mlir.register_lowering( + inv_p, + jax_mlir.lower_fun(_inv_impl, multiple_results=False), +) + + +# ── 4. Forward-mode differentiation (JVP) ──────────────────────────────── +# +# d(A^{-1})(dA) = -A^{-1} dA A^{-1} + + +def _inv_jvp(primals: tuple, tangents: tuple) -> tuple[Array, Array]: + (x,) = primals + (dx,) = tangents + primal_out = inv_p.bind(x) + if type(dx) is jax_ad.Zero: + tangent_out = jax_ad.Zero.from_primal_value(primal_out) # ty: ignore[unresolved-attribute] + else: + tangent_out = -primal_out @ dx @ primal_out + return primal_out, tangent_out + + +jax_ad.primitive_jvps[inv_p] = _inv_jvp + + +# ── 5. Batching rule (vmap) ─────────────────────────────────────────────── + + +def _inv_batch(args: tuple, batch_axes: tuple) -> tuple: + (x,) = args + (ax,) = batch_axes + x = jnp.moveaxis(x, ax, 0) + return inv_p.bind(x), 0 + + +jax_batching.primitive_batchers[inv_p] = _inv_batch + + +# ── 6. Quax dispatch for QMatrix ────────────────────────────────── + + +@quax.register(inv_p) +def _inv_p_QMatrix(x: QMatrix, /) -> QMatrix: + """Compute the inverse of a 2-D :class:`~coordinax.internal.QMatrix`. + + The numeric value is computed via ``inv_p.bind(x.value)``. + Units are assumed uniform (all entries share the same physical unit, + as is the case for metrics produced by the Cartesian-Jacobian pullback); + the inverse carries the reciprocal unit throughout. + + Examples + -------- + >>> import jax + >>> import jax.numpy as jnp + >>> import quax + >>> import unxt as u + >>> from coordinax.internal import QMatrix, UnitsMatrix + >>> from coordinax._src.internal.quantity_matrix import inv + + >>> A = QMatrix( + ... jnp.array([[4.0, 0.0], [0.0, 1.0]]), + ... unit=UnitsMatrix((('m2', 'm2'), ('m2', 'm2'))), + ... ) + >>> quax.quaxify(inv)(A) + QMatrix([[0.25, 0. ], + [0. , 1. ]], '((1 / m2, 1 / m2), (1 / m2, 1 / m2))') + + """ + if x.ndim != 2: + raise ValueError( + f"inv_p QMatrix dispatch requires a 2-D unit structure, got ndim={x.ndim}" + ) + + inv_val = inv_p.bind(x.value) + return QMatrix(inv_val, unit=x.unit.inverse()) diff --git a/src/coordinax/_src/internal/quantity_matrix/_quantity_matrix.py b/src/coordinax/_src/internal/quantity_matrix/_quantity_matrix.py new file mode 100644 index 00000000..cedc9256 --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/_quantity_matrix.py @@ -0,0 +1,423 @@ +"""QMatrix class and unit-conversion helpers.""" + +from jaxtyping import Array, Shaped +from typing import Any, NoReturn + +import equinox as eqx +import jax +import jax.core +import jax.numpy as jnp +import jax.tree_util as jtu +import plum + +import unxt as u +from unxt.quantity import AllowValue + +from ._units_matrix import UnitsMatrix +from ._utils import _DMLS, CDict, strict_zip + + +class QMatrix(u.AbstractQuantity): + """Quantity container whose elements may each carry different units. + + `QMatrix` stores one numeric array together with a static + `UnitsMatrix` describing the unit of each logical element. The shape of the + unit structure determines whether the object behaves as a heterogeneous + vector or matrix. + + Only 1-D and 2-D logical structures are supported. + + Parameters + ---------- + value : Array, shape ``(..., *shape)`` + Numeric payload. For 1D: ``(..., N)``. For 2D: ``(..., N, M)``. + The value of element ``[i]`` (1D) or ``[i, j]`` (2D) is expressed + in the corresponding unit. + unit : UnitsMatrix + Per-element units. For 1D: ``(u0, u1, ...)``. + For 2D: ``((u00, u01, ...), (u10, u11, ...), ...)``. + Must be a static (hashable) nested tuple structure whose shape + matches the trailing dimensions of ``value``. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + 1D case (vector): + + >>> qv = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg")) + >>> qv.value + Array([1., 2., 3.], dtype=float64) + >>> qv.unit.shape + (3,) + + >>> 2 * qv + QMatrix([2., 4., 6.], '(m, s, kg)') + + >>> qv2 = QMatrix(jnp.array([0.1, 200.0, 300.0]), unit=("km", "ms", "g")) + >>> qv + qv2 + QMatrix([101. , 2.2, 3.3], '(m, s, kg)') + + 2D case (matrix): + + >>> qm = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) + >>> qm.value.shape + (2, 2) + >>> qm.unit.shape + (2, 2) + + >>> 2 * qm + QMatrix([[2., 2.], + [2., 2.]], '((m, s), (kg, rad))') + + >>> qm2 = QMatrix(jnp.array([[0.1, 200.0], [300.0, 0.5]]), + ... unit=(("km", "ms"), ("g", "deg"))) + >>> qm + qm2 + QMatrix([[101. , 1.2 ], + [ 1.3 , 1.00872665]], '((m, s), (kg, rad))') + + Indexing: + + >>> qv[0] + Q(1., 'm') + >>> qm[0] + QMatrix([1., 1.], '(m, s)') + >>> qm[1, 0] + Q(1., 'kg') + + """ + + value: Shaped[Array, "..."] = eqx.field() + unit: UnitsMatrix = eqx.field(static=True, converter=u.unit) # ty: ignore[invalid-assignment] + + @property + def ndim(self) -> int: + """Number of real dimensions (1 for vector, 2 for matrix).""" + return self.unit.ndim + + @property + def shape(self) -> tuple[int, ...]: + """Shape, including batch dimensions.""" + return self.value.shape + + @classmethod + def from_cdict(cls, v: CDict, /, keys: tuple[str, ...] | None = None) -> "QMatrix": + """Pack a component dictionary into a 1-D ``QMatrix``. + + Each value in *v* is stripped to its numeric value and stacked into a + single JAX array. Values that carry units (``unxt.Quantity``) retain + those units in the resulting ``UnitsMatrix``; plain arrays are treated + as dimensionless. + + Examples + -------- + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + From a dictionary of quantities: + + >>> v = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "s"), "z": u.Q(3.0, "kg")} + >>> qv = QMatrix.from_cdict(v) + >>> qv.unit.to_string() + '(m, s, kg)' + >>> qv.value + Array([1., 2., 3.], dtype=float64, ...) + + Selecting and reordering a subset of keys: + + >>> qv2 = QMatrix.from_cdict(v, keys=("z", "x")) + >>> qv2.unit.to_string() + '(kg, m)' + >>> qv2.value + Array([3., 1.], dtype=float64, ...) + + Dimensionless entries (bare arrays) are accepted: + + >>> import jax.numpy as jnp + >>> v2 = {"a": jnp.array(4.0), "b": u.Q(5.0, "m")} + >>> qv3 = QMatrix.from_cdict(v2) + >>> qv3.unit.to_string() + '(, m)' + + """ + keys = tuple(v) if keys is None else keys + vs = [v[k] for k in keys] + us = [u.unit_of(x) or _DMLS for x in vs] + svs = jnp.stack([u.ustrip(AllowValue, unt, x) for x, unt in strict_zip(vs, us)]) + return cls(svs, unit=UnitsMatrix(us)) + + def __getitem__(self, index: Any, /) -> "u.Q | QMatrix": # ty: ignore[invalid-method-override] + """Index into the QMatrix to retrieve a specific element. + + Indexing a logical dimension returns a ``Quantity`` when the result is + a scalar unit, or a ``QMatrix`` when the result still has + structure. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + **1-D vector** — indexing a single element returns a ``Quantity``: + + >>> qv = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg")) + >>> qv[0] + Q(1., 'm') + >>> qv[2] + Q(3., 'kg') + + **2-D matrix** — indexing a row returns a ``QMatrix``: + + >>> qm = QMatrix(jnp.ones((2, 3)), + ... unit=(("m", "s", "kg"), ("rad", "deg", "m"))) + >>> qm[0] + QMatrix([1., 1., 1.], '(m, s, kg)') + + Indexing a specific element returns a ``Quantity``: + + >>> qm[1, 2] + Q(1., 'm') + + """ + value_item = self.value[index] + unit_item = self.unit[index] + if isinstance(unit_item, UnitsMatrix): + return QMatrix(value=value_item, unit=unit_item) + return u.Q(value_item, unit_item) + + # ── Quax API ───────────────────────────────────────────────────── + + def aval(self) -> jax.core.ShapedArray: + return jax.core.ShapedArray(self.value.shape, self.value.dtype) + + def materialise(self) -> NoReturn: + msg = "Refusing to materialise `QMatrix`." + raise RuntimeError(msg) + + def diag(self) -> "QMatrix": + """Return a 1-D ``QMatrix`` containing the diagonal of this matrix. + + Unlike ``qnp.diag``, this method operates directly on the static + ``unit`` structure and the raw value array, so it works correctly under + ``jax.jit`` and with heterogeneous-unit matrices. + + Only supported for 2-D ``QMatrix`` objects. + + Returns + ------- + QMatrix + 1-D ``QMatrix`` of length ``min(n_rows, n_cols)`` whose + ``unit[i]`` is ``self.unit[i, i]`` and whose ``value[..., i]`` is + ``self.value[..., i, i]``. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax.internal import QMatrix + + Uniform units: + + >>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), + ... unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m"))) + >>> d = A.diag() + >>> d.unit.shape + (3,) + >>> d.value + Array([1., 4., 9.], dtype=float64) + + Heterogeneous units — works under jit: + + >>> B = QMatrix(jnp.diag(jnp.array([1.0, 2.0, 3.0])), + ... unit=(("m", "s", "kg"), + ... ("m", "s", "kg"), + ... ("m", "s", "kg"))) + >>> db = B.diag() + >>> db.unit.to_string() + '(m, s, kg)' + >>> db.value + Array([1., 2., 3.], dtype=float64) + + """ + if self.ndim != 2: + raise ValueError( + f"QMatrix.diag() requires a 2D matrix, got ndim={self.ndim}" + ) + n = min(self.shape[-2], self.shape[-1]) + diag_value = jnp.stack([self.value[..., i, i] for i in range(n)], axis=-1) + diag_unit = UnitsMatrix(self.unit._units.diagonal()) + return QMatrix(value=diag_value, unit=diag_unit) + + @property + def T(self) -> "QMatrix": + """Transpose a 2-D ``QMatrix`` (swap rows/columns and units). + + Returns a new ``QMatrix`` whose value array and unit structure + are both transposed. Only 2-D matrices are supported. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix + + >>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + ... unit=(("m", "s"), ("kg", "rad"))) + >>> aT = a.T + >>> aT.value + Array([[1., 3.], + [2., 4.]], dtype=float64) + >>> aT.unit.to_string() + '((m, kg), (s, rad))' + + Also accessible via ``jax.numpy.transpose``: + + >>> aT2 = qnp.matrix_transpose(a) + >>> aT2.value + Array([[1., 3.], + [2., 4.]], dtype=float64) + >>> aT2.unit.to_string() + '((m, kg), (s, rad))' + + """ + if self.ndim != 2: + msg = f"QMatrix.T requires a 2-D matrix, got ndim={self.ndim}" + raise ValueError(msg) + return QMatrix(value=jnp.swapaxes(self.value, -2, -1), unit=self.unit.T) + + +############################################################################## +# Unit-conversion helpers + + +def _convert_value_vector( + value: Shaped[Array, "*batch N"], + from_units: tuple[u.AbstractUnit, ...], + to_units: tuple[u.AbstractUnit, ...], +) -> Shaped[Array, "*batch N"]: + """Convert every element of *value* from *from_units* to *to_units* (1D case). + + Each ``value[..., i]`` is converted individually via + `u.uconvert_value` so that **all** conversion types are handled + correctly. + """ + n = len(to_units) + return jnp.stack( + [u.uconvert_value(to_units[i], from_units[i], value[..., i]) for i in range(n)], + axis=-1, + ) + + +def _convert_value_matrix( + value: Shaped[Array, "*batch N M"], + from_units: tuple[tuple[u.AbstractUnit, ...], ...], + to_units: tuple[tuple[u.AbstractUnit, ...], ...], +) -> Shaped[Array, "*batch N M"]: + """Convert every element of *value* from *from_units* to *to_units* (2D case). + + Each ``value[..., i, j]`` is converted individually via + `u.uconvert_value` so that **all** conversion types are handled + correctly — including nonlinear ones like dB, mag, and dex (which + are logarithmic, not affine). + """ + n = len(to_units) + m = len(to_units[0]) + return jnp.stack( + [ + jnp.stack( + [ + u.uconvert_value(to_units[i][j], from_units[i][j], value[..., i, j]) + for j in range(m) + ], + axis=-1, + ) + for i in range(n) + ], + axis=-2, + ) + + +@plum.conversion_method(type_from=QMatrix, type_to=u.Q) +def QMatrix_to_quantity(x: QMatrix, /) -> u.Q: + """Convert a ``QMatrix`` to a regular ``Quantity``. + + Conversion is only valid when all elements of ``x`` share the same unit. If + units are heterogeneous, this conversion is ambiguous and raises + ``ValueError``. + + >>> import plum + >>> import jax.numpy as jnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + Uniform units convert to a plain quantity: + + >>> qmat = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "m", "m")) + >>> plum.convert(qmat, u.Q) + Q([1., 2., 3.], 'm') + + Mixed units are rejected: + + >>> bad = QMatrix(jnp.array([1.0, 2.0]), unit=("m", "s")) + >>> plum.convert(bad, u.Q) + Traceback (most recent call last): + ... + ValueError: Cannot convert QMatrix to Quantity unless all units are + identical. + + """ + units = jtu.tree_leaves(x.unit.to_tuple()) + + if not units: + msg = "Cannot convert QMatrix with no unit entries." + raise ValueError(msg) + + first = units[0] + if any(unit != first for unit in units[1:]): + msg = "Cannot convert QMatrix to Quantity unless all units are identical." + raise ValueError(msg) + + return u.Q(x.value, first) + + +def _convert_value( + value: Array, + from_units: UnitsMatrix, + to_units: UnitsMatrix, +) -> Array: + """Convert value with heterogeneous units (works for both 1D and 2D).""" + from_tup = from_units.to_tuple() + to_tup = to_units.to_tuple() + if from_units.ndim == 1: + return _convert_value_vector(value, from_tup, to_tup) + if from_units.ndim == 2: + return _convert_value_matrix(value, from_tup, to_tup) + msg = f"Unsupported ndim={from_units.ndim}" + raise NotImplementedError(msg) + + +@plum.dispatch +def uconvert(to_units: UnitsMatrix, x: QMatrix, /) -> QMatrix: + """Convert a ``QMatrix`` to different (but compatible) units. + + Unlike the generic astropy ``StructuredUnit.to()`` path, this dispatch uses + ``_convert_value`` directly so that the regular 2D JAX array in ``x.value`` + is converted element-by-element without requiring a numpy structured array. + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + >>> x = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + >>> q = QMatrix(x, (("m", "rad"), ("m", "rad"))) + >>> target = u.unit((("km", "deg"), ("km", "deg"))) + >>> q.uconvert(target).unit.to_string() + '((km, deg), (km, deg))' + + """ + if x.unit == to_units: + return x + value = _convert_value(x.value, x.unit, to_units) + return QMatrix(value=value, unit=to_units) diff --git a/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py b/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py new file mode 100644 index 00000000..a296445a --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py @@ -0,0 +1,808 @@ +"""Quax primitive registrations for QMatrix arithmetic. + +Registers handlers for the following JAX primitives: +- ``lax.add_p`` — element-wise addition +- ``lax.sub_p`` — element-wise subtraction +- ``lax.dot_general_p`` — dot product / matrix multiply +- ``lax.transpose_p`` — matrix transpose +- ``lax.gather_p`` — element-selection gather (e.g. jnp.diag) +- ``lax.reduce_sum_p`` — summation reduction +""" + +from typing import Any + +import jax +import jax.numpy as jnp +import jax.tree_util as jtu +import numpy as np +import quax +from jax import lax + +import unxt as u +from unxt.quantity import AllowValue + +from ._quantity_matrix import QMatrix, _convert_value +from ._units_matrix import UnitsMatrix + +# Vectorised uconvert_value — used by dot-product helpers. +vec_uconvert_value = np.vectorize(u.uconvert_value) + +_DMLS = u.unit("") + + +# ── add / sub ──────────────────────────────────────────────────────────── + + +@quax.register(lax.add_p) +def add_qm_qm(x: QMatrix, y: QMatrix, /) -> QMatrix: + """Element-wise addition of two `QMatrix` objects. + + The result adopts the units of *x*. Each element is converted from + ``y.unit`` → ``x.unit`` before the numeric add. + + Works for both 1D (vector) and 2D (matrix) cases. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + 2D case: + + >>> a = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) + >>> b = QMatrix(jnp.ones((2, 2)), unit=(("km", "ms"), ("g", "deg"))) + + >>> result = qnp.add(a, b) + >>> result.unit.to_string() + '((m, s), (kg, rad))' + + >>> result.value + Array([[1.00100000e+03, 1.00100000e+00], + [1.00100000e+00, 1.01745329e+00]], dtype=float64) + + 1D case: + + >>> a1d = QMatrix(jnp.ones(3), unit=("m", "s", "kg")) + >>> b1d = QMatrix(jnp.ones(3), unit=("km", "ms", "g")) + + >>> result1d = qnp.add(a1d, b1d) + >>> result1d.unit.to_string() + '(m, s, kg)' + + >>> result1d.value + Array([1.001e+03, 1.001e+00, 1.001e+00], dtype=float64) + + """ + y_converted = _convert_value(y.value, y.unit, x.unit) + return QMatrix(value=lax.add(x.value, y_converted), unit=x.unit) + + +@quax.register(lax.sub_p) +def sub_qm_qm(x: QMatrix, y: QMatrix, /) -> QMatrix: + """Element-wise subtraction of two `QMatrix` objects. + + The result adopts the units of *x*. Each element is converted from + ``y.unit`` → ``x.unit`` before the numeric subtract. + + Works for both 1D (vector) and 2D (matrix) cases. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + 2D case: + + >>> a = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad"))) + >>> b = QMatrix( + ... value=jnp.ones((2, 2)), + ... unit=(("km", u.unit("ms")), (u.unit("g"), u.unit("deg")))) + + >>> result = qnp.subtract(a, b) + >>> result.unit.to_string() + '((m, s), (kg, rad))' + + >>> result.value + Array([[-9.99000000e+02, 9.99000000e-01], + [ 9.99000000e-01, 9.82546707e-01]], dtype=float64) + + 1D case: + + >>> a1d = QMatrix(value=jnp.ones(3), + ... unit=("m", "s", "kg")) + >>> b1d = QMatrix(value=jnp.ones(3), + ... unit=("km", u.unit("ms"), u.unit("g"))) + + >>> result1d = qnp.subtract(a1d, b1d) + >>> result1d.unit.to_string() + '(m, s, kg)' + + >>> result1d.value + Array([-999. , 0.999, 0.999], dtype=float64) + + """ + y_converted = _convert_value(y.value, y.unit, x.unit) + return QMatrix(value=lax.sub(x.value, y_converted), unit=x.unit) + + +# ── dot_general helpers ─────────────────────────────────────────────────── + + +def _dot_general_1d_1d( + lhs: QMatrix, + rhs: QMatrix, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> u.Q: + """Vector dot product: (N,) @ (N,) → scalar. + + Result = Σ_i lhs[i] * rhs[i] + + All terms must be unit-compatible. We convert to the unit of the first term. + """ + n = lhs.shape[-1] + assert n == rhs.shape[-1] # noqa: S101 + + # Reference unit: lhs.unit[0] * rhs.unit[0] + ref_unit = lhs.unit[0] * rhs.unit[0] + + # Compute scale factors + scales = jnp.array( + [u.uconvert_value(ref_unit, lhs.unit[i] * rhs.unit[i], 1.0) for i in range(n)] + ) + + # Compute dot product with rescaling + result_value = jnp.sum(scales * lhs.value * rhs.value, axis=-1) + + return u.Q(result_value, ref_unit) + + +def _dot_general_2d_1d( + lhs: QMatrix, + rhs: QMatrix, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> QMatrix: + """Matrix-vector multiply: (N, K) @ (K,) → (N,). + + For ``w = A @ v`` where ``A`` is ``(N, K)`` and ``v`` is ``(K,)``: + + ``w[i] = Σ_j A[i, j] * v[j]`` + + Each product ``A[i,j] * v[j]`` has unit ``A.unit[i][j] * v.unit[j]``. All + ``K`` terms in the sum for output row ``i`` must be unit-compatible. We + convert every term to the unit of the *first* term (``j = 0``) for each + output row ``i``: ``ref[i] = A.unit[i][0] * v.unit[0]``. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + Identity matrix times a vector: + + >>> A = QMatrix(jnp.eye(3, dtype=jnp.float64), + ... unit=(("", "", ""), ("", "", ""), ("", "", ""))) + >>> v = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "m", "m")) + >>> w = qnp.matmul(A, v) + >>> w.value + Array([1., 2., 3.], dtype=float64) + + Mixed units on contraction axis (km column converted to m): + + >>> A2 = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + ... unit=(("m", "km"), ("m", "km"))) + >>> v2 = QMatrix(jnp.array([1.0, 1.0]), unit=("s", "s")) + >>> w2 = qnp.matmul(A2, v2) + >>> w2.value + Array([2001., 4003.], dtype=float64) + >>> w2.unit.to_string() + '(m s, m s)' + + """ + assert rhs.shape[-1] == lhs.shape[-1] # noqa: S101 + + # 1) Output units: ref[i] = lhs.unit[i][0] * rhs.unit[0] + out_unit = UnitsMatrix(np.multiply(lhs.unit._units[:, 0], rhs.unit._units[0])) + + # 2) Precompute scale factors: scale[i, j] converts + # lhs.unit[i][j]*rhs.unit[j] → ref[i] + scale_2d = jnp.array( + vec_uconvert_value( + out_unit._units[:, None], # (N, 1) — broadcast over K + np.multiply(lhs.unit._units, rhs.unit._units[None, :]), # (N, K) + 1.0, + ) + ) + + # 3) Vectorised contraction: + # w[..., i] = Σ_j scale[i, j] * A[..., i, j] * v[..., j] + accum = jnp.einsum("ij,...ij,...j->...i", scale_2d, lhs.value, rhs.value) + + return QMatrix(value=accum, unit=out_unit) + + +def _dot_general_2d_2d( + lhs: QMatrix, + rhs: QMatrix, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> QMatrix: + """Matrix multiply: (N, K) @ (K, M) → (N, M). + + For ``C = A @ B`` where ``A`` is ``(N, K)`` and ``B`` is ``(K, M)``: + + ``C[i, k] = Σ_j A[i, j] * B[j, k]`` + + Each product ``A[i,j] * B[j,k]`` has unit ``A.unit[i][j] * B.unit[j][k]``. + All ``K`` terms in the sum **must** be unit-compatible. We convert every + term to the unit of the *first* term (``j = 0``) using `u.uconvert_value`, + then sum with a plain matmul. + + The strategy: + 1. Pick a reference unit for each ``(i, k)`` output element: + ``ref[i][k] = A.unit[i][0] * B.unit[0][k]``. + 2. For each contraction index ``j``, compute per-element conversion + factors from ``A.unit[i][j] * B.unit[j][k]`` to ``ref[i][k]``. + Because the products are *multiplicative* compositions, the + conversion from ``u_A * u_B`` to ``ref`` is multiplicative even + when the individual units are affine — the product of two + absolute quantities is always absolute. + So we can safely compute a scale factor: + ``scale[i][j][k] = uconvert_value(ref[i][k], A.unit[i][j] * B.unit[j][k], 1.0)`` + 3. Build the rescaled sum as: + ``C_val[i, k] = Σ_j scale[i][j][k] * A_val[i, j] * B_val[j, k]`` + Done via ``C_val = (A_val * S_ij) @ B_val`` per output column, or + equivalently with a loop + accumulate. + """ + # Check contraction axis + assert lhs.shape[-1] == rhs.shape[-2] # noqa: S101 + + # 1) Compute output units: ref[i][k] = lhs.unit[i][0] * rhs.unit[0][k] + out_unit = np.multiply(lhs.unit._units[:, 0:1], rhs.unit._units[0:1, :]) + + # 2) Precompute all scale factors as a (N, K, M) constant array. + # scale[i, j, k] converts lhs.unit[i][j]*rhs.unit[j][k] → out_unit[i][k]. + # + # CORRECTNESS NOTE — why a multiplicative scale factor is exact: + # Affine units (°C, °F) are the only units where a bare + # multiplicative scale would be wrong (they have an additive + # offset). But astropy rejects product conversions involving + # affine units — e.g. ``(deg_C * s).to(deg_F * s)`` raises + # ``UnitConversionError``. Every product unit that astropy + # *does* accept (including logarithmic units like dex, mag) is + # a plain ``CompositeUnit`` whose conversion is purely + # multiplicative. So ``uconvert_value(to, from, 1.0)`` yields + # an exact scale factor for all valid product units. + # + # The tests in ``TestAffineProductUnitsRejected`` assert that + # astropy keeps rejecting affine product conversions. If that + # ever changes, those tests will fail, alerting us that this + # assumption needs revisiting. + scale_3d = jnp.array( + vec_uconvert_value( + out_unit[:, None, :], # (N, 1, M) + np.multiply(lhs.unit._units[:, :, None], rhs.unit._units[None, :, :]), + 1.0, # ꜛ (N, K, M) + ) + ) + + # 3) Vectorised contraction — no Python loop, no accumulator. + # C[..., i, k] = Σ_j scale[i, j, k] * A[..., i, j] * B[..., j, k] + accum = jnp.sum( # (N, K, M) * (..., N, K, 1) * (..., 1, K, M) + scale_3d * lhs.value[..., :, :, None] * rhs.value[..., None, :, :], + axis=-2, + ) + + return QMatrix(value=accum, unit=out_unit) + + +# ── dot_general dispatch ────────────────────────────────────────────────── + + +@quax.register(lax.dot_general_p) +def dot_general_qm_qm( + lhs: QMatrix, + rhs: QMatrix, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> QMatrix | u.Q: + """Dot product / matrix multiply two `QMatrix` objects. + + Delegates to specialized implementations based on the dimensionality: + - 1D @ 1D → scalar (vector dot product) + - 2D @ 2D → 2D (matrix-matrix multiply) + + For the standard matmul contraction: contracting_dims = ((-1,), (-2,)), + with no batch dims (batch is handled by leading dims in QMatrix). + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> import unxt as u + >>> from coordinax.internal import QMatrix + + 1D @ 1D (dot product): + + >>> v1 = QMatrix(jnp.array([1.0, 2.0]), unit=("m", "km")) + >>> v2 = QMatrix(jnp.array([3.0, 4.0]), unit=("s", "s")) + >>> result = qnp.dot(v1, v2) + >>> result.value + Array(8003., dtype=float64) + >>> result.unit + Unit("m s") + + 2D @ 2D (matrix multiply): + + >>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + ... unit=(("m", "km"), ("m", "km"))) + >>> b = QMatrix(jnp.array([[1.0, 0.0], [0.0, 1.0]]), + ... unit=(("s", "s"), ("s", "s"))) + + >>> c = qnp.matmul(a, b) + >>> c.unit.to_string() + '((m s, m s), (m s, m s))' + + >>> c.value + Array([[1.e+00, 2.e+03], + [3.e+00, 4.e+03]], dtype=float64) + + """ + # For now, we only handle the standard matmul/dot contraction + (contract, batch) = dimension_numbers + assert len(contract[0]) == 1 and len(contract[1]) == 1 # noqa: PT018, S101 + assert len(batch[0]) == 0 and len(batch[1]) == 0 # noqa: PT018, S101 + + # Delegate based on dimensionality + if lhs.ndim == 1 and rhs.ndim == 1: + return _dot_general_1d_1d( + lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + if lhs.ndim == 2 and rhs.ndim == 2: + return _dot_general_2d_2d( + lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + if lhs.ndim == 2 and rhs.ndim == 1: + return _dot_general_2d_1d( + lhs, + rhs, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + msg = f"Unsupported dimensionality: lhs.ndim={lhs.ndim}, rhs.ndim={rhs.ndim}" + raise NotImplementedError(msg) + + +@quax.register(lax.dot_general_p) +def dot_general_qm_arr( + lhs: QMatrix, + rhs: jax.Array, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> "QMatrix | u.Q": + """Dot product of a :class:`QMatrix` with a plain JAX array. + + The plain array is treated as dimensionless. Delegates to + :func:`dot_general_qm_qm` after wrapping ``rhs`` in a dimensionless + :class:`QMatrix`. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix, UnitsMatrix + + 2D metric x 1D plain vector: + + >>> g = QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=UnitsMatrix((("m2", "m2"), ("m2", "m2"))), + ... ) + >>> v = jnp.array([1.0, 1.0]) + >>> w = qnp.matmul(g, v) + >>> w.unit.to_string() + '(m2, m2)' + >>> w.value + Array([2., 3.], dtype=float64) + + """ + if rhs.ndim == 1: + n = rhs.shape[0] + rhs_qm = QMatrix(rhs, unit=UnitsMatrix(tuple(_DMLS for _ in range(n)))) + else: + nr, nc = rhs.shape[-2], rhs.shape[-1] + rhs_qm = QMatrix( + rhs, + unit=UnitsMatrix(tuple(tuple(_DMLS for _ in range(nc)) for _ in range(nr))), + ) + return dot_general_qm_qm( + lhs, + rhs_qm, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + + +@quax.register(lax.dot_general_p) +def dot_general_qm_qty( + lhs: QMatrix, + rhs: u.AbstractQuantity, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> "QMatrix | u.Q": + """Dot product of a :class:`QMatrix` with a :class:`~unxt.AbstractQuantity`. + + The Quantity carries a single scalar unit that applies uniformly to all + elements. The ``rhs`` is wrapped as a uniform-unit + :class:`QMatrix` and delegated to :func:`dot_general_qm_qm`. + + Note that :class:`QMatrix` is itself a subtype of + :class:`~unxt.AbstractQuantity`, so :func:`dot_general_qm_qm` takes + precedence when both sides are :class:`QMatrix`. + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix, UnitsMatrix + + 2D metric with units @ uniform-unit Quantity vector: + + >>> g = QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=UnitsMatrix((("m2 / rad2", "m2 / rad2"), ("m2 / rad2", "m2 / rad2"))), + ... ) + >>> v = u.Q(jnp.array([1.0, 1.0]), "rad") + >>> w = qnp.matmul(g, v) + >>> w.unit.to_string() + '(m2 / rad, m2 / rad)' + >>> w.value + Array([2., 3.], dtype=float64) + + """ + rhs_unit = u.unit_of(rhs) + rhs_val = u.ustrip(AllowValue, rhs_unit, rhs) + if rhs_val.ndim == 1: + n = rhs_val.shape[0] + rhs_qm = QMatrix(rhs_val, unit=UnitsMatrix(tuple(rhs_unit for _ in range(n)))) + else: + nr, nc = rhs_val.shape[-2], rhs_val.shape[-1] + rhs_qm = QMatrix( + rhs_val, + unit=UnitsMatrix( + tuple(tuple(rhs_unit for _ in range(nc)) for _ in range(nr)) + ), + ) + return dot_general_qm_qm( + lhs, + rhs_qm, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + + +@quax.register(lax.dot_general_p) +def dot_general_arr_qm( + lhs: jax.Array, + rhs: QMatrix, + /, + *, + dimension_numbers: lax.DotDimensionNumbers, + precision: Any = None, + preferred_element_type: Any = None, + **kw: Any, +) -> "QMatrix | u.Q": + """Dot product of a plain JAX array with a :class:`QMatrix`. + + The plain array is treated as dimensionless. Delegates to + :func:`dot_general_qm_qm` after wrapping ``lhs`` in a dimensionless + :class:`QMatrix`. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix + + Dimensionless identity @ QMatrix vector: + + >>> A = jnp.eye(2, dtype=jnp.float64) + >>> v = QMatrix(jnp.array([2.0, 3.0]), unit=("m / s", "m / s")) + >>> w = qnp.matmul(A, v) + >>> w.unit.to_string() + '(m / s, m / s)' + >>> w.value + Array([2., 3.], dtype=float64) + + """ + if lhs.ndim == 1: + n = lhs.shape[0] + lhs_qm = QMatrix(lhs, unit=UnitsMatrix(tuple(_DMLS for _ in range(n)))) + else: + nr, nc = lhs.shape[-2], lhs.shape[-1] + lhs_qm = QMatrix( + lhs, + unit=UnitsMatrix(tuple(tuple(_DMLS for _ in range(nc)) for _ in range(nr))), + ) + return dot_general_qm_qm( + lhs_qm, + rhs, + dimension_numbers=dimension_numbers, + precision=precision, + preferred_element_type=preferred_element_type, + **kw, + ) + + +# ── transpose ──────────────────────────────────────────────────────────── + + +@quax.register(lax.transpose_p) +def transpose_qm(x: QMatrix, /, *, permutation: tuple[int, ...]) -> QMatrix: + """Transpose a ``QMatrix``, swapping only the last two (matrix) axes. + + Leading batch dimensions must be preserved unchanged. Only permutations + that swap the last two axes while keeping all batch axes in place are + supported, because the unit structure is purely 2-D and cannot represent + arbitrary axis re-orderings. + + >>> import jax.numpy as jnp + >>> import quaxed.numpy as qnp + >>> from coordinax.internal import QMatrix + + 2-D (no batch): + + >>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]), + ... unit=(("m", "s"), ("kg", "rad"))) + >>> aT = qnp.matrix_transpose(a) + >>> aT.value + Array([[1., 3.], + [2., 4.]], dtype=float64) + >>> aT.unit.to_string() + '((m, kg), (s, rad))' + + Batched ``(B, N, M)`` — batch axis is preserved: + + >>> import jax + >>> b = QMatrix(jnp.ones((3, 2, 2)), + ... unit=(("m", "s"), ("kg", "rad"))) + >>> bT = qnp.matrix_transpose(b) + >>> bT.shape + (3, 2, 2) + + """ + ndim_val = len(permutation) # full ndim of the value array (includes batch dims) + if ndim_val < 2: + msg = f"transpose_qm requires ndim >= 2, got ndim={ndim_val}" + raise NotImplementedError(msg) + # Validate: batch axes must be unchanged, last two must be swapped. + expected = (*range(ndim_val - 2), ndim_val - 1, ndim_val - 2) + if tuple(permutation) != expected: + msg = ( + f"transpose_qm only supports matrix transpose of the last two axes " + f"(expected permutation {expected}), got {tuple(permutation)}" + ) + raise NotImplementedError(msg) + transposed_value = lax.transpose(x.value, permutation) + return QMatrix(value=transposed_value, unit=x.unit.T) + + +# ── gather ─────────────────────────────────────────────────────────────── + + +def _jit_fallback_uniform_unit(units: UnitsMatrix, out_size: int) -> UnitsMatrix: + """Return a 1-D ``UnitsMatrix`` of length *out_size* if all units are equal. + + Used as a JIT-mode fallback inside ``gather_qm`` when the concrete gather + indices are not available. Raises ``ValueError`` for heterogeneous inputs. + """ + all_units = jtu.tree_leaves(units.to_tuple()) + first = all_units[0] + if any(u_i != first for u_i in all_units[1:]): + msg = ( + "QMatrix gather (e.g. jnp.diag) under jit requires all units " + "to be equal when indices cannot be concretized. " + "Call eagerly (outside jit) for heterogeneous-unit QMatrix." + ) + raise ValueError(msg) + return UnitsMatrix(np.full((out_size,), first, dtype=object)) + + +@quax.register(lax.gather_p) +def gather_qm( + x: QMatrix, + start_indices: jax.Array, + /, + *, + dimension_numbers: lax.GatherDimensionNumbers, + slice_sizes: tuple[int, ...], + indices_are_sorted: bool = False, + mode: Any = None, + fill_value: Any = None, + unique_indices: bool = False, + **kwargs: Any, +) -> QMatrix: + """Handle element-selection gathers (e.g. ``jnp.diag``) for ``QMatrix``. + + Supports only *element-selection* gathers where every input dimension is + collapsed (``offset_dims == ()`` and all ``slice_sizes == 1``). This + covers ``jnp.diag``, ``jnp.diagonal``, and integer-array fancy indexing on + ``QMatrix`` objects. + + Unit extraction: + + ``QMatrix.unit`` is declared ``static=True`` and is therefore always + a concrete Python object, even inside ``jax.jit``. The *indices*, however, + are traced under JIT and cannot be read concretely. Because JAX's + ``jnp.diag`` uses ``platform_dependent`` internally, quax always traces + both branches via ``make_jaxpr``, so the JIT fallback path is taken for + unit resolution. Consequently, all units in the input must be equal; + heterogeneous-unit inputs raise ``ValueError``. + + >>> import jax.numpy as jnp + >>> from coordinax.internal import QMatrix + + Diagonal of a 3x3 dimensionless matrix: + + >>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), + ... unit=(("", "", ""), ("", "", ""), ("", "", ""))) + >>> d = A.diag() + >>> d.unit.shape + (3,) + >>> d.unit.ndim + 1 + >>> d.value + Array([1., 4., 9.], dtype=float64) + + ```{note} + ``jnp.diag`` uses JAX's ``platform_dependent`` internally, which causes + quax to trace both branches via ``make_jaxpr`` even in eager mode. This + means the JIT fallback path is always taken for the unit computation, so + **heterogeneous-unit matrices are not supported** with ``qnp.diag``. + All units in the input must be equal; otherwise a ``ValueError`` is raised. + ``` + + """ + result_value = lax.gather( + x.value, + start_indices, + dimension_numbers, + slice_sizes, + indices_are_sorted=indices_are_sorted, + mode=mode, + fill_value=fill_value, + unique_indices=unique_indices, + ) + + # Only element-selection gathers are supported: all input dimensions must + # be collapsed and every slice_size must be 1. + n_input_dims = x.value.ndim + normalized_collapsed = { + d % n_input_dims for d in dimension_numbers.collapsed_slice_dims + } + is_element_selection = ( + dimension_numbers.offset_dims == () + and normalized_collapsed == set(range(n_input_dims)) + and all(s == 1 for s in slice_sizes) + ) + if not is_element_selection: + msg = ( + "QMatrix: only element-selection gathers (all input dims " + "collapsed, all slice_sizes == 1) are supported. " + f"Got offset_dims={dimension_numbers.offset_dims}, " + f"collapsed_slice_dims={dimension_numbers.collapsed_slice_dims}, " + f"slice_sizes={slice_sizes}." + ) + raise NotImplementedError(msg) + + # Number of output elements — start_indices.shape is always concrete in JAX. + out_size = start_indices.shape[0] + + if isinstance(start_indices, jax.core.Tracer): + # JIT path: indices are traced — fall back to uniform-unit check. + out_unit = _jit_fallback_uniform_unit(x.unit, out_size) + else: + # Eager path: indices are concrete — look up units directly. + idx_np = np.asarray(start_indices) + if x.unit.ndim == 1: + out_unit = UnitsMatrix(x.unit._units[idx_np[:, 0]]) + else: # x.unit.ndim == 2 + out_unit = UnitsMatrix(x.unit._units[idx_np[:, 0], idx_np[:, 1]]) + + return QMatrix(value=result_value, unit=out_unit) + + +# ── reduce_sum ─────────────────────────────────────────────────────────── + + +@quax.register(lax.reduce_sum_p) +def reduce_sum_p_qm(operand: QMatrix, /, *, axes: Any, **kwargs: Any) -> QMatrix: + """Handle ``lax.reduce_sum`` for ``QMatrix``. + + ``jnp.diag`` on a square 2-D matrix uses ``platform_dependent`` which traces + *both* the default (gather-based) and Mosaic implementation. The Mosaic + path computes ``reduce(mul(eye, A), axis=0)`` — JAX's JIT optimises + ``lax.reduce(x, 0, lax.add, (0,))`` to the simpler ``reduce_sum_p`` + primitive. This handler ensures the output carries the correct 1-D unit + structure so that both branches produce the *same* pytree — required by + ``platform_dependent`` / ``lax.switch``. + + Unit reduction rule: + + When reducing a 2-D ``QMatrix`` along ``axes=(0,)`` (rows): the + output unit for column *j* is taken from ``operand.unit[0, j]`` (the first + row). All elements being summed along a column must be unit-compatible for + the sum to be physically meaningful. + + Analogously for ``axes=(1,)`` (column reduction), the output unit for row + *i* is ``operand.unit[i, 0]``. + + >>> import jax.numpy as jnp + >>> from coordinax.internal import QMatrix + + ``QMatrix.diag()`` on a 3x3 uniform-unit matrix: + + >>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])), + ... unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m"))) + >>> d = A.diag() + >>> d.unit.shape + (3,) + >>> d.unit.ndim + 1 + + """ + result_value = lax.reduce_sum_p.bind(operand.value, axes=axes, **kwargs) + + # Reduce the unit structure by dropping the summed axes. + axset = set(axes) + if operand.ndim == 2: + if axset == {0}: + # Row reduction → 1-D output; unit = first row's units. + out_unit = UnitsMatrix(operand.unit._units[0]) + elif axset == {1}: + # Column reduction → 1-D output; unit = first column's units. + out_unit = UnitsMatrix(operand.unit._units[:, 0]) + else: + msg = f"reduce_sum_p_qm: unsupported axes={axes} for 2-D QMatrix." + raise NotImplementedError(msg) + else: + msg = ( + f"reduce_sum_p_qm: only 2-D QMatrix is supported, got ndim={operand.ndim}." + ) + raise NotImplementedError(msg) + + return QMatrix(value=result_value, unit=out_unit) diff --git a/src/coordinax/_src/internal/quantity_matrix/_units_matrix.py b/src/coordinax/_src/internal/quantity_matrix/_units_matrix.py new file mode 100644 index 00000000..07224143 --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/_units_matrix.py @@ -0,0 +1,380 @@ +"""Immutable, hashable unit structure for QMatrix.""" + +from typing import Any, TypeAlias, TypeVar, final + +import numpy as np +import plum + +import unxt as u + +T = TypeVar("T") + +NestedTuple: TypeAlias = T | tuple["NestedTuple[T]", ...] +UnitTree: TypeAlias = NestedTuple[u.AbstractUnit] + + +def _normalize_unit(x: Any, /) -> u.AbstractUnit: + """Convert *x* to an ``AbstractUnit``; accept unit strings and AbstractUnit. + + Raises ``TypeError`` for unsupported types. + """ + if isinstance(x, str): + return u.unit(x) # ty: ignore[invalid-return-type] + if isinstance(x, u.AbstractUnit): + return x + msg = f"Expected an AbstractUnit or unit string; got {type(x).__name__!r}" + raise TypeError(msg) + + +def _build_object_array(iterable: Any, /) -> np.ndarray: # noqa: C901 + """Build a 1-D or 2-D numpy object array of ``AbstractUnit`` from *iterable*. + + Accepts: + + - A numpy object array (element-normalize and validate ndim). + - A plain tuple/list of units or unit strings → 1-D output. + - A plain tuple/list of tuples of units or unit strings → 2-D output. + + Raises ``TypeError`` if a non-sequence is passed, ``ValueError`` if the + structure is ragged or has unsupported ndim. + """ + if isinstance(iterable, np.ndarray) and iterable.dtype == object: + if iterable.ndim not in (1, 2): + msg = f"UnitsMatrix only supports 1D or 2D; got ndim={iterable.ndim}" + raise ValueError(msg) + flat = [_normalize_unit(v) for v in iterable.flat] + data: np.ndarray = np.empty(iterable.shape, dtype=object) + data.flat[:] = flat + return data + + # Sequence path: tuple, list, or any iterable + items = list(iterable) # raises TypeError if not iterable + + if not items: + raise ValueError("UnitsMatrix requires at least one element") + + first = items[0] + if isinstance(first, (tuple, list)): + # 2-D: sequence of rows — validate and fill in one pass + n, m = len(items), len(first) + data = np.empty((n, m), dtype=object) + for i, row in enumerate(items): + if not isinstance(row, (tuple, list)) or len(row) != m: + raise ValueError("ragged structure") + for j, v in enumerate(row): + if isinstance(v, (tuple, list)): + raise ValueError("ragged structure") # noqa: TRY004 + data[i, j] = _normalize_unit(v) + return data + + # 1-D: sequence of units / unit strings + n = len(items) + data = np.empty(n, dtype=object) + for i, v in enumerate(items): + if isinstance(v, (tuple, list)): # Mixed leaf/nested → ragged + raise ValueError("ragged structure") # noqa: TRY004 + data[i] = _normalize_unit(v) + return data + + +@final +class UnitsMatrix: + """Immutable, hashable unit structure for `QMatrix`. + + `UnitsMatrix` wraps a numpy object array (``dtype=object``) of + `~unxt.AbstractUnit` elements. Only 1-D and 2-D structures are accepted. + + The class supports tuple-style indexing, iteration, `to_tuple()`, and + `to_string()`. It is **not** a subclass of `astropy.units.StructuredUnit`; + bidirectional converters to/from ``StructuredUnit`` are provided in + ``coordinax.interop.astropy``. + + Hashability is achieved via ``hash(self.to_tuple())``, so the underlying + ``AbstractUnit`` objects must themselves be hashable (they are). + + For 1D: ``UnitsMatrix(("m", "s", "kg"))`` + For 2D: ``UnitsMatrix((("m", "s"), ("kg", "rad")))`` + + Examples + -------- + >>> import unxt as u + >>> from coordinax.internal import UnitsMatrix + + 1D case: + + >>> units_1d = UnitsMatrix(("m", "s", "kg")) + >>> units_1d.shape + (3,) + >>> units_1d[0] + Unit("m") + >>> units_1d.to_string() + '(m, s, kg)' + + 2D case: + + >>> units_2d = UnitsMatrix((("m", "s"), ("kg", "rad"))) + >>> units_2d.shape + (2, 2) + >>> units_2d[0, 1] + Unit("s") + >>> units_2d.to_string() + '((m, s), (kg, rad))' + + """ + + __slots__ = ("_units",) + + def __init__(self, iterable: Any, /) -> None: + if isinstance(iterable, UnitsMatrix): + # Copy from another UnitsMatrix — avoids sharing the mutable array. + data = iterable._units.copy() + else: + data = _build_object_array(iterable) + if data.ndim not in (1, 2): + msg = f"UnitsMatrix only supports 1D or 2D, but got ndim={data.ndim}" + raise ValueError(msg) + self._units = data + + # ── Shape / structure ───────────────────────────────────────────── + + @property + def shape(self) -> tuple[int, ...]: + """Shape of the N-D unit structure.""" + return tuple(self._units.shape) + + @property + def ndim(self) -> int: + """Number of dimensions.""" + return int(self._units.ndim) + + @property + def T(self) -> "UnitsMatrix": + """Compute the all-axis units array transpose. + + Examples + -------- + >>> from coordinax.internal import UnitsMatrix + + >>> units = UnitsMatrix(("m", "s")) + >>> units.T + UnitsMatrix("(m, s)") + + >>> units = UnitsMatrix((("m", "s"), ("kg", "rad"))) + >>> units.T + UnitsMatrix("((m, kg), (s, rad))") + + >>> units = UnitsMatrix((("m", "s", "kg"), ("Hz", "candela", "km"))) + >>> units.T + UnitsMatrix("((m, Hz), (s, cd), (kg, km))") + + """ + return UnitsMatrix(self._units.T) + + def inverse(self) -> "UnitsMatrix": + r"""Inverse unit structure — each unit raised to the power -1. + + For a 1-D (diagonal) ``UnitsMatrix`` the inversion is done + entry-by-entry in *O(n)*, providing a speedup over the general 2-D + case. For a 2-D ``UnitsMatrix`` with a uniform unit (all entries + equal) the reciprocal is computed once and broadcast in *O(1)*; + mixed-unit 2-D structures fall back to an element-wise *O(nm)* loop. + + Examples + -------- + >>> from coordinax.internal import UnitsMatrix + + 1-D (diagonal) case — element-wise reciprocal: + + >>> UnitsMatrix(("m2", "s2")).inverse() + UnitsMatrix("(1 / m2, 1 / s2)") + + 2-D uniform-unit case: + + >>> UnitsMatrix((("m2", "m2"), ("m2", "m2"))).inverse() + UnitsMatrix("((1 / m2, 1 / m2), (1 / m2, 1 / m2))") + + 2-D mixed-unit case: + + >>> UnitsMatrix((("m2", "s2"), ("s2", "rad2"))).inverse() + UnitsMatrix("((1 / m2, 1 / s2), (1 / s2, 1 / rad2))") + + """ + inv_data = np.empty(self._units.shape, dtype=object) + if self._units.ndim == 1: + # Diagonal speedup: 1-D represents a diagonal metric's units. + for i in range(self._units.shape[0]): + inv_data[i] = self._units[i] ** (-1) + else: + # 2-D: fast path when all entries share the same unit. + flat = self._units.ravel() + first = flat[0] + if all(u == first for u in flat[1:]): + inv_unit = first ** (-1) + inv_data[:] = inv_unit + else: + n, m = self._units.shape + for i in range(n): + for j in range(m): + inv_data[i, j] = self._units[i, j] ** (-1) + return UnitsMatrix(inv_data) + + # ── Serialization ───────────────────────────────────────────────── + + def to_tuple(self) -> UnitTree: + """Convert to a nested tuple of `~unxt.AbstractUnit` objects. + + Examples + -------- + >>> from coordinax.internal import UnitsMatrix + >>> import unxt as u + >>> UnitsMatrix(("m", "s")).to_tuple() + (Unit("m"), Unit("s")) + + """ + if self._units.ndim == 1: + return tuple(self._units) + return tuple(map(tuple, self._units)) + + def to_string(self) -> str: + """Return a human-readable string representation of the unit structure. + + Examples + -------- + >>> from coordinax.internal import UnitsMatrix + >>> UnitsMatrix(("m", "s", "kg")).to_string() + '(m, s, kg)' + >>> UnitsMatrix((("m", "s"), ("kg", "rad"))).to_string() + '((m, s), (kg, rad))' + + """ + if self._units.ndim == 1: + inner = ", ".join(str(x) for x in self._units) + if len(self._units) == 1: + return f"({inner},)" + return f"({inner})" + # 2D + row_strs = [] + for row in self._units: + inner = ", ".join(str(x) for x in row) + row_strs.append(f"({inner},)" if len(row) == 1 else f"({inner})") + if len(self._units) == 1: + return f"({row_strs[0]},)" + return f"({', '.join(row_strs)})" + + # ── Python data model ───────────────────────────────────────────── + + def __repr__(self) -> str: + return f'UnitsMatrix("{self.to_string()}")' + + def __eq__(self, other: Any, /) -> bool: + if isinstance(other, UnitsMatrix): + if self._units.shape != other._units.shape: + return False + return bool(np.all(self._units == other._units)) + if isinstance(other, (tuple, list)): + try: + return self == UnitsMatrix(other) + except (TypeError, ValueError): + return False + return NotImplemented + + def __hash__(self) -> int: + return hash(self.to_tuple()) + + def __iter__(self) -> Any: + """Iterate over elements (1D) or row ``UnitsMatrix`` objects (2D). + + Examples + -------- + >>> from coordinax.internal import UnitsMatrix + >>> list(UnitsMatrix(("m", "rad", "rad"))) + [Unit("m"), Unit("rad"), Unit("rad")] + + """ + if self._units.ndim == 1: + yield from self._units + return + for row in self._units: + yield UnitsMatrix(row) + + def __getitem__(self, index: Any, /) -> Any: + """Index into the UnitsMatrix to retrieve a unit or sub-structure. + + >>> from coordinax.internal import UnitsMatrix + >>> units = UnitsMatrix((("m", "s"), ("kg", "rad"))) + + Indexing a single element returns a unit: + + >>> units[0, 1] + Unit("s") + + Indexing a row returns a UnitsMatrix: + + >>> units[0] + UnitsMatrix("(m, s)") + + """ + result = self._units[index] + if isinstance(result, np.ndarray): + if result.ndim == 0: # 0-d array -> extract the contained unit. + return result.item() + return UnitsMatrix(result) + return result + + +@plum.dispatch +def unit(tuple_of_units: tuple[Any, ...], /) -> UnitsMatrix: + """Convert a nested tuple of units into a ``UnitsMatrix``. + + This allows users to specify units in a convenient nested tuple format when + constructing ``QMatrix`` instances, and have them automatically + converted to the appropriate ``UnitsMatrix``. + + >>> import unxt as u + + 1D case: + + >>> u.unit(("m", "s", "kg")) + UnitsMatrix("(m, s, kg)") + + 2D case: + + >>> u.unit((("m", "s"), ("kg", "rad"))) + UnitsMatrix("((m, s), (kg, rad))") + + """ + return UnitsMatrix(tuple_of_units) + + +@plum.dispatch +def unit(arr: np.ndarray, /) -> UnitsMatrix: + """Convert a numpy object array of units into a ``UnitsMatrix``. + + >>> import numpy as np + >>> import unxt as u + >>> from coordinax.internal import UnitsMatrix + >>> arr = np.array([u.unit("m"), u.unit("s")], dtype=object) + >>> u.unit(arr) + UnitsMatrix("(m, s)") + + """ + return UnitsMatrix(arr) + + +@plum.dispatch +def unit(obj: UnitsMatrix, /) -> UnitsMatrix: + """Identity: a UnitsMatrix is returned unchanged by the unit converter.""" + return obj + + +@plum.dispatch +def unit_of(obj: UnitsMatrix, /) -> UnitsMatrix: + """Identity conversion for UnitsMatrix to itself. + + >>> import unxt as u + >>> unit = u.unit(("m", "s", "kg")) + >>> u.unit_of(unit) is unit + True + + """ + return obj diff --git a/src/coordinax/_src/internal/quantity_matrix/_utils.py b/src/coordinax/_src/internal/quantity_matrix/_utils.py new file mode 100644 index 00000000..e554b594 --- /dev/null +++ b/src/coordinax/_src/internal/quantity_matrix/_utils.py @@ -0,0 +1,31 @@ +"""Shared utilities for the quantity_matrix package.""" + +from typing import Any, TypeAlias, cast + +import unxt as u + +CDict: TypeAlias = dict[str, Any] +_DMLS = u.unit("") + +PackedUnitOutput: TypeAlias = tuple[u.AbstractUnit | None, ...] + + +def strict_zip(*args: Any) -> zip: + """Zip iterables while enforcing equal lengths.""" + return zip(*args, strict=True) + + +def cdict_units(p: CDict, keys: tuple[str, ...], /) -> PackedUnitOutput: + """Extract per-key units from a component dictionary. + + Non-quantity entries yield `None`, so the output tuple can be used for + heterogeneous dictionaries containing both quantity and non-quantity data. + + >>> import unxt as u + >>> d = {'x': u.Q(1.0, 'm'), 'y': 2.0, 'z': u.Q(3.0, 'kg')} + >>> cdict_units(d, ('x', 'y', 'z')) + (Unit("m"), None, Unit("kg")) + + """ + # `unit_of()` returns None for non-quantities, so this works for both cases. + return cast("PackedUnitOutput", tuple(u.unit_of(p[k]) for k in keys)) diff --git a/tests/unit/internal/test_quantity_matrix.py b/tests/unit/internal/test_quantity_matrix.py index 2a752f25..3a9c93c5 100644 --- a/tests/unit/internal/test_quantity_matrix.py +++ b/tests/unit/internal/test_quantity_matrix.py @@ -1,4 +1,4 @@ -"""Tests for coordinax._src.quantity_matrix.QuantityMatrix.""" +"""Tests for coordinax._src.quantity_matrix.QMatrix.""" import math @@ -17,8 +17,10 @@ from coordinax._src.internal.quantity_matrix import ( _convert_value_matrix, _convert_value_vector, + det as qm_det, + inv as qm_inv, ) -from coordinax.internal import QuantityMatrix as QMat, UnitsMatrix +from coordinax.internal import QMatrix as QMat, UnitsMatrix # --------------------------------------------------------------------------- # Unit shorthands (visual noise reduction) @@ -48,11 +50,8 @@ def unit_2x2(): @pytest.fixture def qm_2x2(unit_2x2): - """Return a 2x2 QuantityMatrix with values 1-4.""" - return QMat( - value=jnp.array([[1.0, 2.0], [3.0, 4.0]]), - unit=unit_2x2, - ) + """Return a 2x2 QMatrix with values 1-4.""" + return QMat(value=jnp.array([[1, 2], [3, 4]]), unit=unit_2x2) @pytest.fixture @@ -69,11 +68,8 @@ def unit_1d(): @pytest.fixture def qm_1d(unit_1d): - """Return a 1D QuantityMatrix (vector) with values 1-3.""" - return QMat( - value=jnp.array([1.0, 2.0, 3.0]), - unit=unit_1d, - ) + """Return a 1D QMatrix (vector) with values 1-3.""" + return QMat(value=jnp.array([1, 2, 3]), unit=unit_1d) @pytest.fixture @@ -88,7 +84,7 @@ def unit_1d_alt(): class TestConstruction: - """Tests for QuantityMatrix construction and basic properties.""" + """Tests for QMatrix construction and basic properties.""" def test_shape(self, qm_2x2): assert qm_2x2.shape == (2, 2) @@ -100,7 +96,7 @@ def test_n_cols(self, qm_2x2): assert qm_2x2.shape[-1] == 2 def test_value(self, qm_2x2): - expected = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + expected = jnp.array([[1, 2], [3, 4]]) assert jnp.array_equal(qm_2x2.value, expected) def test_unit(self, qm_2x2, unit_2x2): @@ -124,7 +120,7 @@ def test_batch_dims(self): def test_1x1(self): """Degenerate 1x1 matrix.""" - qm = QMat(jnp.array([[42.0]]), unit=((_m,),)) + qm = QMat(jnp.array([[42]]), unit=((_m,),)) assert qm.shape[-2] == 1 assert qm.shape[-1] == 1 @@ -140,7 +136,7 @@ def test_unit_is_unitsmatrix(self, qm_2x2): def test_unit_converter_from_plain_tuples(self): """Plain nested tuples (of strings) are converted to ``UnitsMatrix``.""" - qm = QMat(jnp.array([[1.0]]), unit=(("m",),)) + qm = QMat(jnp.array([[1]]), unit=(("m",),)) assert isinstance(qm.unit, UnitsMatrix) assert qm.unit[0, 0] == _m @@ -153,12 +149,12 @@ def test_1d_construction(self, qm_1d, unit_1d): def test_1d_value(self, qm_1d): """1D vector value.""" - expected = jnp.array([1.0, 2.0, 3.0]) + expected = jnp.array([1, 2, 3]) assert jnp.array_equal(qm_1d.value, expected) def test_1d_from_strings(self): """1D vector from unit strings.""" - qm = QMat(jnp.array([7.0, 8.0]), unit=("m", "s")) + qm = QMat(jnp.array([7, 8]), unit=("m", "s")) assert isinstance(qm.unit, UnitsMatrix) assert qm.unit[0] == _m assert qm.unit[1] == _s @@ -178,16 +174,16 @@ def test_ndim_property_2d(self, qm_2x2): assert qm_2x2.ndim == 2 def test_repr(self, qm_2x2): - """``repr(QuantityMatrix(...))`` succeeds and contains key info.""" + """``repr(QMatrix(...))`` succeeds and contains key info.""" r = repr(qm_2x2) - assert "QuantityMatrix" in r + assert "QMatrix" in r assert "((m, s), (kg, rad))" in r def test_repr_1x1(self): """Repr for a 1x1 matrix includes trailing-comma tuple syntax.""" - qm = QMat(jnp.array([[42.0]]), unit=((_m,),)) + qm = QMat(jnp.array([[42]]), unit=((_m,),)) r = repr(qm) - assert "QuantityMatrix" in r + assert "QMatrix" in r assert "((m,),)" in r @@ -379,31 +375,31 @@ class TestConvertValueMatrix: def test_noop_same_units(self, unit_2x2): """If from_units == to_units no conversion happens.""" - val = jnp.array([[7.0, 8.0], [9.0, 10.0]]) + val = jnp.array([[7, 8], [9, 10]]) out = _convert_value_matrix(val, unit_2x2, unit_2x2) assert jnp.array_equal(out, val) def test_km_to_m(self): """1 km → 1000 m.""" - out = _convert_value_matrix(jnp.array([[3.0]]), ((_km,),), ((_m,),)) - assert jnp.isclose(out[0, 0], 3000.0) + out = _convert_value_matrix(jnp.array([[3]]), ((_km,),), ((_m,),)) + assert jnp.isclose(out[0, 0], 3000) def test_mixed_conversion(self, unit_2x2, unit_2x2_alt): """Convert from (km, ms, g, deg) → (m, s, kg, rad).""" - val = jnp.array([[1.0, 1000.0], [3000.0, 180.0]]) + val = jnp.array([[1, 1000], [3000, 180]]) out = _convert_value_matrix(val, unit_2x2_alt, unit_2x2) - assert jnp.isclose(out[0, 0], 1000.0) # 1 km -> 1000 m - assert jnp.isclose(out[0, 1], 1.0) # 1000 ms -> 1 s - assert jnp.isclose(out[1, 0], 3.0) # 3000 g -> 3 kg + assert jnp.isclose(out[0, 0], 1000) # 1 km -> 1000 m + assert jnp.isclose(out[0, 1], 1) # 1000 ms -> 1 s + assert jnp.isclose(out[1, 0], 3) # 3000 g -> 3 kg assert jnp.isclose(out[1, 1], math.pi, atol=1e-4) # 180 deg -> pi rad def test_preserves_batch(self): """Batch dimensions are preserved.""" - val = jnp.array([[[2.0]], [[5.0]]]) # (2, 1, 1) + val = jnp.array([[[2]], [[5]]]) # (2, 1, 1) out = _convert_value_matrix(val, ((_km,),), ((_m,),)) assert out.shape == (2, 1, 1) - assert jnp.isclose(out[0, 0, 0], 2000.0) - assert jnp.isclose(out[1, 0, 0], 5000.0) + assert jnp.isclose(out[0, 0, 0], 2000) + assert jnp.isclose(out[1, 0, 0], 5000) class TestConvertValuePoint: @@ -417,24 +413,24 @@ def test_noop_same_units(self, unit_1d): def test_km_to_m(self): """1 km → 1000 m.""" - out = _convert_value_vector(jnp.array([3.0]), (_km,), (_m,)) - assert jnp.isclose(out[0], 3000.0) + out = _convert_value_vector(jnp.array([3]), (_km,), (_m,)) + assert jnp.isclose(out[0], 3000) def test_mixed_conversion(self, unit_1d, unit_1d_alt): """Convert from (km, ms, g) → (m, s, kg).""" - val = jnp.array([1.0, 1000.0, 3000.0]) + val = jnp.array([1, 1000, 3000]) out = _convert_value_vector(val, unit_1d_alt, unit_1d) - assert jnp.isclose(out[0], 1000.0) # 1 km -> 1000 m - assert jnp.isclose(out[1], 1.0) # 1000 ms -> 1 s - assert jnp.isclose(out[2], 3.0) # 3000 g -> 3 kg + assert jnp.isclose(out[0], 1000) # 1 km -> 1000 m + assert jnp.isclose(out[1], 1) # 1000 ms -> 1 s + assert jnp.isclose(out[2], 3) # 3000 g -> 3 kg def test_preserves_batch(self): """Batch dimensions are preserved.""" - val = jnp.array([[2.0], [5.0]]) # (2, 1) + val = jnp.array([[2], [5]]) # (2, 1) out = _convert_value_vector(val, (_km,), (_m,)) assert out.shape == (2, 1) - assert jnp.isclose(out[0, 0], 2000.0) - assert jnp.isclose(out[1, 0], 5000.0) + assert jnp.isclose(out[0, 0], 2000) + assert jnp.isclose(out[1, 0], 5000) # --------------------------------------------------------------------------- @@ -448,16 +444,13 @@ def _add(a, b): class TestAddition: - """Tests for QuantityMatrix + QuantityMatrix.""" + """Tests for QMatrix + QMatrix.""" def test_same_units(self, qm_2x2, unit_2x2): """Simple add, same units.""" - other = QMat( - value=jnp.array([[10.0, 20.0], [30.0, 40.0]]), - unit=unit_2x2, - ) + other = QMat(value=jnp.array([[10, 20], [30, 40]]), unit=unit_2x2) result = _add(qm_2x2, other) - expected = jnp.array([[11.0, 22.0], [33.0, 44.0]]) + expected = jnp.array([[11, 22], [33, 44]]) assert jnp.allclose(result.value, expected) assert result.unit == unit_2x2 @@ -561,7 +554,7 @@ def _sub(a, b): class TestSubtraction: - """Tests for QuantityMatrix - QuantityMatrix.""" + """Tests for QMatrix - QMatrix.""" def test_same_units(self, qm_2x2, unit_2x2): """Simple sub, same units.""" @@ -659,7 +652,7 @@ def _matmul(a, b): class TestDotProduct: - """Tests for QuantityMatrix @ QuantityMatrix.""" + """Tests for QMatrix @ QMatrix.""" def test_simple_matmul_uniform_units(self): """2x2 @ 2x1 with uniform units along contraction axis.""" @@ -747,9 +740,9 @@ def test_1d_dot_product_uniform_units(self): a = QMat(jnp.array([2.0, 3.0]), unit=(_m, _m)) b = QMat(jnp.array([4.0, 5.0]), unit=(_s, _s)) result = _matmul(a, b) - # Result should be a scalar Quantity, not a QuantityMatrix + # Result should be a scalar Quantity, not a QMatrix # 2*4 + 3*5 = 8 + 15 = 23 in m*s - assert isinstance(result, u.Quantity) + assert isinstance(result, u.Q) assert jnp.isclose(result.value, 23.0) assert result.unit == _m * _s @@ -760,7 +753,7 @@ def test_1d_dot_product_mixed_units(self): a = QMat(jnp.array([1.0, 1.0]), unit=(_m, _km)) b = QMat(jnp.array([1.0, 1.0]), unit=(_s, _s)) result = _matmul(a, b) - assert isinstance(result, u.Quantity) + assert isinstance(result, u.Q) assert jnp.isclose(result.value, 1001.0) assert result.unit == _m * _s @@ -787,7 +780,7 @@ def dot_batched(x, y): class TestJaxIntegration: - """QuantityMatrix works with JAX transformations.""" + """QMatrix works with JAX transformations.""" def test_jit_add(self, unit_2x2): """jit-compiled addition.""" @@ -805,7 +798,7 @@ def test_jit_matmul(self): assert jnp.isclose(result.value[0, 0], 23.0) def test_pytree_flatten_unflatten(self, qm_2x2, unit_2x2): - """QuantityMatrix is a proper PyTree.""" + """QMatrix is a proper PyTree.""" leaves, treedef = jax.tree.flatten(qm_2x2) restored = jax.tree.unflatten(treedef, leaves) assert jnp.array_equal(restored.value, qm_2x2.value) @@ -831,41 +824,35 @@ def add_batched(x, y): class TestPlumConversion: - """Tests for ``plum.convert`` registrations involving ``QuantityMatrix``.""" + """Tests for ``plum.convert`` registrations involving ``QMatrix``.""" - def test_quantitymatrix_to_quantity_uniform_1d(self): - """1D uniform-unit ``QuantityMatrix`` converts to ``u.Quantity``.""" + def test_QMatrix_to_quantity_uniform_1d(self): + """1D uniform-unit ``QMatrix`` converts to ``u.Q``.""" qm = QMat(value=jnp.array([1.0, 2.0, 3.0]), unit=(_m, _m, _m)) - result = plum.convert(qm, u.Quantity) + result = plum.convert(qm, u.Q) - assert isinstance(result, u.Quantity) + assert isinstance(result, u.Q) assert result.unit == _m assert jnp.array_equal(result.value, qm.value) - def test_quantitymatrix_to_quantity_uniform_2d(self): - """2D uniform-unit ``QuantityMatrix`` converts to ``u.Quantity``.""" - qm = QMat( - value=jnp.array([[1.0, 2.0], [3.0, 4.0]]), - unit=((_s, _s), (_s, _s)), - ) + def test_QMatrix_to_quantity_uniform_2d(self): + """2D uniform-unit ``QMatrix`` converts to ``u.Q``.""" + qm = QMat(value=jnp.array([[1, 2], [3, 4]]), unit=((_s, _s), (_s, _s))) - result = plum.convert(qm, u.Quantity) + result = plum.convert(qm, u.Q) - assert isinstance(result, u.Quantity) + assert isinstance(result, u.Q) assert result.unit == _s assert result.shape == (2, 2) assert jnp.array_equal(result.value, qm.value) - def test_quantitymatrix_to_quantity_heterogeneous_units_raises(self): - """Mixed units cannot be converted to a single ``u.Quantity``.""" + def test_QMatrix_to_quantity_heterogeneous_units_raises(self): + """Mixed units cannot be converted to a single ``u.Q``.""" qm = QMat(value=jnp.array([1.0, 2.0]), unit=(_m, _s)) - with pytest.raises( - ValueError, - match="all units are identical", - ): - plum.convert(qm, u.Quantity) + with pytest.raises(ValueError, match="all units are identical"): + plum.convert(qm, u.Q) # --------------------------------------------------------------------------- @@ -933,7 +920,7 @@ def test_kelvin_times_s_is_convertible(self): class TestMatVec: - """Tests for 2D `QuantityMatrix` @ 1D `QuantityMatrix`.""" + """Tests for 2D `QMatrix` @ 1D `QMatrix`.""" def test_identity_uniform_units(self): """Identity 3x3 @ uniform-unit vector → same vector.""" @@ -953,8 +940,8 @@ def test_uniform_units_values(self): assert jnp.isclose(w.value[0], 80.0) assert jnp.isclose(w.value[1], 140.0) - def test_output_is_1d_quantitymatrix(self): - """Result of 2D @ 1D is a 1D ``QuantityMatrix``, not a 2D one.""" + def test_output_is_1d_QMatrix(self): + """Result of 2D @ 1D is a 1D ``QMatrix``, not a 2D one.""" A = QMat(jnp.array([[1.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) v = QMat(jnp.array([3.0, 7.0]), unit=(_s, _s)) w = _matmul(A, v) @@ -1130,7 +1117,7 @@ def test_diag_dimensionless_unit_string(self): assert s.count("(") == 1, f"Expected 1D unit string, got: {s!r}" def test_diag_under_jit_uniform_units(self): - """jnp.diag under jit works for uniform-unit QuantityMatrix.""" + """jnp.diag under jit works for uniform-unit QMatrix.""" A = QMat( jnp.diag(jnp.array([1.0, 4.0, 9.0])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), @@ -1143,12 +1130,12 @@ def test_diag_under_jit_uniform_units(self): # --------------------------------------------------------------------------- -# QuantityMatrix.diag() method +# QMatrix.diag() method # --------------------------------------------------------------------------- class TestDiagMethod: - """Tests for ``QuantityMatrix.diag()`` — the method that bypasses JAX gather.""" + """Tests for ``QMatrix.diag()`` — the method that bypasses JAX gather.""" def test_uniform_units_values(self): """Diagonal values are correct for a uniform-unit matrix.""" @@ -1238,7 +1225,7 @@ def test_non_square_picks_min_dim(self): assert d.unit[1] == _km # unit[1,1] def test_1d_raises(self, qm_1d): - """Calling .diag() on a 1-D QuantityMatrix raises ValueError.""" + """Calling .diag() on a 1-D QMatrix raises ValueError.""" with pytest.raises(ValueError, match="2D"): qm_1d.diag() @@ -1280,6 +1267,69 @@ def test_batch_dimensions(self): assert jnp.isclose(d.value[2, 1], 12.0) +# --------------------------------------------------------------------------- +# UnitsMatrix.inverse +# --------------------------------------------------------------------------- + + +class TestUnitsMatrixInverse: + """Tests for ``UnitsMatrix.inverse`` — element-wise unit reciprocal.""" + + def test_1d_values(self): + """1-D: each unit is reciprocated.""" + um = UnitsMatrix((_m, _s)) + inv = um.inverse() + assert inv[0] == _m ** (-1) + assert inv[1] == _s ** (-1) + + def test_1d_shape_preserved(self): + """1-D inverse has the same shape.""" + um = UnitsMatrix((_m, _s, _kg)) + assert um.inverse().shape == (3,) + + def test_1d_returns_unitsmatrix(self): + """Result is a ``UnitsMatrix`` instance.""" + um = UnitsMatrix((_m, _s)) + assert isinstance(um.inverse(), UnitsMatrix) + + def test_1d_double_inverse_is_identity(self): + """Two inversions return the original unit structure.""" + um = UnitsMatrix((_m, _s, _kg)) + assert um == um.inverse().inverse() + + def test_2d_uniform_values(self): + """2-D uniform matrix: every entry is reciprocated.""" + um = UnitsMatrix(((_m, _m), (_m, _m))) + inv = um.inverse() + for i in range(2): + for j in range(2): + assert inv[i, j] == _m ** (-1) + + def test_2d_uniform_shape_preserved(self): + """2-D uniform inverse has the same shape.""" + um = UnitsMatrix(((_m, _m), (_m, _m))) + assert um.inverse().shape == (2, 2) + + def test_2d_mixed_values(self): + """2-D mixed-unit matrix: element-wise reciprocal.""" + um = UnitsMatrix(((_m, _s), (_kg, _rad))) + inv = um.inverse() + assert inv[0, 0] == _m ** (-1) + assert inv[0, 1] == _s ** (-1) + assert inv[1, 0] == _kg ** (-1) + assert inv[1, 1] == _rad ** (-1) + + def test_2d_returns_unitsmatrix(self): + """Result is always a ``UnitsMatrix`` instance.""" + um = UnitsMatrix(((_m, _s), (_kg, _rad))) + assert isinstance(um.inverse(), UnitsMatrix) + + def test_2d_double_inverse_is_identity(self): + """Two inversions return the original unit structure.""" + um = UnitsMatrix(((_m, _s), (_kg, _rad))) + assert um == um.inverse().inverse() + + # --------------------------------------------------------------------------- # UnitsMatrix.T # --------------------------------------------------------------------------- @@ -1341,7 +1391,7 @@ def test_1d_transpose_is_identity(self): # --------------------------------------------------------------------------- -# QuantityMatrix.T +# QMatrix.T # --------------------------------------------------------------------------- @@ -1350,8 +1400,8 @@ def _transpose(x): return x.T -class TestQuantityMatrixTranspose: - """Tests for ``QuantityMatrix.T`` — the matrix transpose property.""" +class TestQMatrixTranspose: + """Tests for ``QMatrix.T`` — the matrix transpose property.""" # -- Basic 2D values and units ---------------------------------------- @@ -1378,8 +1428,8 @@ def test_2d_square_shape_preserved(self, qm_2x2): """Shape is unchanged for a square matrix.""" assert qm_2x2.T.shape == (2, 2) - def test_2d_square_returns_quantitymatrix(self, qm_2x2): - """Result is a ``QuantityMatrix`` instance.""" + def test_2d_square_returns_QMatrix(self, qm_2x2): + """Result is a ``QMatrix`` instance.""" assert isinstance(qm_2x2.T, QMat) def test_2d_nonsquare_values(self): @@ -1494,9 +1544,291 @@ def test_jit_units(self, qm_2x2): # -- 1-D error -------------------------------------------------------- def test_1d_raises(self, qm_1d): - """Accessing ``.T`` on a 1-D ``QuantityMatrix`` raises ``ValueError``. + """Accessing ``.T`` on a 1-D ``QMatrix`` raises ``ValueError``. The ``.T`` property requires a 2-D unit structure to unpack ``(n, m)``. """ with pytest.raises(ValueError, match="requires a 2-D matrix"): _ = qm_1d.T + + +# --------------------------------------------------------------------------- +# det_p custom primitive + Quax dispatch for QMatrix +# --------------------------------------------------------------------------- + + +class TestDetPrimitive: + """Tests for the custom ``det_p`` JAX primitive on plain arrays.""" + + def test_det_2x2_diagonal(self): + """Det on a 2×2 diagonal matrix returns the product of the diagonals.""" + A = jnp.array([[2.0, 0.0], [0.0, 3.0]]) + assert jnp.allclose(qm_det(A), jnp.linalg.det(A)) + + def test_det_3x3_identity(self): + """det(I_3 * 2) == 8.""" + A = jnp.eye(3) * 2.0 + assert jnp.allclose(qm_det(A), 8.0) + + def test_det_matches_jnp_linalg_det(self): + """det_p gives the same result as jnp.linalg.det for a generic matrix.""" + A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + assert jnp.allclose(qm_det(A), jnp.linalg.det(A)) + + def test_det_jit(self): + """det_p works under jax.jit.""" + A = jnp.array([[2.0, 0.0], [0.0, 3.0]]) + result = jax.jit(qm_det)(A) + assert jnp.allclose(result, 6.0) + + def test_det_jvp(self): + """Forward-mode derivative matches Jacobi's formula. + + d det(A)(dA) = det(A) * tr(A^-1 dA). + """ + A = jnp.array([[2.0, 0.0], [0.0, 3.0]]) + dA = jnp.ones((2, 2)) + primal, tangent = jax.jvp(qm_det, (A,), (dA,)) + # det(A) = 6 + # tr(A⁻¹ dA) = tr([[0.5,0],[0,1/3]] @ [[1,1],[1,1]]) = 0.5 + 1/3 = 5/6 + # tangent = 6 * 5/6 = 5 + assert jnp.allclose(primal, 6.0) + assert jnp.allclose(tangent, 5.0) + + def test_det_grad(self): + """Reverse-mode gradient matches Jacobi's formula: ∂det(A)/∂A = det(A)·A⁻ᵀ.""" + A = jnp.array([[2.0, 0.0], [0.0, 3.0]]) + # grad_A = det(A) * A^{-T} = 6 * diag(0.5, 1/3) = diag(3, 2) + grad_A = jax.grad(qm_det)(A) + expected = jnp.array([[3.0, 0.0], [0.0, 2.0]]) + assert jnp.allclose(grad_A, expected) + + def test_det_jit_grad(self): + """jit(grad(det)) works correctly.""" + A = jnp.array([[2.0, 0.0], [0.0, 3.0]]) + grad_A = jax.jit(jax.grad(qm_det))(A) + expected = jnp.array([[3.0, 0.0], [0.0, 2.0]]) + assert jnp.allclose(grad_A, expected) + + def test_det_vmap(self): + """det_p works under jax.vmap — maps over a batch of matrices.""" + A = jnp.stack( + [jnp.diag(jnp.array([2.0, 3.0])), jnp.diag(jnp.array([4.0, 5.0]))] + ) + results = jax.vmap(qm_det)(A) + expected = jnp.array([6.0, 20.0]) + assert jnp.allclose(results, expected) + + def test_det_jit_vmap(self): + """jit(vmap(det)) works correctly.""" + A = jnp.stack( + [jnp.diag(jnp.array([2.0, 3.0])), jnp.diag(jnp.array([4.0, 5.0]))] + ) + results = jax.jit(jax.vmap(qm_det))(A) + expected = jnp.array([6.0, 20.0]) + assert jnp.allclose(results, expected) + + def test_det_batched_shape(self): + """det_p on a (*batch, n, n) array returns shape (*batch,).""" + A = jnp.ones((3, 4, 2, 2)) + # det of 2×2 ones matrix = 1*1 - 1*1 = 0 + result = qm_det(A) + assert result.shape == (3, 4) + + +class TestDetQMatrix: + """Tests for det_p Quax dispatch on QMatrix.""" + + def test_returns_abstract_quantity(self): + """Det of a 2×2 QMatrix returns an AbstractQuantity.""" + A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + result = quax.quaxify(qm_det)(A) + assert isinstance(result, u.AbstractQuantity) + + def test_value_2x2_diagonal(self): + """Numeric value equals jnp.linalg.det of the value array.""" + A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + result = quax.quaxify(qm_det)(A) + assert jnp.allclose(result.value, 6.0) + + def test_unit_product_of_diagonal(self): + """Unit is the product of the main-diagonal units: m·m = m².""" + A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + result = quax.quaxify(qm_det)(A) + assert result.unit == u.unit("m2") + + def test_unit_heterogeneous_diagonal(self): + """Unit is u00 * u11 for a 2×2 matrix with mixed diagonal units.""" + A = QMat(jnp.eye(2), unit=((_m, _s), (_m, _s))) + result = quax.quaxify(qm_det)(A) + assert result.unit == _m * _s + + def test_unit_3x3_uniform(self): + """Det of 3×3 identity with uniform unit m gives unit m³.""" + A = QMat(jnp.eye(3), unit=((_m, _m, _m),) * 3) + result = quax.quaxify(qm_det)(A) + assert jnp.allclose(result.value, 1.0) + assert result.unit == u.unit("m3") + + def test_jit_QMatrix(self): + """Det of QMatrix works under jax.jit.""" + A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + result = jax.jit(quax.quaxify(qm_det))(A) + assert jnp.allclose(result.value, 6.0) + assert result.unit == u.unit("m2") + + +# --------------------------------------------------------------------------- +# inv_p custom primitive + Quax dispatch for QMatrix +# --------------------------------------------------------------------------- + + +class TestInvPrimitive: + """Tests for the custom ``inv_p`` JAX primitive on plain arrays.""" + + def test_inv_2x2_diagonal(self): + """inv_p on a diagonal matrix returns the reciprocal diagonal.""" + A = jnp.array([[2.0, 0.0], [0.0, 4.0]]) + result = qm_inv(A) + expected = jnp.linalg.inv(A) + assert jnp.allclose(result, expected) + + def test_inv_3x3_identity(self): + """inv(I) == I.""" + A = jnp.eye(3) + assert jnp.allclose(qm_inv(A), A) + + def test_inv_matches_jnp_linalg_inv(self): + """inv_p gives the same result as jnp.linalg.inv for a generic matrix.""" + A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + assert jnp.allclose(qm_inv(A), jnp.linalg.inv(A)) + + def test_inv_jit(self): + """inv_p works under jax.jit.""" + A = jnp.array([[2.0, 0.0], [0.0, 4.0]]) + result = jax.jit(qm_inv)(A) + assert jnp.allclose(result, jnp.linalg.inv(A)) + + def test_inv_jvp(self): + """Forward-mode derivative of A^{-1}: d(A^{-1})(dA) = -A^{-1} dA A^{-1}.""" + A = jnp.array([[2.0, 0.0], [0.0, 4.0]]) + dA = jnp.ones((2, 2)) + primal, tangent = jax.jvp(qm_inv, (A,), (dA,)) + assert jnp.allclose(primal, jnp.linalg.inv(A)) + # expected tangent: -A^{-1} dA A^{-1} + Ainv = jnp.linalg.inv(A) + expected_tangent = -(Ainv @ dA @ Ainv) + assert jnp.allclose(tangent, expected_tangent) + + def test_inv_grad(self): + """Reverse-mode gradient via VJP derived from JVP.""" + A = jnp.array([[2.0, 0.0], [0.0, 4.0]]) + + # Scalar output f(A) = sum(inv(A)) + def f(a): + return jnp.sum(qm_inv(a)) + + grad_A = jax.grad(f)(A) + # Finite-difference check + eps = 1e-5 + fd = jnp.zeros_like(A) + for i in range(2): + for j in range(2): + Ap = A.at[i, j].add(eps) + Am = A.at[i, j].add(-eps) + fd_val = (jnp.sum(jnp.linalg.inv(Ap)) - jnp.sum(jnp.linalg.inv(Am))) / ( + 2 * eps + ) + fd = fd.at[i, j].set(fd_val) + assert jnp.allclose(grad_A, fd, atol=1e-4) + + def test_inv_jit_grad(self): + """jit(grad(sum(inv(A)))) works correctly.""" + A = jnp.array([[2.0, 0.0], [0.0, 4.0]]) + + def f(a): + return jnp.sum(qm_inv(a)) + + grad_A = jax.jit(jax.grad(f))(A) + + def g(a): + return jnp.sum(jnp.linalg.inv(a)) + + expected = jax.grad(g)(A) + assert jnp.allclose(grad_A, expected, atol=1e-6) + + def test_inv_vmap(self): + """inv_p works under jax.vmap — maps over a batch of matrices.""" + A = jnp.stack( + [jnp.diag(jnp.array([2.0, 4.0])), jnp.diag(jnp.array([1.0, 2.0]))] + ) + results = jax.vmap(qm_inv)(A) + expected = jax.vmap(jnp.linalg.inv)(A) + assert jnp.allclose(results, expected) + + def test_inv_jit_vmap(self): + """jit(vmap(inv)) works correctly.""" + A = jnp.stack( + [jnp.diag(jnp.array([2.0, 4.0])), jnp.diag(jnp.array([1.0, 2.0]))] + ) + results = jax.jit(jax.vmap(qm_inv))(A) + expected = jax.vmap(jnp.linalg.inv)(A) + assert jnp.allclose(results, expected) + + def test_inv_batched_shape(self): + """inv_p on a (*batch, n, n) array returns shape (*batch, n, n).""" + A = jnp.stack([jnp.eye(2)] * 6).reshape(3, 2, 2, 2) + result = jax.vmap(jax.vmap(qm_inv))(A) + assert result.shape == (3, 2, 2, 2) + + +class TestInvQMatrix: + """Tests for inv_p Quax dispatch on QMatrix.""" + + def test_returns_QMatrix(self): + """Inv of a 2×2 QMatrix returns a QMatrix.""" + A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + result = quax.quaxify(qm_inv)(A) + assert isinstance(result, QMat) + + def test_value_2x2_diagonal(self): + """Numeric value equals jnp.linalg.inv of the value array.""" + A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + result = quax.quaxify(qm_inv)(A) + expected_val = jnp.linalg.inv(jnp.array([[4.0, 0.0], [0.0, 1.0]])) + assert jnp.allclose(result.value, expected_val) + + def test_unit_reciprocal(self): + """Unit of the inverse is the reciprocal of the original unit: 1/m.""" + A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + result = quax.quaxify(qm_inv)(A) + expected_unit = u.unit("1 / m") + assert result.unit[0, 0] == expected_unit + assert result.unit[1, 1] == expected_unit + + def test_unit_m2_per_rad2(self): + """Inv of a metric with m²/rad² entries carries rad²/m² units.""" + m2_r2 = u.unit("m2 / rad2") + A = QMat( + jnp.array([[4.0, 0.0], [0.0, 1.0]]), + unit=((m2_r2, m2_r2), (m2_r2, m2_r2)), + ) + result = quax.quaxify(qm_inv)(A) + assert result.unit[0, 0] == u.unit("rad2 / m2") + + def test_jit_QMatrix(self): + """Inv of QMatrix works under jax.jit.""" + A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + result = jax.jit(quax.quaxify(qm_inv))(A) + assert jnp.allclose(result.value, jnp.array([[0.25, 0.0], [0.0, 1.0]])) + assert result.unit[0, 0] == u.unit("1 / m") + + def test_roundtrip_identity(self): + """A @ inv(A) ≈ I for a QMatrix (value check).""" + A = QMat( + jnp.array([[2.0, 1.0], [1.0, 3.0]]), + unit=((_m, _m), (_m, _m)), + ) + Ainv = quax.quaxify(qm_inv)(A) + product = A.value @ Ainv.value + assert jnp.allclose(product, jnp.eye(2), atol=1e-6) From 2c28f799f3fb865f0905c7ec52ce2b7e203d7309 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:22:04 -0400 Subject: [PATCH 02/15] =?UTF-8?q?=E2=9C=A8=20feat(metric):=20introduce=20?= =?UTF-8?q?=5Fsrc/metric=20module=20with=20abstract=20field,=20matrix=20ty?= =?UTF-8?q?pes,=20and=20dispatch=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new coordinax._src.metric package with four modules: field.py (AbstractMetricField, AbstractDiagonalMetricField, RoundMetric), matrix.py (AbstractMetricMatrix, DiagonalMetric, DenseMetric), api.py (metric_matrix, metric_representation fallback dispatches), and __init__.py. AbstractMetric/AbstractDiagonalMetric in base/metric.py are renamed to AbstractMetricField/AbstractDiagonalMetricField and stripped of their metric_matrix method — that computation is now a standalone dispatch function. --- src/coordinax/_src/base/metric.py | 330 +----------- src/coordinax/_src/metric/__init__.py | 16 + src/coordinax/_src/metric/api.py | 77 +++ src/coordinax/_src/metric/field.py | 132 +++++ src/coordinax/_src/metric/matrix.py | 505 ++++++++++++++++++ tests/unit/manifolds/test_metric_field.py | 184 +++++++ .../manifolds/test_metric_matrix_types.py | 250 +++++++++ 7 files changed, 1188 insertions(+), 306 deletions(-) create mode 100644 src/coordinax/_src/metric/__init__.py create mode 100644 src/coordinax/_src/metric/api.py create mode 100644 src/coordinax/_src/metric/field.py create mode 100644 src/coordinax/_src/metric/matrix.py create mode 100644 tests/unit/manifolds/test_metric_field.py create mode 100644 tests/unit/manifolds/test_metric_matrix_types.py diff --git a/src/coordinax/_src/base/metric.py b/src/coordinax/_src/base/metric.py index 3094ffff..5c63007d 100644 --- a/src/coordinax/_src/base/metric.py +++ b/src/coordinax/_src/base/metric.py @@ -1,27 +1,15 @@ """Manifold definitions and manifold inference helpers.""" -__all__ = ("AbstractMetric", "AbstractDiagonalMetric") +__all__ = ("AbstractMetricField", "AbstractDiagonalMetricField") import abc -from jaxtyping import Array, Bool -from typing import TYPE_CHECKING, Any - import jax -import jax.numpy as jnp - -import coordinax.angles as cxa -import coordinax.api.manifolds as cxmapi -from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix - -if TYPE_CHECKING: - import coordinax.charts # noqa: ICN001 @jax.tree_util.register_static -class AbstractMetric(metaclass=abc.ABCMeta): - r"""Abstract base class for metrics on representations. +class AbstractMetricField(metaclass=abc.ABCMeta): + r"""Abstract base class for metrics of manifolds. The metric defines a bilinear form on the tangent space of a chart. @@ -51,18 +39,14 @@ class AbstractMetric(metaclass=abc.ABCMeta): \end{pmatrix}. $$ + The metric matrix is computed via the standalone dispatch function + :func:`coordinax.manifolds.metric_matrix`. + Examples -------- - >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> import unxt as u - >>> metric = cxm.EuclideanMetric(3) - >>> at = {"x": u.Q(1.0, "km"), "y": u.Q(0.0, "km"), - ... "z": u.Q(0.0, "km")} - >>> metric.metric_matrix(cxc.cart3d, at=at) - QuantityMatrix([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], '((, , ), (, , ), (, , ))') + >>> cxm.FlatMetric(3).ndim + 3 """ @@ -77,214 +61,8 @@ def signature(self) -> tuple[int, ...]: """Return the signature of the metric as a tuple of integers.""" raise NotImplementedError # pragma: no cover - @abc.abstractmethod - def metric_matrix( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> QuantityMatrix | Array: - r"""Compute the metric tensor $g_{ij}$ at base point ``at``. - - Parameters - ---------- - chart : AbstractChart - The coordinate chart in which to express the metric. - at : CDict - Base point (component dict in ``chart``) at which to evaluate. - usys : OptUSys, optional - Unit system to use for the metric evaluation. - - Returns - ------- - QuantityMatrix, shape (n, n) - Symmetric positive-definite metric matrix ``g_{ij}`` in the chart - basis for ``chart``. - - """ - raise NotImplementedError # pragma: no cover - def scale_factors( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> QuantityMatrix: - r"""Return the diagonal entries of the metric matrix in ``chart`` at ``at``. - - This is a thin convenience wrapper over - ``cxmapi.scale_factors(self, chart, at=at, usys=usys)``. - - Examples - -------- - >>> import jax.numpy as jnp - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - - In Cartesian coordinates for Euclidean space, the diagonal entries are all 1: - - >>> metric = cxm.EuclideanMetric(3) - >>> at = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} - >>> gdiag = metric.scale_factors(cxc.cart3d, at=at) - >>> gdiag - QuantityMatrix([1., 1., 1.], '(, , )') - - In spherical coordinates on Euclidean space, the entries depend on the - base point: - - >>> at_sph = { - ... "r": u.Q(2.0, "km"), - ... "theta": u.Angle(jnp.pi / 2, "rad"), - ... "phi": u.Angle(0.0, "rad"), - ... } - >>> metric.scale_factors(cxc.sph3d, at=at_sph) - QuantityMatrix([1., 4., 4.], '(, km2 / rad2, km2 / rad2)') - - """ - return cxmapi.scale_factors(self, chart, at=at, usys=usys) # ty: ignore[invalid-return-type] - - def cholesky( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> QuantityMatrix | Array: - r"""Return the lower-triangular Cholesky factor $L$ of the metric matrix. - - Computes the factorization $g = L\,L^\top$ where $L$ is the unique - lower-triangular matrix with strictly positive diagonal entries. - With the convention used here, the vielbein is $E = L^\top$; use - ``L.value.T`` (or plain ``.T`` for a bare array) to obtain it. - - Parameters - ---------- - chart : AbstractChart - The coordinate chart in which to express the metric. - at : CDict - Base point (component dict in ``chart``) at which to evaluate. - usys : OptUSys, optional - Unit system to use for the metric evaluation. - - Returns - ------- - QuantityMatrix or Array, shape (n, n) - Lower-triangular Cholesky factor $L$ satisfying $g = L\,L^\top$. - Returns a ``QuantityMatrix`` when the metric matrix carries units; - returns a plain ``Array`` otherwise. The unit of element - $L_{ij}$ is $\sqrt{u_{ij}}$ where $u_{ij}$ is the unit of $g_{ij}$. - - Examples - -------- - >>> import jax.numpy as jnp - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - - Euclidean metric in Cartesian coordinates — Cholesky of the identity: - - >>> metric = cxm.EuclideanMetric(3) - >>> at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} - >>> metric.cholesky(cxc.cart3d, at=at) - QuantityMatrix([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], '((, , ), (, , ), (, , ))') - - Euclidean metric in spherical coordinates — Cholesky is diag(1, r, r sinθ): - - >>> at = { - ... "r": u.Q(2.0, "m"), - ... "theta": u.Angle(jnp.pi / 2, "rad"), - ... "phi": u.Angle(0.0, "rad"), - ... } - >>> metric.cholesky(cxc.sph3d, at=at) - QuantityMatrix( - [[1., 0., 0.], - [0., 2., 0.], - [0., 0., 2.]], - '((, m(1/2) / rad(1/2), m(1/2) / rad(1/2)), ...)' - ) - - """ - G = self.metric_matrix(chart, at=at, usys=usys) - if isinstance(G, QuantityMatrix): - l_units = UnitsMatrix(G.unit._units**0.5) - return QuantityMatrix(jnp.linalg.cholesky(G.value), unit=l_units) - return jnp.linalg.cholesky(G) - - def angle_between( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - uvec: CDict, - vvec: CDict, - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> cxa.AbstractAngle: - r"""Return the metric angle between two tangent vectors. - - This is a thin convenience wrapper over - ``cxmapi.angle_between(self, chart, uvec, vvec, at=at, usys=usys)``. - """ - return cxmapi.angle_between(self, chart, uvec, vvec, at=at, usys=usys) # ty: ignore[invalid-return-type] - - def is_diagonal( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> Bool[Array, ""]: - r"""Return ``True`` if the metric matrix is diagonal at base point ``at``. - - A metric is diagonal when all off-diagonal entries $g_{ij}$ with - $i \neq j$ are numerically zero (within floating-point tolerance, - checked via ``jnp.allclose``), i.e. the coordinate basis is - orthogonal at ``at``. - - Parameters - ---------- - chart : AbstractChart - The coordinate chart in which to evaluate the metric. - at : CDict - Base point at which to check the metric matrix. - usys : OptUSys, optional - Unit system for the evaluation. - - Returns - ------- - bool - ``True`` if all off-diagonal metric components vanish. - - Examples - -------- - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - - Euclidean metric in spherical coordinates is diagonal (orthogonal chart): - - >>> metric = cxm.EuclideanMetric(3) - >>> at = {"r": u.Q(3.0, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0.0, "rad")} - >>> metric.is_diagonal(cxc.sph3d, at=at) - Array(True, dtype=bool) - - """ - G = self.metric_matrix(chart, at=at, usys=usys) - val = G.value if hasattr(G, "value") else G - off_diagonal = jnp.subtract(val, jnp.diag(jnp.diag(val))) - return jnp.allclose(off_diagonal, 0) - - -class AbstractDiagonalMetric(AbstractMetric): +class AbstractDiagonalMetricField(AbstractMetricField): r"""Abstract base class for metrics whose matrix is diagonal. A metric is **diagonal** (equivalently, the coordinate chart is an @@ -305,51 +83,22 @@ class AbstractDiagonalMetric(AbstractMetric): **Role: structural marker, not behavioral interface.** - This class adds no new abstract methods beyond those of `AbstractMetric`. - Its purpose is to declare that ``metric_matrix`` **must** return a diagonal - matrix at every valid base point for charts where this metric is used as - diagonal (typically orthogonal charts). + This class adds no new abstract methods beyond those of `AbstractMetricField`. + Its purpose is to declare that the ``metric_matrix`` dispatch + **must** return a :class:`~coordinax._src.metric.matrix.DiagonalMetric` + at every valid base point for charts where this metric is used as diagonal + (typically orthogonal charts). In particular, manifold/atlas chart membership (for example, ``has_chart``) is a broader structural notion and does not, by itself, imply orthogonality or diagonality. - This structural guarantee enables type-level dispatch specialisation, for - example: - - - computing ``scale_factors`` from diagonal entries directly, - - extracting only ``diag(g)`` instead of inspecting a full matrix, and - - using squared Jacobian column norms instead of forming full - $J^\top J$ when a diagonal form is known a priori. - - Notes - ----- - Subclasses must implement the two abstract members inherited from - `AbstractMetric`: - - - ``signature`` (property): a tuple of $\pm 1$ of length ``ndim`` encoding - the metric signature. Positive entries denote Riemannian (space-like) - directions; a single ``-1`` entry denotes a time-like direction - (Lorentzian signature). - - ``metric_matrix(chart, /, *, at, usys=None)`` (method): must return a - diagonal ``QuantityMatrix`` (or plain ``Array``) of shape ``(ndim, ndim)`` - with all off-diagonal entries numerically zero (within floating-point - tolerance, i.e. ``jnp.allclose`` to zero). - - Concrete subclasses must be immutable frozen dataclasses and registered as - static JAX PyTree nodes via ``@jax.tree_util.register_static``. - - `AbstractMetric.is_diagonal` inspects the matrix at a **specific base - point** and returns a ``bool`` Array. ``AbstractDiagonalMetric`` makes this - an unconditional **structural promise** across all base points within the - metric's diagonal chart domain. - See Also -------- - EuclideanMetric : flat Riemannian metric on $\mathbb{R}^n$; in Cartesian + FlatMetric : flat Riemannian metric on $\mathbb{R}^n$; in Cartesian charts $g = I_n$; in orthogonal curvilinear charts computed by Jacobian pullback $g = J^\top J$. - HyperSphericalMetric : round metric on $S^{n-1}$ in the intrinsic + RoundMetric : round metric on $S^{n-1}$ in the intrinsic hyperspherical chart; diagonal entries follow the cumulative-sine rule $g_{kk} = \prod_{j < k} \sin^2\!\theta_j$. MinkowskiMetric : Lorentzian pseudo-Riemannian metric @@ -360,57 +109,26 @@ class AbstractDiagonalMetric(AbstractMetric): -------- >>> import coordinax.manifolds as cxm - ``EuclideanMetric`` is an ``AbstractDiagonalMetric``: + ``FlatMetric`` is an ``AbstractDiagonalMetricField``: - >>> isinstance(cxm.EuclideanMetric(3), AbstractDiagonalMetric) + >>> isinstance(cxm.FlatMetric(3), AbstractDiagonalMetricField) True - ``MinkowskiMetric`` is also an ``AbstractDiagonalMetric``: + ``MinkowskiMetric`` is also an ``AbstractDiagonalMetricField``: - >>> isinstance(cxm.MinkowskiMetric(), AbstractDiagonalMetric) + >>> isinstance(cxm.MinkowskiMetric(), AbstractDiagonalMetricField) True - General (non-diagonal) metrics such as ``InducedMetric`` are not: + General (non-diagonal) metrics such as ``PullbackMetric`` are not: >>> import unxt as u >>> isinstance( - ... cxm.InducedMetric( + ... cxm.PullbackMetric( ... cxm.TwoSphereIn3D(radius=u.Q(1.0, "m")), - ... cxm.EuclideanMetric(3), + ... cxm.FlatMetric(3), ... ), - ... AbstractDiagonalMetric, + ... AbstractDiagonalMetricField, ... ) False """ - - def is_diagonal( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> Bool[Array, ""]: - r"""Return ``True`` as a structural guarantee of diagonality. - - For ``AbstractDiagonalMetric`` the metric is diagonal by type-level - contract on its diagonal chart domain (typically orthogonal charts), so - this method is unconditional and does not inspect ``metric_matrix``. - - Examples - -------- - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - - Euclidean metric in spherical coordinates is diagonal (orthogonal chart): - - >>> metric = cxm.EuclideanMetric(3) - >>> at = {"r": u.Q(3.0, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0.0, "rad")} - >>> metric.is_diagonal(cxc.sph3d, at=at) - Array(True, dtype=bool) - - """ - del chart, at, usys - return jnp.ones((), dtype=bool) diff --git a/src/coordinax/_src/metric/__init__.py b/src/coordinax/_src/metric/__init__.py new file mode 100644 index 00000000..d56693bc --- /dev/null +++ b/src/coordinax/_src/metric/__init__.py @@ -0,0 +1,16 @@ +"""Metrics — intrinsic metric fields, matrix representations, and dispatch API.""" + +__all__ = ( + "AbstractMetricField", + "AbstractDiagonalMetricField", + "RoundMetric", + "AbstractMetricMatrix", + "DiagonalMetric", + "DenseMetric", + "metric_matrix", + "metric_representation", +) + +from .api import metric_matrix, metric_representation +from .field import AbstractDiagonalMetricField, AbstractMetricField, RoundMetric +from .matrix import AbstractMetricMatrix, DenseMetric, DiagonalMetric diff --git a/src/coordinax/_src/metric/api.py b/src/coordinax/_src/metric/api.py new file mode 100644 index 00000000..d4c81a4c --- /dev/null +++ b/src/coordinax/_src/metric/api.py @@ -0,0 +1,77 @@ +"""Concrete fallback dispatch rules for metric_matrix and metric_representation. + +The dispatch function objects are defined (as abstract) in +:mod:`coordinax.api.manifolds`. This module registers the concrete fallback +rules that apply to any :class:`~coordinax._src.base.AbstractManifold` / +:class:`~coordinax._src.base.AbstractChart` pair not covered by a more specific +rule registered in a ``register_metric.py`` module. + +Importing this module is sufficient to ensure both fallback rules are active; +the ``register_metric.py`` modules import from here so that these rules are +always present before any specific rule is added. +""" + +__all__ = ("metric_matrix", "metric_representation") + +import plum + +from .matrix import AbstractMetricMatrix, DenseMetric +from coordinax._src.base import AbstractChart, AbstractManifold + + +@plum.dispatch +def metric_matrix( + M: AbstractManifold, point: dict, chart: AbstractChart, / +) -> AbstractMetricMatrix: + """Fallback — raise with a helpful message for unregistered pairs. + + Concrete ``(manifold, chart)`` pairs register their own dispatch rules in + the relevant ``register_metric.py`` modules (loaded as part of Phase 2). + + Parameters + ---------- + M : AbstractManifold + The manifold carrying the metric field. + point : CDict + A component dictionary giving the coordinates in ``chart``. + chart : AbstractChart + The coordinate chart in which to express the metric. + + Raises + ------ + NotImplementedError + When no specific dispatch rule is registered for the given types. + + """ + del point + msg = ( + f"No metric_matrix dispatch registered for " + f"manifold={type(M).__name__!r}, chart={type(chart).__name__!r}. " + f"Register a rule with @plum.dispatch on metric_matrix." + ) + raise NotImplementedError(msg) + + +@plum.dispatch +def metric_representation( + M: AbstractManifold, chart: AbstractChart, / +) -> type[AbstractMetricMatrix]: + """Return `DenseMetric` as the default fallback. + + More specific rules (e.g. for Cartesian charts) override this and return + `DiagonalMetric`. + + >>> import coordinax.manifolds as cxm + >>> import coordinax.charts as cxc + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DenseMetric + >>> from coordinax._src.charts.d3 import LonCosLatSpherical3D + + Non-orthogonal chart falls back to ``DenseMetric``: + + >>> metric_representation(cxm.R3, LonCosLatSpherical3D()) is DenseMetric + True + + """ + del M, chart + return DenseMetric diff --git a/src/coordinax/_src/metric/field.py b/src/coordinax/_src/metric/field.py new file mode 100644 index 00000000..7e20c391 --- /dev/null +++ b/src/coordinax/_src/metric/field.py @@ -0,0 +1,132 @@ +"""Intrinsic metric field types. + +A *metric field* is a smooth family of inner-product structures on the tangent +spaces of a manifold. It describes the *kind* of geometry — flat, round, +Lorentzian, etc. — without referencing a particular coordinate chart or +producing a matrix of components. The latter is the responsibility of +:class:`~coordinax._src.metric.matrix.AbstractMetricMatrix` together with the +``metric_matrix`` dispatch API. +""" + +__all__ = ("AbstractMetricField", "AbstractDiagonalMetricField", "RoundMetric") + +import abc + +from typing import final + +import equinox as eqx +import jax.tree_util as jtu + +import unxt as u + + +@jtu.register_static +class AbstractMetricField(metaclass=abc.ABCMeta): + r"""Abstract base class for intrinsic metric fields. + + A metric field associates each point $p$ of a manifold $M$ with an + inner-product $g_p$ on the tangent space $T_p M$. Subclasses encode the + *kind* of geometry (flat, round, Lorentzian, …) without specifying a + coordinate representation. + + Subclasses must implement the :attr:`signature` property. + + Notes + ----- + Concrete parameter-free subclasses are registered as static JAX pytree + leaves via ``@jax.tree_util.register_static``. Subclasses that carry + JAX-traced parameters (e.g. :class:`RoundMetric` with its ``radius``) are + :class:`equinox.Module` pytrees instead. + + """ + + @property + def ndim(self) -> int: + """Intrinsic dimension — length of :attr:`signature`.""" + return len(self.signature) + + @property + @abc.abstractmethod + def signature(self) -> tuple[int, ...]: + """Metric signature as a tuple of ±1 values. + + Returns ``(1, 1, ..., 1)`` for Riemannian metrics and includes ``-1`` + entries for pseudo-Riemannian (Lorentzian) metrics. + """ + raise NotImplementedError # pragma: no cover + + +@jtu.register_static +class AbstractDiagonalMetricField(AbstractMetricField, metaclass=abc.ABCMeta): + """Structural marker for metric fields that are diagonal in their natural chart. + + Subclassing :class:`AbstractDiagonalMetricField` signals that there exists + at least one coordinate chart in which the metric matrix is diagonal. The + ``metric_matrix`` dispatch rules for such charts can therefore return a + :class:`~coordinax._src.metric.matrix.DiagonalMetric` instead of a dense + matrix. + """ + + +@final +class RoundMetric(AbstractDiagonalMetricField, eqx.Module): + r"""Constant positive-curvature (round) metric on the *n*-sphere $S^n$. + + The geometric radius ``radius`` sets the overall scale: the metric is + $g = R^2 \, \hat{g}$ where $\hat{g}$ is the round metric on the unit sphere. + + Unlike the other :class:`AbstractMetricField` subtypes, :class:`RoundMetric` + is an :class:`equinox.Module` rather than a ``register_static`` dataclass. + This means ``radius`` is a *dynamic* JAX leaf — it can be JIT-compiled, + differentiated with :func:`jax.grad`, and batched with :func:`jax.vmap`. + + Parameters + ---------- + ndim : int + Intrinsic dimension of the sphere (e.g. 2 for $S^2$). Stored as a + static equinox field (part of the treedef, not a JAX array). + radius : unxt.AbstractQuantity + Geometric radius. Stored as a dynamic equinox field (JAX leaf). + + Examples + -------- + >>> import unxt as u + >>> from coordinax._src.metric.field import RoundMetric + + >>> m = RoundMetric(ndim=2, radius=u.Q(1.0, "m")) + >>> m.ndim + 2 + >>> m.signature + (1, 1) + >>> m.radius + Q(1., 'm') + + The radius is a JAX leaf — the pytree contains it as a dynamic leaf: + + >>> import jax + >>> leaves, treedef = jax.tree_util.tree_flatten(m) + >>> len(leaves) # only radius is dynamic + 1 + + """ + + # NOTE: the field is named _ndim (not ndim) to avoid conflicting with the + # `ndim` property inherited from AbstractMetricField. The public interface + # still accepts `ndim` via the custom __init__. + _ndim: int = eqx.field(static=True) + radius: u.AbstractQuantity + """Geometric radius (dynamic JAX leaf — JIT/grad/vmap-friendly).""" + + def __init__(self, *, ndim: int, radius: u.AbstractQuantity) -> None: + object.__setattr__(self, "_ndim", ndim) + object.__setattr__(self, "radius", radius) + + @property + def ndim(self) -> int: + """Intrinsic dimension of the sphere (static — part of the treedef).""" + return self._ndim + + @property + def signature(self) -> tuple[int, ...]: + """All-positive signature ``(1, 1, ..., 1)`` for a round sphere.""" + return (1,) * self._ndim diff --git a/src/coordinax/_src/metric/matrix.py b/src/coordinax/_src/metric/matrix.py new file mode 100644 index 00000000..7d7dcd6a --- /dev/null +++ b/src/coordinax/_src/metric/matrix.py @@ -0,0 +1,505 @@ +"""Typed metric matrix representations. + +These types encapsulate the result of evaluating a metric field at a specific +``(manifold, point, chart)`` triple via the ``metric_matrix`` dispatch API. +They encode the sparsity structure (diagonal vs. dense) and provide operations +consistent with that structure. +""" + +__all__ = ("AbstractMetricMatrix", "DiagonalMetric", "DenseMetric") + +import abc +import functools as ft +import operator + +from jaxtyping import Array +from typing import Any, final + +import equinox as eqx +import jax.numpy as jnp +import quax + +import quaxed.numpy as qnp +import unxt as u + +from coordinax.internal import ( + QMatrix, + UnitsMatrix, + det as _det_primitive, + inv as _inv_primitive, +) + +_det = quax.quaxify(_det_primitive) +_inv = quax.quaxify(_inv_primitive) +_matmul = quax.quaxify(jnp.matmul) + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _sine_product_diagonal(thetas: Array, scale: Any, /) -> Array: + r"""Cumulative-product diagonal for the round-sphere metric. + + Given polar angles $\theta_1, \dots, \theta_k$ and a scale factor + $R$ (e.g. the sphere radius or the radial coordinate $r$), returns the + length-$(k+1)$ diagonal + + .. math:: + + \bigl[R^2,\; + R^2\sin^2\theta_1,\; + R^2\sin^2\theta_1\sin^2\theta_2,\; + \ldots\bigr] + + This is shared by the :class:`~coordinax._src.spherical.manifold.HyperSphereSn` + and :class:`~coordinax._src.euclidean.manifold.EuclideanManifold` + + ``HyperSphericalChart`` dispatch rules. + + Parameters + ---------- + thetas : Array, shape ``(k,)`` + Polar angles $\theta_1, \dots, \theta_k$ in radians (the *last* + azimuthal angle $\phi$ is excluded by the caller). + scale : scalar + Scale factor $R$ (sphere radius) or $r$ (radial coordinate). + + Returns + ------- + Array, shape ``(k+1,)`` + The metric diagonal. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import _sine_product_diagonal + + Unit sphere, S² (one polar angle θ = π/2): + + >>> _sine_product_diagonal(jnp.array([jnp.pi / 2]), 1.0) + Array([1., 1.], dtype=float64) + + Radius-2 sphere, S² (θ = π/6): + + >>> import jax + >>> _sine_product_diagonal(jnp.array([jnp.pi / 6]), 2.0) + Array([4., 1.], dtype=float64) + + """ + sin2 = qnp.sin(thetas) ** 2 + cumprod = jnp.concat([jnp.ones(1, dtype=sin2.dtype), jnp.cumprod(sin2)]) + return scale**2 * cumprod + + +# --------------------------------------------------------------------------- +# Abstract base +# --------------------------------------------------------------------------- + + +class AbstractMetricMatrix(eqx.Module): + """Abstract base class for typed metric matrix representations. + + Concrete subclasses encode the sparsity structure of a metric matrix + (diagonal vs. dense) and provide matrix-level operations consistent with + that structure. + """ + + @property + @abc.abstractmethod + def ndim(self) -> int: + """Dimension of the metric.""" + raise NotImplementedError # pragma: no cover + + @abc.abstractmethod + def to_dense(self) -> "DenseMetric": + """Return an equivalent :class:`DenseMetric`. + + For diagonal metrics, off-diagonal entries are zero. + For dense metrics, returns ``self``. + """ + raise NotImplementedError # pragma: no cover + + +# --------------------------------------------------------------------------- +# Diagonal metric +# --------------------------------------------------------------------------- + + +@final +class DiagonalMetric(AbstractMetricMatrix): + r"""Diagonal metric matrix stored as a 1-D array or QMatrix. + + Encodes a metric whose coordinate matrix is diagonal — i.e. orthogonal + coordinate charts. Storing only the diagonal avoids materialising the full + $n \times n$ matrix and makes operations like matrix-vector products and + inversion run in $O(n)$. + + Parameters + ---------- + diagonal : QMatrix or Array, shape ``(n,)`` + The diagonal entries $g_{11}, g_{22}, \\ldots, g_{nn}$. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> d = DiagonalMetric(jnp.array([1.0, 4.0, 9.0])) + >>> d.ndim + 3 + >>> d.determinant + Array(36., dtype=float64) + >>> d.inverse.diagonal + Array([1. , 0.25 , 0.11111111], dtype=float64) + + """ + + diagonal: QMatrix | Array + + @property + def ndim(self) -> int: + """Dimension of the metric.""" + return int(self.diagonal.shape[-1]) + + def to_dense(self) -> "DenseMetric": + r"""Convert to a full $n \times n$ matrix with zeros off the diagonal. + + When the diagonal is a :class:`~coordinax.internal.QMatrix`, + the off-diagonal entry ``(i, j)`` is assigned the geometric-mean unit + ``sqrt(diag_unit[i] * diag_unit[j])``. This choice ensures that + ``g[i, j] * v[j]`` is unit-compatible with ``g[i, i] * v[i]`` during + matrix-vector contraction, which is required for the + :func:`~coordinax.internal.QMatrix` dot-product to succeed even + when the coordinate components have different physical dimensions (e.g. + metres and radians in spherical coordinates). + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Plain array — off-diagonal entries are zero: + + >>> d = DiagonalMetric(jnp.array([1.0, 4.0])) + >>> d.to_dense().matrix + Array([[1., 0.], + [0., 4.]], dtype=float64) + + QMatrix diagonal — diagonal units are preserved and off-diagonal + entries get the geometric-mean unit: + + >>> from coordinax.internal import QMatrix + >>> d = DiagonalMetric(QMatrix(jnp.array([1.0, 4.0]), unit=("m2", "s2"))) + >>> d.to_dense().matrix.unit[0, 0] + Unit("m2") + >>> d.to_dense().matrix.unit[1, 1] + Unit("s2") + >>> d.to_dense().matrix.unit[0, 1] # geometric mean: sqrt(m2 * s2) + Unit("m s") + + """ + if isinstance(self.diagonal, QMatrix): + # Off-diagonal entries are numerically zero, but their units must + # be chosen so that g[i,j] * v[j] is unit-compatible with + # g[i,i] * v[i] for any tangent vector v. The physically correct + # unit for entry (i, j) of a metric tensor g is + # + # [g_{ij}] = [ds²] / ([coord_i] · [coord_j]) + # = sqrt([g_{ii}] · [g_{jj}]) + # + # Using the geometric mean for off-diagonal entries ensures that + # the scale-factor computation in _dot_general_2d_1d can always + # convert every term to the reference unit ref[i] = g[i,0]*v[0]. + n = self.ndim + dense_val = jnp.diag(self.diagonal.value) + du = self.diagonal.unit._units # shape (n,) + row_units = tuple( + tuple(du[i] if i == j else (du[i] * du[j]) ** 0.5 for j in range(n)) + for i in range(n) + ) + return DenseMetric(QMatrix(dense_val, unit=UnitsMatrix(row_units))) + return DenseMetric(jnp.diag(self.diagonal)) + + def __matmul__( + self, other: "Array | QMatrix | u.AbstractQuantity", / + ) -> "Array | QMatrix | u.AbstractQuantity": + """Apply this diagonal metric to a vector — element-wise product. + + When either the diagonal or ``other`` carries units, the operation is + routed through :meth:`to_dense` so that unit propagation is handled + correctly by the :class:`~coordinax.internal.QMatrix` Quax + dispatches. Plain-array inputs use a fast O(n) element-wise multiply. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Plain array diagonal, plain array vector: + + >>> d = DiagonalMetric(jnp.array([1.0, 4.0, 9.0])) + >>> d @ jnp.array([1.0, 2.0, 3.0]) + Array([ 1., 8., 27.], dtype=float64) + + QMatrix diagonal, plain array vector — result carries diagonal units: + + >>> from coordinax.internal import QMatrix + >>> d = DiagonalMetric( + ... QMatrix(jnp.array([2.0, 3.0]), unit=("m2 / rad2", "m2 / rad2")) + ... ) + >>> w = d @ jnp.array([1.0, 1.0]) + >>> w.unit.to_string() + '(m2 / rad2, m2 / rad2)' + >>> w.value + Array([2., 3.], dtype=float64) + + QMatrix diagonal, Quantity vector — full unit tracking: + + >>> import unxt as u + >>> w2 = d @ u.Q(jnp.array([1.0, 1.0]), "rad") + >>> w2.unit.to_string() + '(m2 / rad, m2 / rad)' + >>> w2.value + Array([2., 3.], dtype=float64) + + QMatrix diagonal, QMatrix vector — full unit tracking: + + >>> v = QMatrix(jnp.array([1.0, 1.0]), unit=("rad", "rad")) + >>> w3 = d @ v + >>> w3.unit.to_string() + '(m2 / rad, m2 / rad)' + >>> w3.value + Array([2., 3.], dtype=float64) + + """ + if isinstance(self.diagonal, QMatrix) or isinstance( + other, (QMatrix, u.AbstractQuantity) + ): + # Route through the dense path for correct unit propagation. + return self.to_dense().__matmul__(other) + # Fast O(n) path: plain-array element-wise multiply. + return self.diagonal * other + + @property + def inverse(self) -> "DiagonalMetric": + """Inverse diagonal metric — reciprocal of each diagonal entry. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> d = DiagonalMetric(jnp.array([2.0, 4.0])) + >>> d.inverse.diagonal + Array([0.5 , 0.25], dtype=float64) + + """ + if isinstance(self.diagonal, QMatrix): + inv_vals = 1.0 / self.diagonal.value + return DiagonalMetric(QMatrix(inv_vals, unit=self.diagonal.unit.inverse())) + return DiagonalMetric(1.0 / self.diagonal) + + @property + def determinant(self) -> "Array | u.AbstractQuantity": + """Product of the diagonal entries. + + Returns a :class:`~unxt.AbstractQuantity` when the diagonal is a + :class:`~coordinax.internal.QMatrix`. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Bare array — returns a plain :class:`~jaxtyping.Array`: + + >>> DiagonalMetric(jnp.array([2.0, 3.0])).determinant + Array(6., dtype=float64) + + QMatrix diagonal — returns a :class:`~unxt.Quantity`: + + >>> import unxt as u + >>> from coordinax.internal import QMatrix + >>> d = DiagonalMetric(QMatrix(jnp.array([2.0, 3.0]), unit=("m2", "s2"))) + >>> d.determinant + Q(6., 'm2 s2') + + """ + if isinstance(self.diagonal, QMatrix): + det_val = qnp.prod(self.diagonal.value) + det_unit = ft.reduce(operator.mul, self.diagonal.unit) + return u.Q(det_val, det_unit) + return qnp.prod(self.diagonal) + + +# --------------------------------------------------------------------------- +# Dense metric +# --------------------------------------------------------------------------- + + +@final +class DenseMetric(AbstractMetricMatrix): + r"""Dense symmetric metric matrix. + + Stores the full $n \\times n$ metric matrix. Used for non-orthogonal + charts or metrics that cannot be expressed diagonally. + + Parameters + ---------- + matrix : QMatrix or Array, shape ``(n, n)`` + The full metric matrix $g_{ij}$. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DenseMetric + + >>> g = DenseMetric(jnp.eye(3)) + >>> g.ndim + 3 + >>> g.determinant + Array(1., dtype=float64) + + """ + + matrix: QMatrix | Array + + @property + def ndim(self) -> int: + """Dimension of the metric.""" + return int(self.matrix.shape[-1]) + + def to_dense(self) -> "DenseMetric": + """Return ``self`` — already in dense form. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DenseMetric + + >>> g = DenseMetric(jnp.eye(2)) + >>> g.to_dense() is g + True + + """ + return self + + def __matmul__(self, other: "Array | QMatrix", /) -> "Array | QMatrix": + """Apply this metric matrix to a vector via matrix-vector product. + + When the metric matrix is a :class:`~coordinax.internal.QMatrix`, + a plain-array ``other`` is treated as dimensionless so that units flow + through the contraction correctly. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DenseMetric + + Plain array metric, plain array vector: + + >>> g = DenseMetric(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + >>> g @ jnp.array([1.0, 1.0]) + Array([2., 3.], dtype=float64) + + QMatrix metric, plain array vector — result carries metric units: + + >>> from coordinax.internal import QMatrix, UnitsMatrix + >>> g = DenseMetric( + ... QMatrix( + ... jnp.array([[2.0, 0.0], [0.0, 3.0]]), + ... unit=UnitsMatrix(( + ... ("m2 / rad2", "m2 / rad2"), + ... ("m2 / rad2", "m2 / rad2"), + ... )), + ... ) + ... ) + >>> w = g @ jnp.array([1.0, 1.0]) + >>> w.unit.to_string() + '(m2 / rad2, m2 / rad2)' + >>> w.value + Array([2., 3.], dtype=float64) + + QMatrix metric, QMatrix vector — full unit tracking: + + >>> v = QMatrix(jnp.array([1.0, 1.0]), unit=("rad / s", "rad / s")) + >>> w2 = g @ v + >>> w2.unit.to_string() + '(m2 / (rad s), m2 / (rad s))' + + """ + return _matmul(self.matrix, other) + + @property + def inverse(self) -> "DenseMetric": + """Inverse via :func:`jax.numpy.linalg.inv` (positive-definite assumption). + + Returns a :class:`~coordinax.internal.QMatrix`-backed + :class:`DenseMetric` with units ``1 / ref_unit`` when the matrix + carries units. Assumes all entries share the same unit (physically + well-formed metrics from the Cartesian-Jacobian pullback always satisfy + this). + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DenseMetric + + Bare array: + + >>> g = DenseMetric(jnp.array([[2.0, 0.0], [0.0, 4.0]])) + >>> g.inverse.matrix + Array([[0.5 , 0. ], + [0. , 0.25]], dtype=float64) + + QMatrix — inverse carries reciprocal units: + + >>> import unxt as u + >>> from coordinax.internal import QMatrix, UnitsMatrix + >>> g = DenseMetric( + ... QMatrix( + ... jnp.array([[4.0, 0.0], [0.0, 1.0]]), + ... unit=UnitsMatrix(( + ... ("m2 / rad2", "m2 / rad2"), + ... ("m2 / rad2", "m2 / rad2"), + ... )), + ... ) + ... ) + >>> g.inverse.matrix.unit[0, 0] + Unit("rad2 / m2") + >>> g.inverse.matrix.value + Array([[0.25, 0. ], + [0. , 1. ]], dtype=float64) + + """ + return DenseMetric(_inv(self.matrix)) + + @property + def determinant(self) -> "Array | u.AbstractQuantity": + """Determinant via the custom ``det_p`` JAX primitive. + + Routes through Quax, so a :class:`~coordinax.internal.QMatrix` + matrix returns a :class:`~unxt.AbstractQuantity` while a plain array + returns a bare :class:`~jaxtyping.Array`. The unit is the product of + the main-diagonal units — valid for diagonal and uniform-unit matrices. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from coordinax._src.metric.matrix import DenseMetric + + Bare array — returns a plain :class:`~jaxtyping.Array`: + + >>> DenseMetric(jnp.eye(3)).determinant + Array(1., dtype=float64) + + QMatrix — returns a :class:`~unxt.Quantity`: + + >>> import unxt as u + >>> from coordinax.internal import QMatrix + >>> g = DenseMetric(QMatrix(jnp.eye(2), unit=(("m2", ""), ("", "s2")))) + >>> g.determinant + Q(1., 'm2 s2') + + """ + return _det(self.matrix) diff --git a/tests/unit/manifolds/test_metric_field.py b/tests/unit/manifolds/test_metric_field.py new file mode 100644 index 00000000..228eb7f8 --- /dev/null +++ b/tests/unit/manifolds/test_metric_field.py @@ -0,0 +1,184 @@ +"""Contract tests for AbstractMetricField concrete subtypes. + +Verifies: +- All subtypes expose ``ndim`` and ``signature`` +- Signature entries are ±1 +- Static subtypes are JAX-static pytree leaves (no dynamic leaves) +- No ``metric_matrix``, ``scale_factors``, or ``cholesky`` methods exist + (these were removed in Phase 3b) +- ``RoundMetric`` from ``metric/field.py`` is an eqx.Module with a dynamic + ``radius`` leaf, static ``_ndim``, and supports JIT / grad through ``radius`` +""" + +import jax +import jax.numpy as jnp +import pytest + +import unxt as u + +import coordinax.manifolds as cxm +from coordinax._src.metric.field import ( + RoundMetric as DynamicRoundMetric, +) + +# --------------------------------------------------------------------------- +# Fixtures: every concrete AbstractMetricField subtype from the public API +# --------------------------------------------------------------------------- + + +@pytest.fixture( + params=[ + pytest.param(lambda: cxm.FlatMetric(1), id="flat-1d"), + pytest.param(lambda: cxm.FlatMetric(3), id="flat-3d"), + pytest.param(lambda: cxm.RoundMetric(2), id="round-2d"), + pytest.param(lambda: cxm.RoundMetric(3), id="round-3d"), + pytest.param(lambda: cxm.MinkowskiMetric(), id="minkowski"), + pytest.param( + lambda: cxm.ProductMetric(factors=(cxm.RoundMetric(2), cxm.FlatMetric(1))), + id="product-s2-r1", + ), + pytest.param( + lambda: cxm.PullbackMetric( + cxm.TwoSphereIn3D(radius=1.0), cxm.FlatMetric(3) + ), + id="pullback-unit-sphere", + ), + ] +) +def metric_field(request): + return request.param() + + +# --------------------------------------------------------------------------- +# Generic contract +# --------------------------------------------------------------------------- + + +class TestAbstractMetricFieldContract: + """Every AbstractMetricField subtype satisfies these invariants.""" + + def test_has_ndim(self, metric_field): + assert isinstance(metric_field.ndim, int) + assert metric_field.ndim >= 1 + + def test_has_signature(self, metric_field): + sig = metric_field.signature + assert isinstance(sig, tuple) + assert len(sig) == metric_field.ndim + + def test_signature_entries_are_plus_minus_one(self, metric_field): + for s in metric_field.signature: + assert s in (-1, 1), f"signature entry {s!r} is not ±1" + + def test_no_metric_matrix_method(self, metric_field): + """Phase 3b: field classes must NOT have a metric_matrix() method.""" + assert not hasattr(metric_field, "metric_matrix"), ( + f"{type(metric_field).__name__} still has a metric_matrix method" + ) + + def test_no_scale_factors_method(self, metric_field): + assert not hasattr(metric_field, "scale_factors"), ( + f"{type(metric_field).__name__} still has a scale_factors method" + ) + + def test_no_cholesky_method(self, metric_field): + assert not hasattr(metric_field, "cholesky"), ( + f"{type(metric_field).__name__} still has a cholesky method" + ) + + +# --------------------------------------------------------------------------- +# Static JAX pytree leaves (parameter-free types) +# --------------------------------------------------------------------------- + + +class TestStaticMetricFieldPytree: + """Parameter-free metric fields are static JAX pytrees (no dynamic leaves).""" + + @pytest.mark.parametrize( + "factory", + [ + lambda: cxm.FlatMetric(3), + lambda: cxm.RoundMetric(2), + lambda: cxm.MinkowskiMetric(), + lambda: cxm.ProductMetric(factors=(cxm.RoundMetric(2), cxm.FlatMetric(1))), + ], + ids=["flat-3d", "round-2d", "minkowski", "product"], + ) + def test_no_dynamic_leaves(self, factory): + m = factory() + leaves, _ = jax.tree.flatten(m) + assert leaves == [], ( + f"{type(m).__name__} has unexpected dynamic leaves: {leaves}" + ) + + def test_flat_metric_jit_roundtrip(self): + m = cxm.FlatMetric(3) + + @jax.jit + def get_ndim(mf): + return mf.ndim + + assert get_ndim(m) == 3 + + def test_round_metric_jit_roundtrip(self): + m = cxm.RoundMetric(2) + + @jax.jit + def get_ndim(mf): + return mf.ndim + + assert get_ndim(m) == 2 + + +# --------------------------------------------------------------------------- +# RoundMetric from metric/field.py — dynamic radius, static ndim +# --------------------------------------------------------------------------- + + +class TestRoundMetricFieldDynamic: + """RoundMetric (metric/field.py) is an eqx.Module with dynamic radius.""" + + def test_radius_is_only_dynamic_leaf(self): + m = DynamicRoundMetric(ndim=2, radius=u.Q(1.0, "m")) + leaves, _ = jax.tree.flatten(m) + assert len(leaves) == 1, f"Expected 1 dynamic leaf (radius), got {leaves}" + + def test_ndim_is_static(self): + m2 = DynamicRoundMetric(ndim=2, radius=u.Q(1.0, "m")) + m3 = DynamicRoundMetric(ndim=3, radius=u.Q(1.0, "m")) + _, treedef2 = jax.tree.flatten(m2) + _, treedef3 = jax.tree.flatten(m3) + assert treedef2 != treedef3, ( + "ndim should be static (changing it should change the treedef)" + ) + + def test_signature_length_matches_ndim(self): + for ndim in [1, 2, 3, 4]: + m = DynamicRoundMetric(ndim=ndim, radius=u.Q(1.0, "m")) + assert len(m.signature) == ndim + assert all(s == 1 for s in m.signature) + + def test_jit_through_radius(self): + m = DynamicRoundMetric(ndim=2, radius=u.Q(1.0, "m")) + + @jax.jit + def get_radius_value(mf): + return mf.radius.value + + result = get_radius_value(m) + assert jnp.allclose(result, jnp.array(1.0)) + + def test_grad_through_radius(self): + def f(r_val): + m = DynamicRoundMetric(ndim=2, radius=u.Q(r_val, "m")) + return m.radius.value + + grad_f = jax.grad(f) + result = grad_f(jnp.array(3.0)) + assert jnp.allclose(result, jnp.array(1.0)) + + def test_radius_unit_preserved(self): + m = DynamicRoundMetric(ndim=2, radius=u.Q(5.0, "km")) + assert str(m.radius.unit) == "km" + assert jnp.allclose(m.radius.value, jnp.array(5.0)) diff --git a/tests/unit/manifolds/test_metric_matrix_types.py b/tests/unit/manifolds/test_metric_matrix_types.py new file mode 100644 index 00000000..7227acfd --- /dev/null +++ b/tests/unit/manifolds/test_metric_matrix_types.py @@ -0,0 +1,250 @@ +"""Contract tests for DiagonalMetric and DenseMetric. + +Tests cover: +- pytree flatten/unflatten round-trip (equinox Module) +- ``to_dense`` +- ``__matmul__`` +- ``inverse`` +- ``determinant`` +- JIT and vmap compatibility +""" + +import jax +import jax.numpy as jnp +import pytest + +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture( + params=[ + jnp.array([1.0, 4.0, 9.0]), + jnp.array([1.0]), + jnp.array([2.0, 3.0]), + ], + ids=["3d", "1d", "2d"], +) +def diag_metric(request): + return DiagonalMetric(request.param) + + +@pytest.fixture( + params=[ + jnp.eye(3), + jnp.array([[4.0, 0.0], [0.0, 9.0]]), + jnp.eye(1), + ], + ids=["I3", "diag2", "I1"], +) +def dense_metric(request): + return DenseMetric(request.param) + + +# --------------------------------------------------------------------------- +# DiagonalMetric +# --------------------------------------------------------------------------- + + +class TestDiagonalMetric: + """Contract tests for DiagonalMetric.""" + + def test_ndim(self, diag_metric): + assert diag_metric.ndim == diag_metric.diagonal.shape[-1] + + def test_pytree_roundtrip(self, diag_metric): + """flatten/unflatten must recover an equal object.""" + leaves, treedef = jax.tree_util.tree_flatten(diag_metric) + restored = jax.tree_util.tree_unflatten(treedef, leaves) + assert jnp.allclose(restored.diagonal, diag_metric.diagonal) + + def test_to_dense_shape(self, diag_metric): + n = diag_metric.ndim + dense = diag_metric.to_dense() + assert isinstance(dense, DenseMetric) + assert dense.matrix.shape == (n, n) + + def test_to_dense_diagonal_entries(self, diag_metric): + """Diagonal entries of to_dense() match the stored diagonal.""" + dense = diag_metric.to_dense() + assert jnp.allclose(jnp.diag(dense.matrix), diag_metric.diagonal) + + def test_to_dense_off_diagonal_zero(self, diag_metric): + """Off-diagonal entries of to_dense() are zero.""" + n = diag_metric.ndim + dense = diag_metric.to_dense() + off_diag_mask = ~jnp.eye(n, dtype=bool) + assert jnp.allclose(dense.matrix[off_diag_mask], 0.0) + + def test_matmul(self, diag_metric): + n = diag_metric.ndim + v = jnp.ones(n) + result = diag_metric @ v + assert jnp.allclose(result, diag_metric.diagonal) + + def test_matmul_matches_dense(self, diag_metric): + n = diag_metric.ndim + v = jnp.arange(1.0, n + 1) + assert jnp.allclose(diag_metric @ v, diag_metric.to_dense() @ v) + + def test_inverse_shape(self, diag_metric): + inv = diag_metric.inverse + assert isinstance(inv, DiagonalMetric) + assert inv.diagonal.shape == diag_metric.diagonal.shape + + def test_inverse_values(self, diag_metric): + assert jnp.allclose(diag_metric.inverse.diagonal, 1.0 / diag_metric.diagonal) + + def test_inverse_roundtrip(self, diag_metric): + """G * g⁻¹ ≈ I element-wise.""" + product = diag_metric.diagonal * diag_metric.inverse.diagonal + assert jnp.allclose(product, jnp.ones(diag_metric.ndim)) + + def test_determinant(self, diag_metric): + assert jnp.allclose(diag_metric.determinant, jnp.prod(diag_metric.diagonal)) + + def test_jit_matmul(self, diag_metric): + n = diag_metric.ndim + v = jnp.ones(n) + + @jax.jit + def apply(d, v): + return d @ v + + result = apply(diag_metric, v) + assert jnp.allclose(result, diag_metric.diagonal) + + def test_jit_determinant(self, diag_metric): + @jax.jit + def det(d): + return d.determinant + + assert jnp.allclose(det(diag_metric), diag_metric.determinant) + + def test_vmap_matmul(self, diag_metric): + """Vmap over batch of vectors.""" + n = diag_metric.ndim + batch = jnp.ones((4, n)) + + result = jax.vmap(lambda v: diag_metric @ v)(batch) + assert result.shape == (4, n) + assert jnp.allclose(result[0], diag_metric.diagonal) + + def test_is_not_static(self): + """DiagonalMetric is a dynamic pytree, not a static leaf.""" + d = DiagonalMetric(jnp.array([1.0, 2.0])) + leaves, _ = jax.tree_util.tree_flatten(d) + assert len(leaves) > 0, "diagonal should be a dynamic JAX leaf" + + +# --------------------------------------------------------------------------- +# DenseMetric +# --------------------------------------------------------------------------- + + +class TestDenseMetric: + """Contract tests for DenseMetric.""" + + def test_ndim(self, dense_metric): + assert dense_metric.ndim == dense_metric.matrix.shape[-1] + + def test_pytree_roundtrip(self, dense_metric): + leaves, treedef = jax.tree_util.tree_flatten(dense_metric) + restored = jax.tree_util.tree_unflatten(treedef, leaves) + assert jnp.allclose(restored.matrix, dense_metric.matrix) + + def test_to_dense_is_self(self, dense_metric): + assert dense_metric.to_dense() is dense_metric + + def test_matmul_identity(self): + n = 3 + g = DenseMetric(jnp.eye(n)) + v = jnp.arange(1.0, n + 1) + assert jnp.allclose(g @ v, v) + + def test_matmul_diagonal(self): + g = DenseMetric(jnp.array([[4.0, 0.0], [0.0, 9.0]])) + v = jnp.array([1.0, 1.0]) + assert jnp.allclose(g @ v, jnp.array([4.0, 9.0])) + + def test_inverse_identity(self): + g = DenseMetric(jnp.eye(3)) + assert jnp.allclose(g.inverse.matrix, jnp.eye(3)) + + def test_inverse_diagonal(self): + g = DenseMetric(jnp.array([[4.0, 0.0], [0.0, 9.0]])) + expected = jnp.array([[0.25, 0.0], [0.0, 1.0 / 9.0]]) + assert jnp.allclose(g.inverse.matrix, expected) + + def test_inverse_roundtrip(self, dense_metric): + n = dense_metric.ndim + product = dense_metric.matrix @ dense_metric.inverse.matrix + assert jnp.allclose(product, jnp.eye(n), atol=1e-5) + + def test_determinant_identity(self): + assert jnp.allclose(DenseMetric(jnp.eye(3)).determinant, 1.0) + + def test_determinant_diagonal(self): + g = DenseMetric(jnp.array([[2.0, 0.0], [0.0, 3.0]])) + assert jnp.allclose(g.determinant, 6.0) + + def test_jit_matmul(self, dense_metric): + n = dense_metric.ndim + v = jnp.ones(n) + + @jax.jit + def apply(g, v): + return g @ v + + result = apply(dense_metric, v) + assert result.shape == (n,) + + def test_jit_determinant(self, dense_metric): + @jax.jit + def det(g): + return g.determinant + + assert jnp.isfinite(det(dense_metric)) + + def test_vmap_matmul(self, dense_metric): + n = dense_metric.ndim + batch = jnp.ones((4, n)) + + result = jax.vmap(lambda v: dense_metric @ v)(batch) + assert result.shape == (4, n) + + def test_is_not_static(self): + g = DenseMetric(jnp.eye(2)) + leaves, _ = jax.tree_util.tree_flatten(g) + assert len(leaves) > 0 + + +# --------------------------------------------------------------------------- +# DiagonalMetric ↔ DenseMetric consistency +# --------------------------------------------------------------------------- + + +class TestDiagonalDenseConsistency: + """DiagonalMetric.to_dense() must agree with DenseMetric on operations.""" + + def test_matmul_consistency(self): + diag = DiagonalMetric(jnp.array([2.0, 3.0, 4.0])) + dense = diag.to_dense() + v = jnp.array([1.0, 2.0, 3.0]) + assert jnp.allclose(diag @ v, dense @ v) + + def test_determinant_consistency(self): + diag = DiagonalMetric(jnp.array([2.0, 3.0])) + dense = diag.to_dense() + assert jnp.allclose(diag.determinant, dense.determinant) + + def test_inverse_consistency(self): + diag = DiagonalMetric(jnp.array([2.0, 5.0])) + assert jnp.allclose( + jnp.diag(diag.inverse.to_dense().matrix), + diag.to_dense().inverse.matrix.diagonal(), + atol=1e-6, + ) From c01b384b6fe534b0a0526c92717e093595d42a9c Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:23:45 -0400 Subject: [PATCH 03/15] =?UTF-8?q?=E2=9C=A8=20feat(metric):=20add=20per-geo?= =?UTF-8?q?metry=20register=5Fmetric=20dispatch=20files,=20scale=5Ffactors?= =?UTF-8?q?,=20and=20rename=20metric=20classes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each geometry (Euclidean, HyperSpherical, Minkowski, Product, Embedded, Custom) gains a register_metric.py module that registers metric_matrix and metric_representation dispatches using the new standalone dispatch API. Metric class names are updated: EuclideanMetric → FlatMetric, HyperSphericalMetric → RoundMetric, CartesianProductMetric → ProductMetric, InducedMetric → PullbackMetric. Spherical and Minkowski geometries receive new scale_factors.py files. The metric_matrix method is removed from each metric class in favour of the dispatch function. --- src/coordinax/_src/custom/metric.py | 6 +- src/coordinax/_src/embedded/__init__.py | 1 + src/coordinax/_src/embedded/chart.py | 6 - src/coordinax/_src/embedded/manifold.py | 10 +- src/coordinax/_src/embedded/metric.py | 74 +-- .../_src/embedded/register_charts.py | 2 - .../_src/embedded/register_metric.py | 182 ++++++ src/coordinax/_src/euclidean/__init__.py | 1 + src/coordinax/_src/euclidean/manifold.py | 6 +- src/coordinax/_src/euclidean/metric.py | 63 +- .../_src/euclidean/register_metric.py | 488 ++++++++++++++ src/coordinax/_src/euclidean/scale_factors.py | 30 +- src/coordinax/_src/minkowski/__init__.py | 2 + src/coordinax/_src/minkowski/metric.py | 85 +-- .../_src/minkowski/register_metric.py | 145 +++++ src/coordinax/_src/minkowski/scale_factors.py | 41 ++ src/coordinax/_src/null/metric.py | 14 +- src/coordinax/_src/product/__init__.py | 1 + src/coordinax/_src/product/atlas.py | 2 - src/coordinax/_src/product/chart.py | 4 - src/coordinax/_src/product/manifold.py | 20 +- src/coordinax/_src/product/metric.py | 72 +-- src/coordinax/_src/product/register_metric.py | 132 ++++ src/coordinax/_src/spherical/__init__.py | 2 + src/coordinax/_src/spherical/embed.py | 12 +- src/coordinax/_src/spherical/manifold.py | 8 +- src/coordinax/_src/spherical/metric.py | 94 +-- .../_src/spherical/register_charts.py | 3 +- .../_src/spherical/register_metric.py | 94 +++ src/coordinax/_src/spherical/scale_factors.py | 58 ++ .../manifolds/test_metric_matrix_dispatch.py | 333 ++++++++++ .../test_metric_pullback_consistency.py | 191 ++++++ tests/unit/manifolds/test_metrics.py | 606 ++++-------------- .../manifolds/test_scale_factors_dispatch.py | 51 +- 34 files changed, 1944 insertions(+), 895 deletions(-) create mode 100644 src/coordinax/_src/embedded/register_metric.py create mode 100644 src/coordinax/_src/euclidean/register_metric.py create mode 100644 src/coordinax/_src/minkowski/register_metric.py create mode 100644 src/coordinax/_src/minkowski/scale_factors.py create mode 100644 src/coordinax/_src/product/register_metric.py create mode 100644 src/coordinax/_src/spherical/register_metric.py create mode 100644 src/coordinax/_src/spherical/scale_factors.py create mode 100644 tests/unit/manifolds/test_metric_matrix_dispatch.py create mode 100644 tests/unit/manifolds/test_metric_pullback_consistency.py diff --git a/src/coordinax/_src/custom/metric.py b/src/coordinax/_src/custom/metric.py index f2dab526..e78bcdc2 100644 --- a/src/coordinax/_src/custom/metric.py +++ b/src/coordinax/_src/custom/metric.py @@ -9,17 +9,17 @@ import jax -from coordinax._src.base import AbstractMetric +from coordinax._src.base import AbstractMetricField @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class CustomMetric(AbstractMetric): # ty: ignore[abstract-method-in-final-class] +class CustomMetric(AbstractMetricField): # ty: ignore[abstract-method-in-final-class] r"""Metric for a {class}`CustomManifold`, defined by user-provided callables. ``CustomMetric`` allows users to supply their own metric without subclassing - ``AbstractMetric``. Both the metric-matrix callable and the signature must + ``AbstractMetricField``. Both the metric-matrix callable and the signature must be provided at construction time. Parameters diff --git a/src/coordinax/_src/embedded/__init__.py b/src/coordinax/_src/embedded/__init__.py index 0ba64768..c850aecc 100644 --- a/src/coordinax/_src/embedded/__init__.py +++ b/src/coordinax/_src/embedded/__init__.py @@ -5,3 +5,4 @@ from .manifold import * from .metric import * from .register_charts import * +from .register_metric import * diff --git a/src/coordinax/_src/embedded/chart.py b/src/coordinax/_src/embedded/chart.py index f0e061eb..cd26f6a7 100644 --- a/src/coordinax/_src/embedded/chart.py +++ b/src/coordinax/_src/embedded/chart.py @@ -293,8 +293,6 @@ def pt_map( 2. Transforming in the ambient space (if the ambient charts differ) 3. Projecting back to the intrinsic coordinates of the target manifold - Examples - -------- >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm >>> import unxt as u @@ -350,8 +348,6 @@ def pt_map( This transforms coordinates from an ambient chart (e.g., Cartesian or Spherical) into the intrinsic coordinates of an embedded manifold. - Examples - -------- >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm >>> import unxt as u @@ -400,8 +396,6 @@ def pt_map( coordinates of an ambient chart, which may differ from the embedding's native ambient chart. - Examples - -------- >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm >>> import unxt as u diff --git a/src/coordinax/_src/embedded/manifold.py b/src/coordinax/_src/embedded/manifold.py index 5d83da6a..e06f7f16 100644 --- a/src/coordinax/_src/embedded/manifold.py +++ b/src/coordinax/_src/embedded/manifold.py @@ -13,7 +13,7 @@ import coordinax.api.charts as cxcapi import coordinax.api.manifolds as cxmapi from .embedmap import AbstractEmbeddingMap, AmbientT, IntrinsicT -from .metric import InducedMetric +from .metric import PullbackMetric from coordinax._src.base import AbstractAtlas, AbstractChart, AbstractManifold from coordinax._src.custom_types import CDict, OptUSys @@ -39,8 +39,8 @@ class EmbeddedManifold(AbstractManifold, Generic[IntrinsicT, AmbientT]): >>> import unxt as u >>> M = cxm.EmbeddedManifold( - ... intrinsic=cxm.HyperSphericalManifold(), - ... ambient=cxm.EuclideanManifold(3), + ... intrinsic=cxm.S2, + ... ambient=cxm.R3, ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "km"))) >>> p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} >>> sph = cxm.pt_embed(p, M) @@ -94,9 +94,9 @@ def project( # Manifold API @property - def metric(self) -> InducedMetric: + def metric(self) -> PullbackMetric: """Induced (pullback) Riemannian metric from the ambient manifold.""" - return InducedMetric(self.embed_map, self.ambient.metric) + return PullbackMetric(self.embed_map, self.ambient.metric) @override @property diff --git a/src/coordinax/_src/embedded/metric.py b/src/coordinax/_src/embedded/metric.py index cc93a063..ef410ad7 100644 --- a/src/coordinax/_src/embedded/metric.py +++ b/src/coordinax/_src/embedded/metric.py @@ -1,6 +1,6 @@ """Representations for embedded manifolds.""" -__all__ = ("InducedMetric",) +__all__ = ("PullbackMetric",) import dataclasses @@ -12,10 +12,10 @@ import unxt as u from .embedmap import AbstractEmbeddingMap -from coordinax._src.base import AbstractChart, AbstractMetric +from coordinax._src.base import AbstractMetricField from coordinax._src.custom_types import CDict, OptUSys from coordinax.internal import ( - QuantityMatrix, + QMatrix, UnitsMatrix, cdict_units, pack_nonuniform_unit, @@ -26,8 +26,8 @@ def _jacobian_embed_map( embed_map: AbstractEmbeddingMap, at: CDict, usys: OptUSys -) -> QuantityMatrix: - """Compute the Jacobian of ``embed_map`` at ``at`` as a ``QuantityMatrix``. +) -> QMatrix: + """Compute the Jacobian of ``embed_map`` at ``at`` as a ``QMatrix``. Mirrors the general fallback of ``jac_pt_map`` but differentiates the embedding function instead of a chart transition map. @@ -43,8 +43,8 @@ def _jacobian_embed_map( Returns ------- - QuantityMatrix - 2-D ``QuantityMatrix`` of shape ``(n_ambient, n_intrinsic)`` where + QMatrix + 2-D ``QMatrix`` of shape ``(n_ambient, n_intrinsic)`` where ``J.value[j, i] = \u2202(ambient_j) / \u2202(intrinsic_i)``. """ @@ -79,13 +79,13 @@ def embed_fn_arr(x_arr: qnp.ndarray) -> qnp.ndarray: return qnp.stack(vals) J_arr = jax.jacfwd(embed_fn_arr)(xat) # shape (n_ambient, n_intrinsic) - return QuantityMatrix(J_arr, unit=unit_matrix) + return QMatrix(J_arr, unit=unit_matrix) @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class InducedMetric(AbstractMetric): +class PullbackMetric(AbstractMetricField): r"""Pullback metric induced by an embedding map. Given an embedding $\iota : N \hookrightarrow M$, the metric $g_N$ on the @@ -101,35 +101,44 @@ class InducedMetric(AbstractMetric): ---------- embed_map : AbstractEmbeddingMap The embedding map from the submanifold into the ambient space. - ambient_metric : AbstractMetric + ambient_metric : AbstractMetricField The Riemannian metric on the ambient manifold. Examples -------- >>> import jax.numpy as jnp >>> import unxt as u + >>> import coordinax.api.manifolds as cxmapi >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm >>> embed_map = cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")) - >>> ambient_metric = cxm.EuclideanMetric(3) - >>> M = cxm.InducedMetric(embed_map, ambient_metric) - >>> at = {"theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0.0, "rad")} - >>> M.metric_matrix(cxc.sph2, at=at) - QuantityMatrix([[1., 0.], - [0., 1.]], '((km2 / rad2, km2 / rad2), (km2 / rad2, km2 / rad2))') - - Embedding into a Riemannian ambient space yields a Riemannian induced metric: - + >>> ambient_metric = cxm.FlatMetric(3) + >>> M = cxm.PullbackMetric(embed_map, ambient_metric) >>> M.signature (1, 1) >>> M.ndim 2 + The metric matrix is obtained via the dispatch API on an + :class:`~coordinax.manifolds.EmbeddedManifold`: + + >>> M_emb = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ... ) + >>> at = {"theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0.0, "rad")} + >>> g = cxmapi.metric_matrix(M_emb, at, cxc.sph2) + >>> g.matrix.value + Array([[1., 0.], + [0., 1.]], dtype=float64, weak_type=True) + >>> g.matrix.unit[0, 0] + Unit("km2 / rad2") + """ embed_map: AbstractEmbeddingMap - ambient_metric: AbstractMetric + ambient_metric: AbstractMetricField @property def signature(self) -> tuple[int, ...]: @@ -140,28 +149,3 @@ def signature(self) -> tuple[int, ...]: ``J`` has full column rank). """ return (1,) * self.embed_map.intrinsic.ndim - - def metric_matrix( - self, _: AbstractChart, /, *, at: CDict, usys: OptUSys = None - ) -> QuantityMatrix: - r"""Compute the induced metric $g_N = J^T G_M J$. - - Computes the Jacobian of the embedding evaluated at ``at``, then - contracts with the ambient metric at the embedded point. - - """ - # Compute Jacobian of the embedding as a QuantityMatrix: - # J shape (n_ambient, n_intrinsic) - J = _jacobian_embed_map(self.embed_map, at, usys=usys) - - # Evaluate embedding at base point to get ambient point - at_ambient = self.embed_map.embed(at) - - # Ambient metric at the embedded point (also QuantityMatrix) - ambient_chart = self.embed_map.ambient - G_ambient = self.ambient_metric.metric_matrix( - ambient_chart, at=at_ambient, usys=usys - ) - - JT = qnp.transpose(J, (1, 0)) # (n_intrinsic, n_ambient) - return JT @ G_ambient @ J # ty: ignore[invalid-return-type] diff --git a/src/coordinax/_src/embedded/register_charts.py b/src/coordinax/_src/embedded/register_charts.py index b471e15e..05e82b34 100644 --- a/src/coordinax/_src/embedded/register_charts.py +++ b/src/coordinax/_src/embedded/register_charts.py @@ -31,8 +31,6 @@ def pt_map( ) -> CDict: """Convert between embedded manifolds with a shared ambient space. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm diff --git a/src/coordinax/_src/embedded/register_metric.py b/src/coordinax/_src/embedded/register_metric.py new file mode 100644 index 00000000..3f8b7376 --- /dev/null +++ b/src/coordinax/_src/embedded/register_metric.py @@ -0,0 +1,182 @@ +r"""Register ``metric_matrix`` and ``metric_representation`` dispatch rules. + +Covers :class:`~coordinax.manifolds.EmbeddedManifold` paired with any +intrinsic :class:`~coordinax._src.base.AbstractChart`. + +The *induced* (pullback) metric on an embedded submanifold is computed as +$g = J^T J$ where $J$ is the Jacobian of the composition +``intrinsic → Cartesian ambient``. Routing through the Cartesian ambient +ensures every entry of $J$ carries the *same* unit (``cart_unit / intrinsic_unit``), +which makes $J^T J$ unit-compatible across all summation terms. + +All results are wrapped in a :class:`~coordinax._src.metric.matrix.DenseMetric` +because the induced metric is not guaranteed to be diagonal. + +""" + +__all__: tuple[str, ...] = () + +import jax +import jax.numpy as jnp +import plum + +import quaxed.numpy as qnp +import unxt as u + +import coordinax.api.charts as cxcapi +from .manifold import EmbeddedManifold +from coordinax._src.base import AbstractChart # type: ignore[type-arg] +from coordinax._src.metric.matrix import DenseMetric +from coordinax.api.manifolds import metric_matrix +from coordinax.internal import ( + QMatrix, + UnitsMatrix, + cdict_units, + pack_nonuniform_unit, +) + +DMLS = u.unit("") + + +# ===================================================================== +# metric_representation +# ===================================================================== + + +@plum.dispatch +def metric_representation( + M: EmbeddedManifold, chart: AbstractChart, / +) -> type[DenseMetric]: + """Embedded manifold in any intrinsic chart → `DenseMetric`. + + >>> import unxt as u + >>> import coordinax.manifolds as cxm + >>> import coordinax.charts as cxc + >>> from coordinax.api.manifolds import metric_representation + + >>> M = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ... ) + >>> metric_representation(M, cxc.sph2) + + + """ + del M, chart + return DenseMetric + + +# ===================================================================== +# metric_matrix +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: EmbeddedManifold, point: dict, chart: AbstractChart, / +) -> DenseMetric: + r"""Induced metric on an embedded submanifold via Jacobian pullback. + + Computes $g_{ij} = \sum_k J^k_i J^k_j$ where $J$ is the Jacobian of the + composition ``intrinsic → Cartesian ambient``. Routing through Cartesian + ambient coordinates ensures all entries of $J$ share the same unit + (``cart_unit / intrinsic_unit``), so the matrix product $J^T J$ is + unit-compatible and the result carries physically correct units. + + Parameters + ---------- + M : EmbeddedManifold + An embedded submanifold; carries ``intrinsic``, ``ambient``, and + ``embed_map`` fields. + point : dict + A coordinate dictionary in the *intrinsic* chart coordinates. + chart : AbstractChart + The intrinsic chart (passed through for API consistency). + + Returns + ------- + DenseMetric + Induced metric matrix at ``point``, backed by a + :class:`~coordinax.internal.QMatrix` with units + ``cart_unit^2 / (intrinsic_unit_i * intrinsic_unit_j)``. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.manifolds as cxm + >>> import coordinax.charts as cxc + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DenseMetric + + Unit sphere — values should be the identity: + + >>> M = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=1.0), + ... ) + >>> p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} + >>> g = metric_matrix(M, p, cxc.sph2) + >>> isinstance(g, DenseMetric) + True + >>> g.matrix.value + Array([[1., 0.], + [0., 1.]], dtype=float64, weak_type=True) + + Radius-2 sphere — metric scaled by R²: + + >>> M2 = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "m")), + ... ) + >>> g2 = metric_matrix(M2, p, cxc.sph2) + >>> g2.matrix.value + Array([[4., 0.], + [0., 4.]], dtype=float64, weak_type=True) + >>> g2.matrix.unit[0, 0] + Unit("m2 / rad2") + + """ + del chart + embed_map = M.embed_map + ambient_chart = embed_map.ambient + intrinsic_keys = embed_map.intrinsic.components + + # Use Cartesian ambient so every J entry has the same unit + # (cart_unit / intrinsic_unit). + cart_chart = ambient_chart.cartesian + cart_keys = cart_chart.components + + xat, ufrom = pack_nonuniform_unit(point, intrinsic_keys) + ufrom_ = tuple(uf if uf is not None else DMLS for uf in ufrom) + + at_ambient = embed_map.embed(point, usys=None) + at_cart = cxcapi.pt_map(at_ambient, ambient_chart, cart_chart) + uto_ = cdict_units(at_cart, cart_keys) + uto_ = tuple(ut if ut is not None else DMLS for ut in uto_) + + def _embed_cart(x_arr: jnp.ndarray) -> jnp.ndarray: + q = {k: u.Q(x_arr[i], ufrom_[i]) for i, k in enumerate(intrinsic_keys)} + q_ambient = embed_map.embed(q, usys=None) + q_cart = cxcapi.pt_map(q_ambient, ambient_chart, cart_chart) + vals = [ + u.ustrip(uto_[j], q_cart[k]) # ty: ignore[not-subscriptable] + if isinstance(q_cart[k], u.AbstractQuantity) # ty: ignore[not-subscriptable] + else qnp.asarray(q_cart[k]) # ty: ignore[not-subscriptable] + for j, k in enumerate(cart_keys) + ] + return qnp.stack(vals) + + J_arr = jax.jacfwd(_embed_cart)(xat) # (n_cart, n_intrinsic) + result_vals = J_arr.T @ J_arr # (n_intrinsic, n_intrinsic) + + # g_{ij} unit = uto_[0]² / (ufrom_[i] × ufrom_[j]) + # Valid because all Cartesian coordinates share the same unit. + n = len(intrinsic_keys) + result_unit = UnitsMatrix( + tuple( + tuple(uto_[0] ** 2 / (ufrom_[i] * ufrom_[j]) for j in range(n)) # ty: ignore[unsupported-operator] + for i in range(n) + ) + ) + return DenseMetric(QMatrix(result_vals, unit=result_unit)) diff --git a/src/coordinax/_src/euclidean/__init__.py b/src/coordinax/_src/euclidean/__init__.py index 408bc615..78b258d7 100644 --- a/src/coordinax/_src/euclidean/__init__.py +++ b/src/coordinax/_src/euclidean/__init__.py @@ -9,4 +9,5 @@ from .guess import * from .manifold import * from .metric import * +from .register_metric import * from .scale_factors import * diff --git a/src/coordinax/_src/euclidean/manifold.py b/src/coordinax/_src/euclidean/manifold.py index c209601e..0458989e 100644 --- a/src/coordinax/_src/euclidean/manifold.py +++ b/src/coordinax/_src/euclidean/manifold.py @@ -20,7 +20,7 @@ import dataclassish from .atlas import EuclideanAtlas -from .metric import EuclideanMetric +from .metric import FlatMetric from coordinax._src.base import AbstractManifold from coordinax._src.internal import pos_named_objs @@ -163,7 +163,7 @@ def __init__(self, ndim: int, /) -> None: raise TypeError(msg) object.__setattr__(self, "ndim", ndim) object.__setattr__(self, "atlas", EuclideanAtlas(self.ndim)) - object.__setattr__(self, "metric", EuclideanMetric(self.ndim)) + object.__setattr__(self, "metric", FlatMetric(self.ndim)) def __pdoc__(self, *, alias: bool = True, **kw: Any) -> wl.AbstractDoc: """Return the string representation. @@ -172,7 +172,7 @@ def __pdoc__(self, *, alias: bool = True, **kw: Any) -> wl.AbstractDoc: -------- >>> import wadler_lindig as wl >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.R3 >>> wl.pprint(M) Rn(3) diff --git a/src/coordinax/_src/euclidean/metric.py b/src/coordinax/_src/euclidean/metric.py index dfcaa40b..4ad76249 100644 --- a/src/coordinax/_src/euclidean/metric.py +++ b/src/coordinax/_src/euclidean/metric.py @@ -1,6 +1,6 @@ """Euclidean manifolds.""" -__all__ = ("EuclideanMetric",) +__all__ = ("FlatMetric",) import dataclasses @@ -8,20 +8,13 @@ import jax -import quaxed.numpy as jnp -import unxt as u - -import coordinax.api.charts as cxcapi -from coordinax._src.base import AbstractChart, AbstractDiagonalMetric -from coordinax._src.custom_types import CDict, OptUSys -from coordinax._src.exceptions import NoGlobalCartesianChartError -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.base import AbstractDiagonalMetricField @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class EuclideanMetric(AbstractDiagonalMetric): +class FlatMetric(AbstractDiagonalMetricField): r"""Euclidean (flat) Riemannian metric on $\mathbb{R}^n$. In Cartesian coordinates the metric is the identity matrix $g = I_n$. @@ -35,7 +28,7 @@ class EuclideanMetric(AbstractDiagonalMetric): transition map. This pullback is diagonal precisely for orthogonal coordinate charts. - `EuclideanMetric` is treated as ``AbstractDiagonalMetric`` on that + `FlatMetric` is treated as ``AbstractDiagonalMetricField`` on that orthogonal chart domain; atlas chart compatibility alone does not imply orthogonality. @@ -47,23 +40,22 @@ class EuclideanMetric(AbstractDiagonalMetric): Examples -------- >>> import jax.numpy as jnp + >>> import coordinax.api.manifolds as cxmapi >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> m = cxm.EuclideanMetric(3) - >>> at = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} - >>> m.metric_matrix(cxc.cart3d, at=at) - QuantityMatrix([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], '((, , ), (, , ), (, , ))') - - The signature is ``(1,) * ndim`` for a Riemannian (positive-definite) metric: - + >>> m = cxm.FlatMetric(3) >>> m.signature (1, 1, 1) >>> m.ndim 3 + The metric matrix is obtained via the dispatch API on the associated manifold: + + >>> at = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + >>> cxmapi.metric_matrix(cxm.R3, at, cxc.cart3d).diagonal + Array([1., 1., 1.], dtype=float64) + """ ndim: int @@ -73,34 +65,3 @@ class EuclideanMetric(AbstractDiagonalMetric): def signature(self) -> tuple[int, ...]: """Signature of the metric as a tuple of 1's.""" return tuple(1 for _ in range(self.ndim)) - - def metric_matrix( - self, chart: AbstractChart, /, *, at: CDict, usys: OptUSys = None - ) -> QuantityMatrix: - r"""Metric matrix in the given chart at the base point ``at``. - - For Cartesian charts, returns the identity matrix directly. - For other charts, compute ``J^T J`` where ``J`` is the Jacobian of - the curvilinear-to-Cartesian transition. This is diagonal exactly when - the chart is orthogonal. - - """ - # Try to get the canonical Cartesian chart for this manifold - try: - cart_chart = chart.cartesian - except NoGlobalCartesianChartError: - # Chart has no Cartesian sibling; fall back to dimensionless identity - n = self.ndim - unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) - return QuantityMatrix(jnp.eye(n), unit=UnitsMatrix(unit_tup)) - - if chart == cart_chart: - # Already Cartesian: metric is the identity - n = self.ndim - unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) - return QuantityMatrix(jnp.eye(n), unit=UnitsMatrix(unit_tup)) - - # Compute J = d(Cartesian)/d(chart) via jac_pt_map (returns QuantityMatrix) - J = cxcapi.jac_pt_map(at, chart, cart_chart, usys=usys) - JT = jnp.transpose(J, (1, 0)) # ty: ignore[no-matching-overload] - return JT @ J diff --git a/src/coordinax/_src/euclidean/register_metric.py b/src/coordinax/_src/euclidean/register_metric.py new file mode 100644 index 00000000..012f38c0 --- /dev/null +++ b/src/coordinax/_src/euclidean/register_metric.py @@ -0,0 +1,488 @@ +"""Register ``metric_matrix`` and ``metric_representation`` dispatch rules. + +Covers :class:`~coordinax.manifolds.EuclideanManifold` paired with every +chart in its atlas. The rules follow a two-tier scheme: + +* **Cartesian charts** (``Cart0D``, ``Cart1D``, ``Cart2D``, ``Cart3D``, + ``CartND``) and **orthogonal curvilinear charts** (``Radial1D``, + ``Polar2D``, ``Cylindrical3D``, ``Spherical3D``, ``MathSpherical3D``, + ``LonLatSpherical3D``) have explicit analytic diagonal metrics and return + a :class:`~coordinax._src.metric.matrix.DiagonalMetric`. +* **All other charts** compute the Jacobian pullback ``g = J^T J`` directly + and return the result as a :class:`~coordinax._src.metric.matrix.DenseMetric`. + +""" + +__all__: tuple[str, ...] = () + +from typing import Any + +import jax.numpy as jnp +import plum + +import unxt as u + +import coordinax.api.charts as cxcapi +from .manifold import EuclideanManifold +from coordinax._src.base import AbstractChart # type: ignore[type-arg] +from coordinax._src.charts.d1 import Cart1D, Radial1D +from coordinax._src.charts.d2 import Cart2D, Polar2D +from coordinax._src.charts.d3 import ( + Cart3D, + Cylindrical3D, + LonLatSpherical3D, + MathSpherical3D, + Spherical3D, +) +from coordinax._src.charts.dn import CartND +from coordinax._src.exceptions import NoGlobalCartesianChartError +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric +from coordinax.internal import QMatrix, UnitsMatrix + +# ===================================================================== +# Private helpers for unit-aware analytic metric formulas +# ===================================================================== + + +def _val_unit(q: Any, /) -> tuple[Any, u.AbstractUnit]: + """Return ``(numeric_value, unit)`` from a Quantity or plain array.""" + if isinstance(q, u.AbstractQuantity): + return q.value, q.unit + return q, u.unit("") # ty: ignore[invalid-return-type] + + +def _angle_rad(q: Any, /) -> Any: + """Return the angle value in radians, stripping units if present.""" + if isinstance(q, u.AbstractQuantity): + return u.ustrip("rad", q) + return q + + +def _angle_unit(q: Any, /) -> u.AbstractUnit: + """Return the unit of an angular coordinate, or dimensionless if plain array.""" + if isinstance(q, u.AbstractQuantity): + return q.unit + return u.unit("") + + +# ===================================================================== +# metric_representation — declare which AbstractMetricFieldMatrix subtype is +# returned for each (manifold, chart) combination. +# ===================================================================== + + +@plum.dispatch +def metric_representation( + M: EuclideanManifold, chart: Cart1D | Cart2D | Cart3D | CartND, / +) -> type[DiagonalMetric]: + """Euclidean manifold in a Cartesian chart → :class:`DiagonalMetric`. + + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> metric_representation(cxm.R3, cxc.cart3d) + + + """ + del M, chart + return DiagonalMetric + + +@plum.dispatch +def metric_representation( + M: EuclideanManifold, + chart: Radial1D + | Polar2D + | Cylindrical3D + | Spherical3D + | MathSpherical3D + | LonLatSpherical3D, + /, +) -> type[DiagonalMetric]: + """Euclidean manifold in an orthogonal curvilinear chart → `DiagonalMetric`. + + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> metric_representation(cxm.R2, cxc.polar2d) + + + >>> metric_representation(cxm.R3, cxc.sph3d) + + + """ + del M, chart + return DiagonalMetric + + +@plum.dispatch +def metric_representation( + M: EuclideanManifold, chart: AbstractChart, / +) -> type[DenseMetric]: + """Euclidean manifold in a general (non-Cartesian) chart → `DenseMetric`. + + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DenseMetric + + >>> from coordinax._src.charts.d3 import LonCosLatSpherical3D + >>> chart = LonCosLatSpherical3D() + >>> metric_representation(cxm.R3, chart) + + + """ + del M, chart + return DenseMetric + + +# ===================================================================== +# metric_matrix — Cartesian charts (identity diagonal) +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: Cart1D | Cart2D | Cart3D, / +) -> DiagonalMetric: + """Euclidean metric in a Cartesian chart: ``g = I_n``. + + The metric matrix is the identity in any Cartesian chart, represented + compactly as a `coordinax._src.metric.matrix.DiagonalMetric` with all-one + diagonal. + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Cart1D: + + >>> at = {"x": jnp.array(3.0)} + >>> g = metric_matrix(cxm.R1, at, cxc.cart1d) + >>> isinstance(g, DiagonalMetric) + True + >>> g.diagonal + Array([1.], dtype=float64) + + Cart2D: + + >>> at = {"x": jnp.array(1.0), "y": jnp.array(2.0)} + >>> metric_matrix(cxm.R2, at, cxc.cart2d).diagonal + Array([1., 1.], dtype=float64) + + Cart3D: + + >>> at = {"x": jnp.array(1.0), "y": jnp.array(2.0), "z": jnp.array(3.0)} + >>> metric_matrix(cxm.R3, at, cxc.cart3d).diagonal + Array([1., 1., 1.], dtype=float64) + + """ + del M, point + n = len(chart.components) + return DiagonalMetric(jnp.ones(n)) + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: CartND, / +) -> DiagonalMetric: + """Euclidean metric in CartND: ``g = I_N`` where *N* is inferred from the point. + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + + >>> at = {"q": jnp.array([1.0, 2.0, 3.0])} + >>> metric_matrix(cxm.R3, at, cxc.cartnd).diagonal + Array([1., 1., 1.], dtype=float64) + + """ + del M, chart + n = jnp.asarray(point["q"]).shape[0] + return DiagonalMetric(jnp.ones(n)) + + +# ===================================================================== +# metric_matrix — Orthogonal curvilinear charts (analytic formulas) +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: Radial1D, / +) -> DiagonalMetric: + """Euclidean metric in ``Radial1D``: ``g = diag(1)``. + + The only component is ``g_rr = 1`` (the radial direction is an + isometry of Euclidean distance in 1-D). + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Dimensionless: + + >>> at = {"r": jnp.array(2.0)} + >>> g = metric_matrix(cxm.R1, at, cxc.radial1d) + >>> isinstance(g, DiagonalMetric) + True + + With length units: + + >>> at = {"r": u.Q(2.0, "m")} + >>> g = metric_matrix(cxm.R1, at, cxc.radial1d) + >>> g.diagonal + QMatrix([1.], '(,)') + + """ + del M, point, chart + dmls = u.unit("") + return DiagonalMetric(QMatrix(jnp.ones(1), unit=UnitsMatrix((dmls,)))) + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: Polar2D, / +) -> DiagonalMetric: + r"""Euclidean metric in ``Polar2D``: ``g = diag(1, r²)``. + + point must contain keys ``"r"`` (length) and ``"theta"`` (angle). + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Dimensionless ``r``: + + >>> at = {"r": jnp.array(3.0), "theta": jnp.array(0.5)} + >>> g = metric_matrix(cxm.R2, at, cxc.polar2d) + >>> g.diagonal + QMatrix([1., 9.], '(, )') + + Length-valued ``r`` and angle-valued ``theta``: + + >>> at = {"r": u.Q(3.0, "m"), "theta": u.Angle(0.5, "rad")} + >>> g = metric_matrix(cxm.R2, at, cxc.polar2d) + >>> g.diagonal + QMatrix([1., 9.], '(, m2 / rad2)') + + """ + del M, chart + r_val, r_unit = _val_unit(point["r"]) + theta_unit = _angle_unit(point["theta"]) + diag = jnp.stack([jnp.asarray(1.0), r_val**2]) + units = UnitsMatrix((u.unit(""), r_unit**2 / theta_unit**2)) + return DiagonalMetric(QMatrix(diag, unit=units)) + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: Cylindrical3D, / +) -> DiagonalMetric: + r"""Euclidean metric in ``Cylindrical3D``: ``g = diag(1, ρ², 1)``. + + ``point`` must contain keys ``"rho"`` (length), ``"phi"`` (angle), + and ``"z"`` (length). + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> at = {"rho": u.Q(3.0, "m"), "phi": u.Angle(0.0, "rad"), "z": u.Q(1.0, "m")} + >>> g = metric_matrix(cxm.R3, at, cxc.cyl3d) + >>> g.diagonal + QMatrix([1., 9., 1.], '(, m2 / rad2, )') + + """ + del M, chart + rho_val, rho_unit = _val_unit(point["rho"]) + phi_unit = _angle_unit(point["phi"]) + dmls = u.unit("") + diag = jnp.stack([jnp.asarray(1.0), rho_val**2, jnp.asarray(1.0)]) + units = UnitsMatrix((dmls, rho_unit**2 / phi_unit**2, dmls)) + return DiagonalMetric(QMatrix(diag, unit=units)) + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: Spherical3D, / +) -> DiagonalMetric: + r"""Euclidean metric in ``Spherical3D``: ``g = diag(1, r², r²sin²θ)``. + + Physics convention: ``θ`` is the polar (colatitude) angle measured from + the ``+z`` axis, ``φ`` is the azimuthal angle. ``point`` must contain + keys ``"r"`` (length), ``"theta"`` (polar angle), and ``"phi"`` + (azimuthal angle). + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> at = { + ... "r": u.Q(2.0, "m"), + ... "theta": u.Angle(jnp.pi / 2, "rad"), + ... "phi": u.Angle(0.0, "rad"), + ... } + >>> g = metric_matrix(cxm.R3, at, cxc.sph3d) + >>> isinstance(g, DiagonalMetric) + True + >>> g.diagonal + QMatrix([1., 4., 4.], '(, m2 / rad2, m2 / rad2)') + + """ + del M, chart + r_val, r_unit = _val_unit(point["r"]) + theta_val = _angle_rad(point["theta"]) + theta_unit = _angle_unit(point["theta"]) + phi_unit = _angle_unit(point["phi"]) + r2 = r_val**2 + r2_unit = r_unit**2 + diag = jnp.stack([jnp.asarray(1.0), r2, r2 * jnp.sin(theta_val) ** 2]) + units = UnitsMatrix((u.unit(""), r2_unit / theta_unit**2, r2_unit / phi_unit**2)) + return DiagonalMetric(QMatrix(diag, unit=units)) + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: MathSpherical3D, / +) -> DiagonalMetric: + r"""Euclidean metric in ``MathSpherical3D``: ``g = diag(1, r²sin²φ, r²)``. + + Math convention: ``φ`` is the polar angle from the ``+z`` axis + (colatitude), ``θ`` is the azimuthal angle. ``point`` must contain + keys ``"r"`` (length), ``"theta"`` (azimuthal angle), and ``"phi"`` + (polar / colatitude angle). + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> at = { + ... "r": u.Q(2.0, "m"), + ... "theta": u.Angle(0.0, "rad"), + ... "phi": u.Angle(jnp.pi / 2, "rad"), + ... } + >>> g = metric_matrix(cxm.R3, at, cxc.math_sph3d) + >>> isinstance(g, DiagonalMetric) + True + >>> g.diagonal + QMatrix([1., 4., 4.], '(, m2 / rad2, m2 / rad2)') + + """ + del M, chart + r_val, r_unit = _val_unit(point["r"]) + phi_val = _angle_rad(point["phi"]) # polar / colatitude angle + theta_unit = _angle_unit(point["theta"]) + phi_unit = _angle_unit(point["phi"]) + r2 = r_val**2 + r2_unit = r_unit**2 + diag = jnp.stack([jnp.asarray(1.0), r2 * jnp.sin(phi_val) ** 2, r2]) + units = UnitsMatrix((u.unit(""), r2_unit / theta_unit**2, r2_unit / phi_unit**2)) + return DiagonalMetric(QMatrix(diag, unit=units)) + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: LonLatSpherical3D, / +) -> DiagonalMetric: + r"""Euclidean metric in ``LonLatSpherical3D``. + + The metric is ``g = diag(distance²cos²lat, distance², 1)`` (components + ordered as ``(lon, lat, distance)``). ``point`` must contain keys + ``"lon"``, ``"lat"``, and ``"distance"`` (length). + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> at = { + ... "lon": u.Angle(0.0, "rad"), + ... "lat": u.Angle(0.0, "rad"), + ... "distance": u.Q(2.0, "m"), + ... } + >>> g = metric_matrix(cxm.R3, at, cxc.lonlat_sph3d) + >>> isinstance(g, DiagonalMetric) + True + >>> g.diagonal + QMatrix([4., 4., 1.], '(m2 / rad2, m2 / rad2, )') + + """ + del M, chart + d_val, d_unit = _val_unit(point["distance"]) + lat_val = _angle_rad(point["lat"]) + lon_unit = _angle_unit(point["lon"]) + lat_unit = _angle_unit(point["lat"]) + d2 = d_val**2 + d2_unit = d_unit**2 + diag = jnp.stack([d2 * jnp.cos(lat_val) ** 2, d2, jnp.asarray(1.0)]) + units = UnitsMatrix((d2_unit / lon_unit**2, d2_unit / lat_unit**2, u.unit(""))) + return DiagonalMetric(QMatrix(diag, unit=units)) + + +# ===================================================================== +# metric_matrix — General fallback (Jacobian pullback) +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: EuclideanManifold, point: dict, chart: AbstractChart, / +) -> DenseMetric: + """Euclidean metric in a general chart via Jacobian pullback ``g = J^T J``. + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DenseMetric + >>> from coordinax._src.charts.d3 import LonCosLatSpherical3D + + Non-orthogonal chart (fallback, returns ``DenseMetric``): + + >>> M = cxm.R3 + >>> chart = LonCosLatSpherical3D() + >>> at = { + ... "lon_coslat": u.Angle(0.0, "rad"), + ... "lat": u.Angle(0.0, "rad"), + ... "distance": u.Q(2.0, "m"), + ... } + >>> g = metric_matrix(M, at, chart) + >>> isinstance(g, DenseMetric) + True + + """ + try: + cart_chart = chart.cartesian + except NoGlobalCartesianChartError: + n = M.ndim + unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) + return DenseMetric(QMatrix(jnp.eye(n), unit=UnitsMatrix(unit_tup))) + J = cxcapi.jac_pt_map(point, chart, cart_chart, usys=None) + JT = J.T + return DenseMetric(JT @ J) # ty: ignore[unsupported-operator] diff --git a/src/coordinax/_src/euclidean/scale_factors.py b/src/coordinax/_src/euclidean/scale_factors.py index eaf08a53..1d896f64 100644 --- a/src/coordinax/_src/euclidean/scale_factors.py +++ b/src/coordinax/_src/euclidean/scale_factors.py @@ -11,40 +11,38 @@ from unxt.quantity import AllowValue, is_any_quantity import coordinax.api.charts as cxcapi -from .metric import EuclideanMetric +from .metric import FlatMetric from coordinax._src.base import AbstractChart from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax.internal import QMatrix, UnitsMatrix DMLS = u.unit("") @plum.dispatch def scale_factors( - metric: EuclideanMetric, + metric: FlatMetric, chart: AbstractChart, /, *, at: CDict, usys: OptUSys = None, -) -> QuantityMatrix: +) -> QMatrix: """Compute only the Euclidean metric diagonal instead of forming ``J.T @ J``. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> metric = cxm.EuclideanMetric(3) + >>> metric = cxm.FlatMetric(3) >>> at = { ... "r": u.Q(jnp.array(2.0), "km"), ... "theta": u.Angle(jnp.pi / 2, "rad"), ... "phi": u.Angle(jnp.array(0.0), "rad"), ... } >>> cxm.scale_factors(metric, cxc.sph3d, at=at) - QuantityMatrix([1., 4., 4.], '(, km2 / rad2, km2 / rad2)') + QMatrix([1., 4., 4.], '(, km2 / rad2, km2 / rad2)') """ del metric @@ -52,7 +50,7 @@ def scale_factors( if chart == cart_chart: n = len(chart.components) - return QuantityMatrix( + return QMatrix( jnp.ones((n,)), unit=UnitsMatrix(tuple(u.unit("") for _ in range(n))) ) @@ -60,22 +58,22 @@ def scale_factors( return _column_squared_norms(J) -def _column_squared_norms(J: QuantityMatrix | Array, /) -> QuantityMatrix: +def _column_squared_norms(J: QMatrix | Array, /) -> QMatrix: """Return ``diag(J.T @ J)`` without forming the full Gram matrix.""" - if isinstance(J, QuantityMatrix): + if isinstance(J, QMatrix): return _quantity_column_squared_norms(J) return _array_column_squared_norms(J) -def _array_column_squared_norms(J: Array, /) -> QuantityMatrix: +def _array_column_squared_norms(J: Array, /) -> QMatrix: """Return squared column norms for a dimensionless Jacobian array.""" value = jnp.einsum("...ji,...ji->...i", J, J) n = value.shape[-1] unit = UnitsMatrix(tuple(DMLS for _ in range(n))) - return QuantityMatrix(value, unit) + return QMatrix(value, unit) -def _quantity_column_squared_norms(J: QuantityMatrix) -> QuantityMatrix: +def _quantity_column_squared_norms(J: QMatrix) -> QMatrix: """Return squared column norms for a heterogeneous-unit Jacobian.""" xs = tuple(_colnorm2(J[:, i]) for i in range(J.shape[-1])) units = tuple(u.unit_of(x) if is_any_quantity(x) else DMLS for x in xs) @@ -83,9 +81,9 @@ def _quantity_column_squared_norms(J: QuantityMatrix) -> QuantityMatrix: [u.ustrip(AllowValue, unit, x) for x, unit in zip(xs, units, strict=True)], axis=-1, ) - return QuantityMatrix(value, unit=UnitsMatrix(units)) + return QMatrix(value, unit=UnitsMatrix(units)) -def _colnorm2(column: QuantityMatrix) -> u.AbstractQuantity | Array: +def _colnorm2(column: QMatrix) -> u.AbstractQuantity | Array: """Return the squared norm of a single Jacobian column.""" return jnp.dot(column, column) diff --git a/src/coordinax/_src/minkowski/__init__.py b/src/coordinax/_src/minkowski/__init__.py index 9355beb6..045bf56d 100644 --- a/src/coordinax/_src/minkowski/__init__.py +++ b/src/coordinax/_src/minkowski/__init__.py @@ -4,3 +4,5 @@ from .charts import * from .manifold import * from .metric import * +from .register_metric import * +from .scale_factors import * diff --git a/src/coordinax/_src/minkowski/metric.py b/src/coordinax/_src/minkowski/metric.py index 94583774..75d5168a 100644 --- a/src/coordinax/_src/minkowski/metric.py +++ b/src/coordinax/_src/minkowski/metric.py @@ -7,21 +7,14 @@ from typing import final import jax -import jax.numpy as jnp -import quaxed.numpy as qnp -import unxt as u - -import coordinax.charts as cxc -from coordinax._src.base import AbstractDiagonalMetric -from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.base import AbstractDiagonalMetricField @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class MinkowskiMetric(AbstractDiagonalMetric): +class MinkowskiMetric(AbstractDiagonalMetricField): r"""Pseudo-Riemannian (Lorentzian) metric on Minkowski spacetime. In the canonical {class}`~coordinax.charts.MinkowskiCT` chart @@ -43,17 +36,16 @@ class MinkowskiMetric(AbstractDiagonalMetric): >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix Canonical Cartesian spacetime chart: >>> m = cxm.MinkowskiMetric() + >>> M = cxm.MinkowskiManifold() >>> at = {"ct": jnp.array(0.0), "x": jnp.array(0.0), ... "y": jnp.array(0.0), "z": jnp.array(0.0)} - >>> m.metric_matrix(cxc.minkowskict, at=at).value - Array([[-1., 0., 0., 0.], - [ 0., 1., 0., 0.], - [ 0., 0., 1., 0.], - [ 0., 0., 0., 1.]], dtype=float64) + >>> metric_matrix(M, at, cxc.minkowskict).diagonal + Array([-1., 1., 1., 1.], dtype=float64) The signature is Lorentzian (pseudo-Riemannian): @@ -69,68 +61,3 @@ class MinkowskiMetric(AbstractDiagonalMetric): def signature(self) -> tuple[int, ...]: """Metric signature: ``(-1, 1, 1, 1)`` — Lorentzian pseudo-Riemannian.""" return (-1, 1, 1, 1) - - def metric_matrix( - self, chart: cxc.AbstractChart, /, *, at: CDict, usys: OptUSys = None - ) -> QuantityMatrix: - r"""Compute the Minkowski metric tensor $g_{ij}$ at the base point ``at``. - - In the canonical :class:`~coordinax.charts.MinkowskiCT` chart returns - $\eta = \operatorname{diag}(-1, 1, 1, 1)$ directly. For any other - registered chart computes the pullback $g = J^T \eta J$ via - :func:`~coordinax.charts.jac_pt_map`. - - Parameters - ---------- - chart : AbstractChart - Coordinate chart in which to express the metric. Must be a chart - registered with :class:`~coordinax.manifolds.MinkowskiAtlas` and - support a ``.cartesian`` property. - at : CDict - Base point in ``chart`` coordinates at which to evaluate. - usys : OptUSys, optional - Unit system to use for the metric evaluation. - - Returns - ------- - QuantityMatrix, shape (4, 4) - Metric matrix $g_{ij}$ in the given chart basis. - - Examples - -------- - >>> import jax.numpy as jnp - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - - Canonical Minkowski chart — returns the flat Minkowski matrix: - - >>> m = cxm.MinkowskiMetric() - >>> at = {"ct": jnp.array(0.0), "x": jnp.array(0.0), - ... "y": jnp.array(0.0), "z": jnp.array(0.0)} - >>> m.metric_matrix(cxc.minkowskict, at=at).value - Array([[-1., 0., 0., 0.], - [ 0., 1., 0., 0.], - [ 0., 0., 1., 0.], - [ 0., 0., 0., 1.]], dtype=float64) - - """ - n = 4 - unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) - - cart_chart = chart.cartesian - if chart == cart_chart: - # Already the canonical Cartesian spacetime chart: η = diag(-1,1,1,1) - return QuantityMatrix( - jnp.diag(jnp.array([-1.0, 1.0, 1.0, 1.0])), - unit=UnitsMatrix(unit_tup), - ) - - # Pullback: g = J^T η J - # J: jacobian of chart → cart_chart, shape (4, 4) - J = cxc.jac_pt_map(at, chart, cart_chart, usys=usys) - JT = qnp.transpose(J, (1, 0)) - eta = QuantityMatrix( - jnp.diag(jnp.array([-1.0, 1.0, 1.0, 1.0])), - unit=UnitsMatrix(unit_tup), - ) - return JT @ eta @ J # ty: ignore[invalid-return-type] diff --git a/src/coordinax/_src/minkowski/register_metric.py b/src/coordinax/_src/minkowski/register_metric.py new file mode 100644 index 00000000..6ffed45b --- /dev/null +++ b/src/coordinax/_src/minkowski/register_metric.py @@ -0,0 +1,145 @@ +r"""Register ``metric_matrix`` and ``metric_representation`` dispatch rules. + +Covers :class:`~coordinax.manifolds.MinkowskiManifold` (Minkowski spacetime +$\mathbb{R}^{1,3}$) paired with charts in the Minkowski atlas. + +* For the canonical :class:`~coordinax.charts.MinkowskiCT` chart + ``(ct, x, y, z)`` the metric is $\eta = \operatorname{diag}(-1, 1, 1, 1)$, + returned as a :class:`~coordinax._src.metric.matrix.DiagonalMetric`. +* For all other registered charts the rule computes the pullback + $g = J^T \eta J$ via :func:`~coordinax.charts.jac_pt_map` directly + and wraps the result in a :class:`~coordinax._src.metric.matrix.DenseMetric`. + +""" + +__all__: tuple[str, ...] = () + +import jax.numpy as jnp +import plum + +import unxt as u + +import coordinax.charts as cxc +from .charts import MinkowskiCT +from .manifold import MinkowskiManifold +from coordinax._src.base import AbstractChart # type: ignore[type-arg] +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric +from coordinax.api.manifolds import metric_matrix, metric_representation +from coordinax.internal import QMatrix, UnitsMatrix + +# ===================================================================== +# metric_representation +# ===================================================================== + + +@plum.dispatch +def metric_representation( + M: MinkowskiManifold, chart: MinkowskiCT, / +) -> type[DiagonalMetric]: + """Minkowski manifold in the canonical CT chart → :class:`DiagonalMetric`. + + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> M = cxm.MinkowskiManifold() + >>> metric_representation(M, cxc.minkowskict) + + + """ + del M, chart + return DiagonalMetric + + +@plum.dispatch +def metric_representation( + M: MinkowskiManifold, chart: AbstractChart, / +) -> type[DenseMetric]: + """Minkowski manifold in a general chart → :class:`DenseMetric`. + + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DenseMetric + + >>> M = cxm.MinkowskiManifold() + >>> metric_representation(M, cxc.minkowskict) # doctest: +SKIP + + + """ + del M, chart + return DenseMetric + + +# ===================================================================== +# metric_matrix — canonical Minkowski CT chart +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: MinkowskiManifold, point: dict, chart: MinkowskiCT, / +) -> DiagonalMetric: + r"""Minkowski metric $\eta = \operatorname{diag}(-1, 1, 1, 1)$ in CT chart. + + All four components ``(ct, x, y, z)`` carry dimension ``"length"``, so the + entries are dimensionless. + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + >>> M = cxm.MinkowskiManifold() + >>> at = {"ct": jnp.array(0.0), "x": jnp.array(1.0), + ... "y": jnp.array(0.0), "z": jnp.array(0.0)} + >>> g = metric_matrix(M, at, cxc.minkowskict) + >>> isinstance(g, DiagonalMetric) + True + >>> g.diagonal + Array([-1., 1., 1., 1.], dtype=float64) + + """ + del M, point, chart + return DiagonalMetric(jnp.array([-1.0, 1.0, 1.0, 1.0])) + + +# ===================================================================== +# metric_matrix — general fallback +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: MinkowskiManifold, point: dict, chart: AbstractChart, / +) -> DenseMetric: + r"""Minkowski metric in a general chart via Jacobian pullback $g = J^T \eta J$. + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + Canonical chart uses the specific dispatch above, not this fallback: + + >>> M = cxm.MinkowskiManifold() + >>> at = {"ct": jnp.array(0.0), "x": jnp.array(1.0), + ... "y": jnp.array(0.0), "z": jnp.array(0.0)} + >>> g = metric_matrix(M, at, cxc.minkowskict) + >>> isinstance(g, DiagonalMetric) + True + + """ + n = 4 + unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) + cart_chart = chart.cartesian + J = cxc.jac_pt_map(point, chart, cart_chart, usys=None) + JT = J.T + eta = QMatrix( + jnp.diag(jnp.array([-1.0, 1.0, 1.0, 1.0])), + unit=UnitsMatrix(unit_tup), + ) + return DenseMetric(JT @ eta @ J) diff --git a/src/coordinax/_src/minkowski/scale_factors.py b/src/coordinax/_src/minkowski/scale_factors.py new file mode 100644 index 00000000..1b41f2b5 --- /dev/null +++ b/src/coordinax/_src/minkowski/scale_factors.py @@ -0,0 +1,41 @@ +"""Minkowski specializations for `coordinax.manifolds.scale_factors`.""" + +__all__: tuple[str, ...] = () + +import jax.numpy as jnp +import plum + +import unxt as u + +from .charts import MinkowskiCT +from .metric import MinkowskiMetric +from coordinax._src.custom_types import CDict, OptUSys +from coordinax.internal import QMatrix, UnitsMatrix + + +@plum.dispatch +def scale_factors( + metric: MinkowskiMetric, chart: MinkowskiCT, /, *, at: CDict, usys: OptUSys = None +) -> QMatrix: + r"""Return the Minkowski metric diagonal $\eta = \operatorname{diag}(-1,1,1,1)$. + + In the canonical `coordinax.charts.MinkowskiCT` chart the metric is + constant, so ``at`` is ignored and the result does not depend on the base + point. + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + >>> metric = cxm.MinkowskiMetric() + >>> at = {"ct": jnp.array(0.0), "x": jnp.array(1.0), + ... "y": jnp.array(0.0), "z": jnp.array(0.0)} + >>> cxm.scale_factors(metric, cxc.minkowskict, at=at) + QMatrix([-1., 1., 1., 1.], '(, , , )') + + """ + del chart, at, usys + n = metric.ndim + value = jnp.array(list(metric.signature), dtype=float) + units = UnitsMatrix(tuple(u.unit("") for _ in range(n))) + return QMatrix(value, unit=units) diff --git a/src/coordinax/_src/null/metric.py b/src/coordinax/_src/null/metric.py index 0c0dbb75..4a256b98 100644 --- a/src/coordinax/_src/null/metric.py +++ b/src/coordinax/_src/null/metric.py @@ -4,26 +4,23 @@ import dataclasses -from jaxtyping import Array -from typing import Any, final +from typing import final import jax -from coordinax._src.base import AbstractMetric -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.base import AbstractMetricField @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class NoMetric(AbstractMetric): +class NoMetric(AbstractMetricField): """A degenerate placeholder metric with no geometry. ``NoMetric`` is a sentinel value used when a metric object is required by the API but none has been specified by the user. - ``ndim == False`` signals "no metric specified". - - ``metric_matrix(chart, at)`` always raises ``NoGlobalCartesianChartError``. """ @@ -37,11 +34,6 @@ def signature(self) -> tuple[int, ...]: """Signature of the degenerate metric.""" return () - def metric_matrix(self, *args: Any, **kwargs: Any) -> QuantityMatrix | Array: - r"""Compute the metric tensor $g_{ij}$ at base point ``at``.""" - del args, kwargs - return QuantityMatrix(jax.numpy.array([]), UnitsMatrix("")) - no_metric = NoMetric() """Canonical instance of `coordinax.manifolds.NoMetric`.""" diff --git a/src/coordinax/_src/product/__init__.py b/src/coordinax/_src/product/__init__.py index a33e4d9f..601bbec8 100644 --- a/src/coordinax/_src/product/__init__.py +++ b/src/coordinax/_src/product/__init__.py @@ -5,3 +5,4 @@ from .galilean_ct import * from .manifold import * from .metric import * +from .register_metric import * diff --git a/src/coordinax/_src/product/atlas.py b/src/coordinax/_src/product/atlas.py index 2ef3cf15..e19c3764 100644 --- a/src/coordinax/_src/product/atlas.py +++ b/src/coordinax/_src/product/atlas.py @@ -26,8 +26,6 @@ class CartesianProductAtlas(AbstractAtlas): The atlas consists of Cartesian product charts formed from the atlases of the factor manifolds. - Examples - -------- Consider the product manifold $S^2 \times \\mathbb{R}$, where - $S^2$ is the 2-sphere with spherical coordinates $(\theta, \\phi)$ and diff --git a/src/coordinax/_src/product/chart.py b/src/coordinax/_src/product/chart.py index f6da6cf1..63196a0f 100644 --- a/src/coordinax/_src/product/chart.py +++ b/src/coordinax/_src/product/chart.py @@ -464,8 +464,6 @@ def pt_map( ) -> CDict: """AbstractChart -> Cartesian -> AbstractCartesianProductChart. - Examples - -------- >>> import coordinax.charts as cxc >>> import unxt as u @@ -499,8 +497,6 @@ def pt_map( ) -> CDict: """AbstractCartesianProductChart -> Cartesian -> AbstractChart. - Examples - -------- >>> import coordinax.charts as cxc >>> import unxt as u diff --git a/src/coordinax/_src/product/manifold.py b/src/coordinax/_src/product/manifold.py index 9c40b988..1b88ad57 100644 --- a/src/coordinax/_src/product/manifold.py +++ b/src/coordinax/_src/product/manifold.py @@ -11,7 +11,7 @@ import jax from .atlas import CartesianProductAtlas -from .metric import CartesianProductMetric +from .metric import ProductMetric from coordinax._src.base import AbstractManifold @@ -62,7 +62,7 @@ class CartesianProductManifold(AbstractManifold): ---------- atlas : CartesianProductAtlas The product atlas formed from the factor atlases. - metric : CartesianProductMetric + metric : ProductMetric The canonical product metric formed from the factor metrics. ndim : int Total intrinsic dimension $\sum_i n_i$. @@ -81,8 +81,7 @@ class CartesianProductManifold(AbstractManifold): >>> import wadler_lindig as wl >>> M = cxm.CartesianProductManifold( - ... factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - ... factor_names=("S2", "R1"), + ... factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1") ... ) >>> wl.pprint(M, width=60) CartesianProductManifold( @@ -144,8 +143,7 @@ class CartesianProductManifold(AbstractManifold): \times \mathbb{R}^2$ has $\dim = 2 + 2 = 4$: >>> M4 = cxm.CartesianProductManifold( - ... factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(2)), - ... factor_names=("S2", "R2"), + ... factors=(cxm.S2, cxm.R2), factor_names=("S2", "R2") ... ) >>> M4.ndim 4 @@ -157,8 +155,7 @@ class CartesianProductManifold(AbstractManifold): structure rather than a single `EuclideanManifold`): >>> Mprod = cxm.CartesianProductManifold( - ... factors=(cxm.EuclideanManifold(2), cxm.EuclideanManifold(1)), - ... factor_names=("xy", "z"), + ... factors=(cxm.R2, cxm.R1), factor_names=("xy", "z") ... ) >>> Mprod.ndim 3 @@ -182,8 +179,7 @@ def atlas(self) -> CartesianProductAtlas: >>> import wadler_lindig as wl >>> M = cxm.CartesianProductManifold( - ... factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - ... factor_names=("S2", "R1")) + ... factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1")) >>> wl.pprint(M.atlas, width=60) CartesianProductAtlas( factors=(HyperSphericalAtlas(), EuclideanAtlas(ndim=1)), @@ -197,7 +193,7 @@ def atlas(self) -> CartesianProductAtlas: ) @property - def metric(self) -> CartesianProductMetric: + def metric(self) -> ProductMetric: """Return the canonical product metric from the factor metrics.""" factor_metrics = tuple(factor.metric for factor in self.factors) - return CartesianProductMetric(factors=factor_metrics) + return ProductMetric(factors=factor_metrics) diff --git a/src/coordinax/_src/product/metric.py b/src/coordinax/_src/product/metric.py index a54bdce2..58a82f5a 100644 --- a/src/coordinax/_src/product/metric.py +++ b/src/coordinax/_src/product/metric.py @@ -1,26 +1,20 @@ """Product manifold metrics.""" -__all__ = ("CartesianProductMetric",) +__all__ = ("ProductMetric",) import dataclasses from typing import final import jax -import jax.numpy as jnp -import unxt as u - -import coordinax.charts as cxc -from coordinax._src.base import AbstractMetric -from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.base import AbstractMetricField @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class CartesianProductMetric(AbstractMetric): +class ProductMetric(AbstractMetricField): r"""Canonical product metric on a Cartesian product manifold. For factor manifolds $(M_i, g_i)$, the product metric on @@ -34,70 +28,14 @@ class CartesianProductMetric(AbstractMetric): which is block diagonal in a product chart. """ - factors: tuple[AbstractMetric, ...] + factors: tuple[AbstractMetricField, ...] """Metrics for each factor manifold, in product order.""" def __post_init__(self) -> None: if len(self.factors) == 0: - raise ValueError("CartesianProductMetric requires at least one factor.") + raise ValueError("ProductMetric requires at least one factor.") @property def signature(self) -> tuple[int, ...]: """Concatenated factor signatures in product order.""" return tuple(s for m in self.factors for s in m.signature) - - def metric_matrix( - self, - chart: cxc.AbstractChart, - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> QuantityMatrix: - """Return block-diagonal matrix from factor metrics in product chart.""" - if not isinstance(chart, cxc.AbstractCartesianProductChart): - msg = f"CartesianProductMetric requires a product chart, got {chart!r}." - raise TypeError(msg) - - if len(chart.factors) != len(self.factors): - msg = ( - "Product chart factor count does not match " - "product metric factor count: " - f"{len(chart.factors)} != {len(self.factors)}" - ) - raise ValueError(msg) - - parts = chart.split_components(at) - blocks = tuple( - _as_quantity_matrix(metric.metric_matrix(c, at=p, usys=usys)) - for metric, c, p in zip(self.factors, chart.factors, parts, strict=True) - ) - - n = sum(block.shape[0] for block in blocks) - dtype = jnp.promote_types(*(block.value.dtype for block in blocks)) - value = jnp.zeros((n, n), dtype=dtype) - units = [[u.unit("") for _ in range(n)] for _ in range(n)] - - offset = 0 - for block in blocks: - block_n = block.shape[0] - value = value.at[offset : offset + block_n, offset : offset + block_n].set( - block.value - ) - for i in range(block_n): - for j in range(block_n): - units[offset + i][offset + j] = block.unit[i, j] - offset += block_n - - unit_tup = tuple(tuple(row) for row in units) - return QuantityMatrix(value=value, unit=UnitsMatrix(unit_tup)) - - -def _as_quantity_matrix(x: QuantityMatrix | jax.Array) -> QuantityMatrix: - """Convert a numeric matrix into a dimensionless QuantityMatrix.""" - if isinstance(x, QuantityMatrix): - return x - - n = x.shape[0] - unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) - return QuantityMatrix(value=x, unit=UnitsMatrix(unit_tup)) diff --git a/src/coordinax/_src/product/register_metric.py b/src/coordinax/_src/product/register_metric.py new file mode 100644 index 00000000..02a0ba85 --- /dev/null +++ b/src/coordinax/_src/product/register_metric.py @@ -0,0 +1,132 @@ +"""Register ``metric_matrix`` and ``metric_representation`` dispatch rules. + +Covers :class:`~coordinax.manifolds.CartesianProductManifold` paired with +:class:`~coordinax.charts.AbstractCartesianProductChart`. + +The product metric is block-diagonal: each block is the factor metric +evaluated at the corresponding component slice of the point, computed +by recursively calling the standalone ``metric_matrix`` dispatch API. + +""" + +__all__: tuple[str, ...] = () + +import jax.numpy as jnp +import plum + +import unxt as u + +from .chart import AbstractCartesianProductChart +from .manifold import CartesianProductManifold +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric +from coordinax.api.manifolds import metric_matrix, metric_representation +from coordinax.internal import QMatrix, UnitsMatrix + +# ===================================================================== +# Private helpers +# ===================================================================== + + +def _mm_to_qm(mm: DenseMetric | DiagonalMetric) -> QMatrix: + """Convert an AbstractMetricMatrix to a QMatrix.""" + if isinstance(mm, DiagonalMetric): + dense = mm.to_dense() + mat = dense.matrix + else: + mat = mm.matrix + if isinstance(mat, QMatrix): + return mat + n = mat.shape[0] + unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) + return QMatrix(mat, unit=UnitsMatrix(unit_tup)) + + +# ===================================================================== +# metric_representation +# ===================================================================== + + +@plum.dispatch +def metric_representation( + M: CartesianProductManifold, chart: AbstractCartesianProductChart, / +) -> type[DenseMetric]: + """Product manifold in a product chart → :class:`DenseMetric`. + + The product metric is block-diagonal in general (not necessarily diagonal + even if each factor metric is diagonal), so :class:`DenseMetric` is the + conservative declaration. + + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_representation + >>> from coordinax._src.metric.matrix import DenseMetric + + >>> M = cxm.CartesianProductManifold( + ... factors=(cxm.R2, cxm.R1), factor_names=("xy", "z") + ... ) + >>> chart = M.default_chart() + >>> metric_representation(M, chart) + + + """ + del M, chart + return DenseMetric + + +# ===================================================================== +# metric_matrix +# ===================================================================== + + +@plum.dispatch +def metric_matrix( + M: CartesianProductManifold, point: dict, chart: AbstractCartesianProductChart, / +) -> DenseMetric: + r"""Product metric (block-diagonal) in a product chart. + + Assembles the block-diagonal matrix from factor metrics by recursively + calling the standalone ``metric_matrix`` dispatch API. + + >>> import jax.numpy as jnp + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DenseMetric + + Two-factor Euclidean product (R² x R¹): + + >>> M = cxm.CartesianProductManifold( + ... factors=(cxm.R2, cxm.R1), factor_names=("xy", "z") + ... ) + >>> chart = M.default_chart() + >>> at = {k: jnp.array(0.0) for k in chart.components} + >>> g = metric_matrix(M, at, chart) + >>> isinstance(g, DenseMetric) + True + >>> g.ndim + 3 + + """ + parts = chart.split_components(point) + factor_blocks = [ + _mm_to_qm(metric_matrix(fm, fp, fc)) + for fm, fc, fp in zip(M.factors, chart.factors, parts, strict=True) + ] + + n = sum(block.shape[0] for block in factor_blocks) + dtype = jnp.result_type(*(block.value.dtype for block in factor_blocks)) + value = jnp.zeros((n, n), dtype=dtype) + units = [[u.unit("") for _ in range(n)] for _ in range(n)] + + offset = 0 + for block in factor_blocks: + block_n = block.shape[0] + value = value.at[offset : offset + block_n, offset : offset + block_n].set( + block.value + ) + for i in range(block_n): + for j in range(block_n): + units[offset + i][offset + j] = block.unit[i, j] + offset += block_n + + unit_tup = tuple(tuple(row) for row in units) + G = QMatrix(value=value, unit=UnitsMatrix(unit_tup)) + return DenseMetric(G) diff --git a/src/coordinax/_src/spherical/__init__.py b/src/coordinax/_src/spherical/__init__.py index b890d1d4..00effc49 100644 --- a/src/coordinax/_src/spherical/__init__.py +++ b/src/coordinax/_src/spherical/__init__.py @@ -7,4 +7,6 @@ from .manifold import * from .metric import * from .register_charts import * +from .register_metric import * from .register_ptmap import * +from .scale_factors import * diff --git a/src/coordinax/_src/spherical/embed.py b/src/coordinax/_src/spherical/embed.py index d705328e..3101dbac 100644 --- a/src/coordinax/_src/spherical/embed.py +++ b/src/coordinax/_src/spherical/embed.py @@ -112,10 +112,10 @@ def embedded_twosphere( radius: float | u.AbstractQuantity, ambient: cxc.AbstractChart[Any, Any] = cxc.sph3d, ) -> EmbeddedManifold: - """Create an {class}`coordinax.manifolds.EmbeddedManifold` for the two-sphere. + """Create an `coordinax.manifolds.EmbeddedManifold` for the two-sphere. This is a convenience helper that constructs an - {class}`coordinax.manifolds.EmbeddedManifold` with + `coordinax.manifolds.EmbeddedManifold` with ``intrinsic=HyperSphericalManifold()`` and ``embedding=TwoSphereIn3D(radius, ambient)``. @@ -125,13 +125,7 @@ def embedded_twosphere( Sphere radius. ambient Ambient chart for the embedding. Defaults to - `{class}`~coordinax.charts.Spherical3D`. - - Returns - ------- - EmbeddedManifold - An embedded manifold pairing the two-sphere manifold with the - {class}`coordinax.manifolds.TwoSphereIn3D` embedding. + `coordinax.charts.Spherical3D`. Examples -------- diff --git a/src/coordinax/_src/spherical/manifold.py b/src/coordinax/_src/spherical/manifold.py index f18157d6..8a439747 100644 --- a/src/coordinax/_src/spherical/manifold.py +++ b/src/coordinax/_src/spherical/manifold.py @@ -12,7 +12,7 @@ import dataclassish from .atlas import HyperSphericalAtlas -from .metric import HyperSphericalMetric +from .metric import RoundMetric from coordinax._src.base import AbstractManifold from coordinax._src.internal import pos_named_objs @@ -28,7 +28,7 @@ class HyperSphericalManifold(AbstractManifold): >>> import coordinax.manifolds as cxm >>> import coordinax.charts as cxc - >>> S2 = cxm.HyperSphericalManifold() + >>> S2 = cxm.HyperSphericalManifold(2) >>> S2.ndim 2 @@ -46,7 +46,7 @@ class HyperSphericalManifold(AbstractManifold): def __init__(self, ndim: int = 2, /) -> None: object.__setattr__(self, "ndim", ndim) object.__setattr__(self, "atlas", HyperSphericalAtlas(self.ndim)) - object.__setattr__(self, "metric", HyperSphericalMetric(self.ndim)) + object.__setattr__(self, "metric", RoundMetric(self.ndim)) def __pdoc__(self, *, alias: bool = True, **kw: Any) -> wl.AbstractDoc: """Return the string representation. @@ -55,7 +55,7 @@ def __pdoc__(self, *, alias: bool = True, **kw: Any) -> wl.AbstractDoc: -------- >>> import wadler_lindig as wl >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.R3 >>> wl.pprint(M) Rn(3) diff --git a/src/coordinax/_src/spherical/metric.py b/src/coordinax/_src/spherical/metric.py index 048c7414..21e9c21f 100644 --- a/src/coordinax/_src/spherical/metric.py +++ b/src/coordinax/_src/spherical/metric.py @@ -1,28 +1,20 @@ """Two-sphere manifold.""" -__all__ = ("HyperSphericalMetric",) +__all__ = ("RoundMetric",) import dataclasses from typing import final import jax -import jax.numpy as jnp -import numpy as np -import unxt as u -from unxt.quantity import AllowValue - -import coordinax.charts as cxc -from coordinax._src.base import AbstractDiagonalMetric -from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.base import AbstractDiagonalMetricField @jax.tree_util.register_static @final @dataclasses.dataclass(frozen=True, slots=True) -class HyperSphericalMetric(AbstractDiagonalMetric): +class RoundMetric(AbstractDiagonalMetricField): r"""Round metric on the unit $n$-sphere $S^{n-1}$ in standard spherical coordinates. The round metric on $S^2$ in the $(\theta, \phi)$ spherical chart is @@ -37,22 +29,22 @@ class HyperSphericalMetric(AbstractDiagonalMetric): Examples -------- >>> import jax.numpy as jnp + >>> import coordinax.api.manifolds as cxmapi >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> m = cxm.HyperSphericalMetric(2) - >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} - >>> m.metric_matrix(cxc.sph2, at=at) - Array([[1., 0.], - [0., 1.]], dtype=float64) - - The signature is ``(1,) * ndim`` for this positive-definite metric: - + >>> m = cxm.RoundMetric(2) >>> m.signature (1, 1) >>> m.ndim 2 + The metric matrix is obtained via the dispatch API on the associated manifold: + + >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} + >>> cxmapi.metric_matrix(cxm.S2, at, cxc.sph2).diagonal + Array([1., 1.], dtype=float64) + """ ndim: int @@ -62,67 +54,3 @@ class HyperSphericalMetric(AbstractDiagonalMetric): def signature(self) -> tuple[int, ...]: """Metric signature: ``(1,) * ndim`` — the round sphere metric is Riemannian.""" return (1,) * self.ndim - - def metric_matrix( - self, chart: cxc.AbstractChart, /, *, at: CDict, usys: OptUSys = None - ) -> QuantityMatrix | jax.Array: - r"""Metric matrix $g = \operatorname{diag}(g_0, \ldots, g_{n-1})$ at ``at``. - - The diagonal entries follow the cumulative-sine rule: - - $$g_{kk} = \prod_{j=0}^{k-1} \sin^2(\theta_j)$$ - - so $g_{00} = 1$, $g_{11} = \sin^2\theta_0$, - $g_{22} = \sin^2\theta_0\,\sin^2\theta_1$, etc. - - Examples - -------- - >>> import jax.numpy as jnp - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - - Bare arrays (angles in radians) → plain ``jax.Array``: - - >>> M = cxm.HyperSphericalMetric(2) - >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} - >>> M.metric_matrix(cxc.sph2, at=at) - Array([[1., 0.], - [0., 1.]], dtype=float64) - - {class}`~unxt.Quantity` angles (radians) → - {class}`~coordinax.internal.QuantityMatrix`: - - >>> at = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} - >>> M.metric_matrix(cxc.sph2, at=at) - QuantityMatrix([[1., 0.], - [0., 1.]], '((, ), (, ))') - - {class}`~unxt.Quantity` angles in degrees are converted automatically: - - >>> at = {"theta": u.Angle(90.0, "deg"), "phi": u.Angle(0.0, "deg")} - >>> M.metric_matrix(cxc.sph2, at=at) - QuantityMatrix([[1., 0.], - [0., 1.]], '((, ), (, ))') - - """ - components = chart.components - is_qty = any(map(u.quantity.is_any_quantity, at.values())) - ang_unit = usys["angle"] if usys is not None else u.unit("rad") - - # Seed with a scalar of the right dtype/shape from the first component. - dtype = jnp.promote_types(*(v.dtype for v in at.values())) - cumprod = jnp.ones((), dtype=dtype) - diag_entries = [cumprod] - for k in components[:-1]: - ang = u.ustrip(AllowValue, u.uconvert_value(u.unit("rad"), ang_unit, at[k])) - cumprod = cumprod * jnp.sin(ang) ** 2 - diag_entries.append(cumprod) - G_arr = jnp.diag(jnp.stack(diag_entries)) - - if not is_qty: - return G_arr - - n = self.ndim - units = UnitsMatrix(np.full((n, n), u.unit(""))) - return QuantityMatrix(G_arr, unit=units) diff --git a/src/coordinax/_src/spherical/register_charts.py b/src/coordinax/_src/spherical/register_charts.py index 56d90067..4543e091 100644 --- a/src/coordinax/_src/spherical/register_charts.py +++ b/src/coordinax/_src/spherical/register_charts.py @@ -41,8 +41,7 @@ def pt_project( >>> import coordinax.manifolds as cxm >>> q = {"x": u.Q(1.0, "km"), "y": u.Q(0.0, "km"), "z": u.Q(0.0, "km")} - >>> M = cxm.HyperSphericalManifold() - >>> cxm.pt_project(q, cxc.cart3d, M) + >>> cxm.pt_project(q, cxc.cart3d, cxm.S2) {'theta': Q(1.57079633, 'rad'), 'phi': Q(0., 'rad')} """ diff --git a/src/coordinax/_src/spherical/register_metric.py b/src/coordinax/_src/spherical/register_metric.py new file mode 100644 index 00000000..a07e99d5 --- /dev/null +++ b/src/coordinax/_src/spherical/register_metric.py @@ -0,0 +1,94 @@ +"""Register ``metric_matrix`` and ``metric_representation`` dispatch rules. + +Covers :class:`~coordinax.manifolds.HyperSphericalManifold` (the unit +$n$-sphere $S^n$) paired with intrinsic angular charts that derive from +:class:`~coordinax._src.spherical.chart.AbstractSphericalHyperSphere`. + +The round metric on $S^n$ is diagonal in standard spherical charts, so all +rules return a :class:`~coordinax._src.metric.matrix.DiagonalMetric`. The +diagonal entries are computed directly via the ``_sine_product_diagonal`` +helper, avoiding a full-matrix allocation. + +""" + +__all__: tuple[str, ...] = () + +import jax.numpy as jnp +import plum + +import unxt as u +from unxt.quantity import AllowValue + +from .chart import AbstractSphericalHyperSphere +from .manifold import HyperSphericalManifold +from coordinax._src.metric.matrix import DiagonalMetric, _sine_product_diagonal +from coordinax.internal import CDict + +RAD = u.unit("rad") + + +@plum.dispatch +def metric_representation( + M: HyperSphericalManifold, chart: AbstractSphericalHyperSphere, / +) -> type[DiagonalMetric]: + """Return `DiagonalMetric` for a unit $n$-sphere in a standard angular chart. + + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + >>> cxm.metric_representation(cxm.S2, cxc.sph2) + + + """ + del M, chart + return DiagonalMetric + + +@plum.dispatch +def metric_matrix( + M: HyperSphericalManifold, point: CDict, chart: AbstractSphericalHyperSphere, / +) -> DiagonalMetric: + r"""Round metric on the unit $n$-sphere in a standard angular chart. + + Computes diagonal entries directly via ``_sine_product_diagonal``: + + $$g_{kk} = \prod_{j=0}^{k-1} \sin^2(\theta_j)$$ + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + >>> from coordinax.api.manifolds import metric_matrix + >>> from coordinax._src.metric.matrix import DiagonalMetric + + $S^2$ at the equator $\theta = \pi/2$: + + >>> M = cxm.HyperSphericalManifold(2) + >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} + >>> g = metric_matrix(M, at, cxc.sph2) + >>> isinstance(g, DiagonalMetric) + True + >>> g.diagonal + Array([1., 1.], dtype=float64) + + $S^2$ at $\theta = \pi/6$: + + >>> at = {"theta": jnp.array(jnp.pi / 6), "phi": jnp.array(0.0)} + >>> g = metric_matrix(M, at, cxc.sph2) + >>> round(float(g.diagonal[1]), 10) # sin\u00b2(\u03c0/6) \u2248 0.25 + 0.25 + + """ + components = chart.components + # All angular components except the last (azimuthal) are polar angles + theta_keys = components[:-1] + if theta_keys: + thetas = jnp.stack( + [ + u.ustrip(AllowValue, u.uconvert_value(RAD, RAD, point[k])) + for k in theta_keys + ] + ) + else: + thetas = jnp.array([]) + diag = _sine_product_diagonal(thetas, 1.0) + return DiagonalMetric(diag) diff --git a/src/coordinax/_src/spherical/scale_factors.py b/src/coordinax/_src/spherical/scale_factors.py new file mode 100644 index 00000000..2cc60428 --- /dev/null +++ b/src/coordinax/_src/spherical/scale_factors.py @@ -0,0 +1,58 @@ +"""Spherical specializations for `coordinax.manifolds.scale_factors`.""" + +__all__: tuple[str, ...] = () + +import plum + +import quaxed.numpy as jnp +import unxt as u +from unxt.quantity import AllowValue + +from .metric import RoundMetric +from coordinax._src.base import AbstractChart +from coordinax._src.custom_types import CDict, OptUSys +from coordinax.internal import QMatrix, UnitsMatrix + + +@plum.dispatch +def scale_factors( + metric: RoundMetric, chart: AbstractChart, /, *, at: CDict, usys: OptUSys = None +) -> QMatrix: + r"""Return round-metric diagonal directly without forming the nxn matrix. + + Computes the cumulative-sine diagonal $g_{kk} = \prod_{j>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + Bare angles (no units) → dimensionless QMatrix: + + >>> metric = cxm.RoundMetric(2) + >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} + >>> cxm.scale_factors(metric, cxc.sph2, at=at) + QMatrix([1., 1.], '(, )') + + Quantity angles → dimensionless QMatrix: + + >>> at = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} + >>> cxm.scale_factors(metric, cxc.sph2, at=at) + QMatrix([1., 1.], '(, )') + + """ + del metric + components = chart.components + ang_unit = usys["angle"] if usys is not None else u.unit("rad") + angles = jnp.stack( + [ + u.ustrip(AllowValue, u.uconvert_value(u.unit("rad"), ang_unit, at[k])) + for k in components[:-1] + ] + ) + sin2 = jnp.sin(angles) ** 2 + value = jnp.concatenate([jnp.ones(1, dtype=sin2.dtype), jnp.cumprod(sin2)]) + n = len(components) + units = UnitsMatrix(tuple(u.unit("") for _ in range(n))) + return QMatrix(value, unit=units) diff --git a/tests/unit/manifolds/test_metric_matrix_dispatch.py b/tests/unit/manifolds/test_metric_matrix_dispatch.py new file mode 100644 index 00000000..27b50d3f --- /dev/null +++ b/tests/unit/manifolds/test_metric_matrix_dispatch.py @@ -0,0 +1,333 @@ +"""Tests for metric_matrix() and metric_representation() dispatch rules. + +Coverage: +- All registered (manifold, chart) pairs return the type promised by + ``metric_representation(manifold, chart)`` +- Numerical values at sample points for the constant-metric cases +- JIT compatibility for each dispatch path +""" + +import jax +import jax.numpy as jnp +import pytest + +import unxt as u + +import coordinax.api.manifolds as cxmapi +import coordinax.charts as cxc +import coordinax.manifolds as cxm +from coordinax._src.metric.matrix import DiagonalMetric + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_product_manifold(): + """Return a CartesianProductManifold (S² x R¹) and its default chart.""" + M_prod = cxm.CartesianProductManifold( + factors=(cxm.S2, cxm.R1), factor_names=("sphere", "line") + ) + chart = M_prod.atlas.default_chart() + return M_prod, chart + + +# --------------------------------------------------------------------------- +# Fixtures: (manifold, point, chart) triples +# --------------------------------------------------------------------------- + + +@pytest.fixture( + params=[ + pytest.param("euclidean_cart1d", id="euclidean-cart1d"), + pytest.param("euclidean_cart2d", id="euclidean-cart2d"), + pytest.param("euclidean_cart3d", id="euclidean-cart3d"), + pytest.param("euclidean_cartnd", id="euclidean-cartnd"), + pytest.param("euclidean_sph3d", id="euclidean-sph3d"), + pytest.param("hyperspherical_sph2", id="hyperspherical-sph2"), + pytest.param("minkowski_minkowskict", id="minkowski-minkowskict"), + pytest.param("product_default", id="product-default"), + ] +) +def manifold_point_chart(request): + """Return ``(manifold, point_dict, chart)`` for each registered pair.""" + key = request.param + if key == "euclidean_cart1d": + return (cxm.R1, {"x": jnp.array(1.0)}, cxc.cart1d) + if key == "euclidean_cart2d": + return (cxm.R2, {"x": jnp.array(1.0), "y": jnp.array(2.0)}, cxc.cart2d) + if key == "euclidean_cart3d": + return ( + cxm.R3, + {"x": jnp.array(1.0), "y": jnp.array(2.0), "z": jnp.array(3.0)}, + cxc.cart3d, + ) + if key == "euclidean_cartnd": + return (cxm.R3, {"q": jnp.array([1.0, 2.0, 3.0])}, cxc.cartnd) + if key == "euclidean_sph3d": + return ( + cxm.R3, + { + "r": u.Q(jnp.array(2.0), "m"), + "theta": u.Angle(jnp.pi / 3, "rad"), + "phi": u.Angle(jnp.array(0.4), "rad"), + }, + cxc.sph3d, + ) + if key == "hyperspherical_sph2": + return ( + cxm.S2, + {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)}, + cxc.sph2, + ) + if key == "minkowski_minkowskict": + return ( + cxm.MinkowskiManifold(), + { + "ct": jnp.array(0.0), + "x": jnp.array(1.0), + "y": jnp.array(0.0), + "z": jnp.array(0.0), + }, + cxc.minkowskict, + ) + if key == "product_default": + M_prod, chart = _make_product_manifold() + pt = {k: jnp.array(jnp.pi / 4) for k in chart.components} + return (M_prod, pt, chart) + msg = f"Unknown fixture key: {key!r}" + raise ValueError(msg) + + +# --------------------------------------------------------------------------- +# Contract tests — apply to every registered (manifold, chart) pair +# --------------------------------------------------------------------------- + + +class TestMetricMatrixDispatchContract: + """Each registered (manifold, chart) pair must satisfy these invariants.""" + + def test_returns_correct_type(self, manifold_point_chart): + """``metric_matrix`` must return an instance of ``metric_representation``.""" + manifold, point, chart = manifold_point_chart + expected_cls = cxmapi.metric_representation(manifold, chart) + result = cxmapi.metric_matrix(manifold, point, chart) + assert isinstance(result, expected_cls), ( + f"metric_matrix({type(manifold).__name__}, {type(chart).__name__}) " + f"returned {type(result).__name__!r}, expected {expected_cls.__name__!r}" + ) + + def test_metric_representation_returns_type(self, manifold_point_chart): + """``metric_representation`` must return a class, not an instance.""" + manifold, _point, chart = manifold_point_chart + cls = cxmapi.metric_representation(manifold, chart) + assert isinstance(cls, type), ( + f"metric_representation should return a class, got {type(cls).__name__!r}" + ) + + def test_has_ndim(self, manifold_point_chart): + """Returned matrix must have a positive integer ndim.""" + manifold, point, chart = manifold_point_chart + result = cxmapi.metric_matrix(manifold, point, chart) + assert hasattr(result, "ndim") + assert isinstance(result.ndim, int) + assert result.ndim > 0 + + +# --------------------------------------------------------------------------- +# Numerical value tests +# --------------------------------------------------------------------------- + + +class TestMetricMatrixNumericalValues: + """Spot-check numerical values for constant-metric cases.""" + + def test_euclidean_cart1d_is_identity(self): + pt = {"x": jnp.array(0.0)} + g = cxmapi.metric_matrix(cxm.R1, pt, cxc.cart1d) + assert isinstance(g, cxm.DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.ones(1)) + + def test_euclidean_cart2d_is_identity(self): + pt = {"x": jnp.array(0.0), "y": jnp.array(0.0)} + g = cxmapi.metric_matrix(cxm.R2, pt, cxc.cart2d) + assert isinstance(g, cxm.DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.ones(2)) + + def test_euclidean_cart3d_is_identity(self): + pt = {"x": jnp.array(1.0), "y": jnp.array(2.0), "z": jnp.array(3.0)} + g = cxmapi.metric_matrix(cxm.R3, pt, cxc.cart3d) + assert isinstance(g, cxm.DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.ones(3)) + + def test_euclidean_cartnd_identity_by_dimension(self): + """CartND: diagonal length equals the actual array dimensionality.""" + pt = {"q": jnp.array([1.0, 2.0, 3.0])} + g = cxmapi.metric_matrix(cxm.R3, pt, cxc.cartnd) + assert isinstance(g, cxm.DiagonalMetric) + assert g.diagonal.shape == (3,) + assert jnp.allclose(g.diagonal, jnp.ones(3)) + + def test_minkowski_diagonal_signature(self): + """Minkowski metric in (ct, x, y, z) coords: diag = [-1, 1, 1, 1].""" + M = cxm.MinkowskiManifold() + pt = { + "ct": jnp.array(0.0), + "x": jnp.array(1.0), + "y": jnp.array(0.0), + "z": jnp.array(0.0), + } + g = cxmapi.metric_matrix(M, pt, cxc.minkowskict) + assert isinstance(g, DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.array([-1.0, 1.0, 1.0, 1.0])) + + def test_hyperspherical_at_equator(self): + """At theta=π/2: diag(S²) = (1, sin²(π/2)) = (1, 1).""" + pt = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} + g = cxmapi.metric_matrix(cxm.S2, pt, cxc.sph2) + assert isinstance(g, DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.array([1.0, 1.0]), atol=1e-6) + + def test_hyperspherical_at_off_equator(self): + """At theta=π/3: diag(S²) = (1, sin²(π/3)) = (1, 3/4).""" + pt = {"theta": jnp.array(jnp.pi / 3), "phi": jnp.array(0.0)} + g = cxmapi.metric_matrix(cxm.S2, pt, cxc.sph2) + assert isinstance(g, DiagonalMetric) + expected = jnp.array([1.0, jnp.sin(jnp.pi / 3) ** 2]) + assert jnp.allclose(g.diagonal, expected, atol=1e-6) + + def test_euclidean_sph3d_returns_diagonal_metric(self): + """EuclideanManifold + sph3d uses analytic formula, returning DiagonalMetric.""" + pt = { + "r": u.Q(jnp.array(2.0), "m"), + "theta": u.Angle(jnp.pi / 2, "rad"), + "phi": u.Angle(jnp.array(0.0), "rad"), + } + g = cxmapi.metric_matrix(cxm.R3, pt, cxc.sph3d) + assert isinstance(g, DiagonalMetric) + + def test_product_manifold_returns_dense_metric(self): + """CartesianProductManifold always returns DenseMetric.""" + M_prod, chart = _make_product_manifold() + pt = {k: jnp.array(jnp.pi / 4) for k in chart.components} + g = cxmapi.metric_matrix(M_prod, pt, chart) + assert isinstance(g, cxm.DenseMetric) + + +# --------------------------------------------------------------------------- +# DiagonalMetric consistency +# --------------------------------------------------------------------------- + + +class TestDiagonalMetricOffDiagonal: + """DiagonalMetric.to_dense() must have zero off-diagonal entries.""" + + def test_euclidean_cart3d_dense_is_identity(self): + pt = {"x": jnp.array(1.0), "y": jnp.array(2.0), "z": jnp.array(3.0)} + g = cxmapi.metric_matrix(cxm.R3, pt, cxc.cart3d) + assert isinstance(g, DiagonalMetric) + G = g.to_dense().matrix + off_diag = G - jnp.diag(jnp.diag(G)) + assert jnp.allclose(off_diag, jnp.zeros((3, 3))) + + def test_minkowski_dense_off_diagonal_is_zero(self): + M = cxm.MinkowskiManifold() + pt = { + "ct": jnp.array(0.0), + "x": jnp.array(0.0), + "y": jnp.array(0.0), + "z": jnp.array(0.0), + } + g = cxmapi.metric_matrix(M, pt, cxc.minkowskict) + assert isinstance(g, DiagonalMetric) + G = g.to_dense().matrix + off_diag = G - jnp.diag(jnp.diag(G)) + assert jnp.allclose(off_diag, jnp.zeros((4, 4))) + + +# --------------------------------------------------------------------------- +# JIT compatibility +# --------------------------------------------------------------------------- + + +class TestMetricMatrixJIT: + """metric_matrix must be JIT-compilable for all dispatch paths.""" + + def test_jit_euclidean_cart3d(self): + @jax.jit + def compute(x, y, z): + pt = {"x": x, "y": y, "z": z} + return cxmapi.metric_matrix(cxm.R3, pt, cxc.cart3d).diagonal + + result = compute(jnp.array(1.0), jnp.array(2.0), jnp.array(3.0)) + assert jnp.allclose(result, jnp.ones(3)) + + def test_jit_euclidean_cartnd(self): + @jax.jit + def compute(q): + return cxmapi.metric_matrix(cxm.R3, {"q": q}, cxc.cartnd).diagonal + + result = compute(jnp.array([1.0, 2.0, 3.0])) + assert jnp.allclose(result, jnp.ones(3)) + + def test_jit_minkowski(self): + M = cxm.MinkowskiManifold() + chart = cxc.minkowskict + + @jax.jit + def compute(ct, x, y, z): + pt = {"ct": ct, "x": x, "y": y, "z": z} + return cxmapi.metric_matrix(M, pt, chart).diagonal + + result = compute(jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), jnp.array(0.0)) + assert jnp.allclose(result, jnp.array([-1.0, 1.0, 1.0, 1.0])) + + def test_jit_hyperspherical(self): + @jax.jit + def compute(theta, phi): + return cxmapi.metric_matrix( + cxm.S2, {"theta": theta, "phi": phi}, cxc.sph2 + ).diagonal + + result = compute(jnp.array(jnp.pi / 2), jnp.array(0.0)) + assert jnp.allclose(result, jnp.array([1.0, 1.0]), atol=1e-6) + + def test_jit_cart1d(self): + @jax.jit + def compute(x): + return cxmapi.metric_matrix(cxm.R1, {"x": x}, cxc.cart1d).diagonal + + result = compute(jnp.array(0.0)) + assert jnp.allclose(result, jnp.ones(1)) + + def test_jit_cart2d(self): + @jax.jit + def compute(x, y): + return cxmapi.metric_matrix(cxm.R2, {"x": x, "y": y}, cxc.cart2d).diagonal + + result = compute(jnp.array(0.0), jnp.array(0.0)) + assert jnp.allclose(result, jnp.ones(2)) + + +# --------------------------------------------------------------------------- +# metric_representation returns constant result (not point-dependent) +# --------------------------------------------------------------------------- + + +class TestMetricRepresentation: + """metric_representation must return the correct class for each pair.""" + + @pytest.mark.parametrize( + ("manifold", "chart", "expected_cls"), + [ + (cxm.R1, cxc.cart1d, cxm.DiagonalMetric), + (cxm.R2, cxc.cart2d, cxm.DiagonalMetric), + (cxm.R3, cxc.cart3d, cxm.DiagonalMetric), + (cxm.R3, cxc.cartnd, cxm.DiagonalMetric), + (cxm.R3, cxc.sph3d, cxm.DiagonalMetric), + (cxm.S2, cxc.sph2, cxm.DiagonalMetric), + (cxm.MinkowskiManifold(), cxc.minkowskict, cxm.DiagonalMetric), + ], + ) + def test_metric_representation_type(self, manifold, chart, expected_cls): + assert cxmapi.metric_representation(manifold, chart) is expected_cls diff --git a/tests/unit/manifolds/test_metric_pullback_consistency.py b/tests/unit/manifolds/test_metric_pullback_consistency.py new file mode 100644 index 00000000..8896e07b --- /dev/null +++ b/tests/unit/manifolds/test_metric_pullback_consistency.py @@ -0,0 +1,191 @@ +"""Pullback metric consistency: RoundMetric vs Jacobian pullback on S². + +For the unit two-sphere S², the round metric in (θ, φ) coordinates gives +g = diag(1, sin²θ). The same result must follow from the Jacobian pullback +of the flat metric on R³ via the standard Cartesian embedding: + + (θ, φ) → (sin(θ)cos(φ), sin(θ)sin(φ), cos(θ)) + +These tests assert that both paths agree numerically at sample points and +across a range of angles verified with Hypothesis. +""" + +import hypothesis.strategies as st +import jax +import jax.numpy as jnp +import pytest +from hypothesis import given, settings + +import unxt as u + +import coordinax.api.manifolds as cxmapi +import coordinax.charts as cxc +import coordinax.manifolds as cxm +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def unit_sphere_embedded(): + """EmbeddedManifold for the unit two-sphere (dimensionless, radius=1).""" + return cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=1.0), + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _round_dense_matrix(theta, phi): + """Expected round-metric as a dense 2×2 array at (theta, phi).""" + return jnp.array([[1.0, 0.0], [0.0, jnp.sin(theta) ** 2]]) + + +# --------------------------------------------------------------------------- +# Type contract tests +# --------------------------------------------------------------------------- + + +class TestPullbackConsistencyTypes: + """Verify the metric_matrix return types for both paths.""" + + def test_round_metric_returns_diagonal(self): + pt = {"theta": jnp.array(jnp.pi / 3), "phi": jnp.array(jnp.pi / 4)} + g = cxmapi.metric_matrix(cxm.S2, pt, cxc.sph2) + assert isinstance(g, DiagonalMetric) + + def test_embedded_metric_returns_dense(self, unit_sphere_embedded): + pt = {"theta": jnp.array(jnp.pi / 3), "phi": jnp.array(jnp.pi / 4)} + g = cxmapi.metric_matrix(unit_sphere_embedded, pt, cxc.sph2) + assert isinstance(g, DenseMetric) + + +# --------------------------------------------------------------------------- +# Numerical consistency tests +# --------------------------------------------------------------------------- + + +class TestPullbackConsistencyNumerical: + """Both paths give the same metric matrix at sample points.""" + + @pytest.mark.parametrize( + ("theta", "phi"), + [ + (jnp.pi / 2, 0.0), # equator, phi=0 + (jnp.pi / 3, jnp.pi / 4), # off-equator + (jnp.pi / 6, jnp.pi), # high latitude, phi=π + (0.1, 2.5), # near pole, arbitrary phi + ], + ids=["equator-0", "off-equator", "high-lat-pi", "near-pole"], + ) + def test_sample_point(self, unit_sphere_embedded, theta, phi): + pt = {"theta": jnp.array(theta), "phi": jnp.array(phi)} + + g_round = cxmapi.metric_matrix(cxm.S2, pt, cxc.sph2) + g_pullback = cxmapi.metric_matrix(unit_sphere_embedded, pt, cxc.sph2) + + # RoundMetric (diagonal) and Jacobian pullback (dense) must agree. + expected = g_round.to_dense().matrix # plain array, shape (2, 2) + actual = g_pullback.matrix.value # QMatrix.value, shape (2, 2) + + assert jnp.allclose(actual, expected, atol=1e-6), ( + f"Mismatch at theta={theta}, phi={phi}:\n" + f" expected={expected}\n actual={actual}" + ) + + @given( + theta=st.floats( + min_value=0.05, max_value=3.09, allow_nan=False, allow_infinity=False + ), + phi=st.floats( + min_value=0.0, max_value=6.28, allow_nan=False, allow_infinity=False + ), + ) + @settings(max_examples=30, deadline=None) + def test_hypothesis_s2(self, unit_sphere_embedded, theta, phi): + pt = {"theta": jnp.array(theta), "phi": jnp.array(phi)} + + g_round = cxmapi.metric_matrix(cxm.S2, pt, cxc.sph2) + g_pullback = cxmapi.metric_matrix(unit_sphere_embedded, pt, cxc.sph2) + + expected = g_round.to_dense().matrix + actual = g_pullback.matrix.value + + assert jnp.allclose(actual, expected, atol=1e-5), ( + f"Mismatch at theta={theta:.4f}, phi={phi:.4f}:\n" + f" expected={expected}\n actual={actual}" + ) + + +# --------------------------------------------------------------------------- +# JIT compatibility +# --------------------------------------------------------------------------- + + +class TestPullbackConsistencyJIT: + """Both metric paths are JIT-compatible.""" + + def test_round_metric_jit(self): + @jax.jit + def compute(theta, phi): + pt = {"theta": theta, "phi": phi} + return cxmapi.metric_matrix(cxm.S2, pt, cxc.sph2).diagonal + + result = compute(jnp.array(jnp.pi / 3), jnp.array(0.0)) + assert result.shape == (2,) + + def test_pullback_metric_jit(self): + M = cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=1.0), + ) + + @jax.jit + def compute(theta, phi): + pt = {"theta": theta, "phi": phi} + return cxmapi.metric_matrix(M, pt, cxc.sph2).matrix.value + + result = compute(jnp.array(jnp.pi / 3), jnp.array(0.0)) + assert result.shape == (2, 2) + + +# --------------------------------------------------------------------------- +# Unit preservation for non-trivial radius +# --------------------------------------------------------------------------- + + +class TestPullbackMetricUnits: + """For a sphere with physical radius, the metric carries correct units.""" + + def test_radius_1km_at_equator(self): + M = cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ) + at = {"theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0.0, "rad")} + g = cxmapi.metric_matrix(M, at, cxc.sph2) + assert isinstance(g, DenseMetric) + # At the equator sin(π/2)=1, so metric should be identity × km²/rad² + assert jnp.allclose(g.matrix.value, jnp.eye(2), atol=1e-6) + assert str(g.matrix.unit[0, 0]) == "km2 / rad2" + + def test_radius_2m_metric_scaled(self): + M = cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "m")), + ) + at = {"theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0.0, "rad")} + g = cxmapi.metric_matrix(M, at, cxc.sph2) + # Metric = R² × I at equator → values should be [[4, 0], [0, 4]] + assert jnp.allclose(g.matrix.value, 4.0 * jnp.eye(2), atol=1e-6) + assert str(g.matrix.unit[0, 0]) == "m2 / rad2" diff --git a/tests/unit/manifolds/test_metrics.py b/tests/unit/manifolds/test_metrics.py index 55fa0006..57ec50a1 100644 --- a/tests/unit/manifolds/test_metrics.py +++ b/tests/unit/manifolds/test_metrics.py @@ -1,27 +1,34 @@ -"""Tests for coordinax.metrics — AbstractMetric and concrete implementations. +"""Tests for coordinax.metrics — AbstractMetricField and concrete implementations. All tests in this file are RED until the metrics module is implemented. """ -import hypothesis.strategies as st import jax import jax.numpy as jnp import pytest -from hypothesis import given, settings import unxt as u import coordinax.charts as cxc -import coordinax.internal as cxi import coordinax.manifolds as cxm +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric +from coordinax.api.manifolds import metric_matrix as mm_dispatch +from coordinax.internal import QMatrix + + +def _mat_val(dense_metric, /): + """Extract numeric array from DenseMetric, regardless of matrix type.""" + mat = dense_metric.matrix + return mat.value if isinstance(mat, QMatrix) else mat + # ============================================================================= -# AbstractMetric contract +# AbstractMetricField contract # ============================================================================= -class TestAbstractMetricContract: - """Every AbstractMetric subclass must satisfy these invariants.""" +class TestAbstractMetricFieldContract: + """Every AbstractMetricField subclass must satisfy these invariants.""" @pytest.fixture( params=[ @@ -35,13 +42,13 @@ class TestAbstractMetricContract: ) def metric(self, request): metrics = { - "R3": cxm.EuclideanMetric(3), - "R2": cxm.EuclideanMetric(2), - "R1": cxm.EuclideanMetric(1), - "hyperspherical2d": cxm.HyperSphericalMetric(ndim=2), + "R3": cxm.FlatMetric(3), + "R2": cxm.FlatMetric(2), + "R1": cxm.FlatMetric(1), + "hyperspherical2d": cxm.RoundMetric(ndim=2), "minkowski4d": cxm.MinkowskiMetric(), - "product3d": cxm.CartesianProductMetric( - factors=(cxm.HyperSphericalMetric(2), cxm.EuclideanMetric(1)) + "product3d": cxm.ProductMetric( + factors=(cxm.RoundMetric(2), cxm.FlatMetric(1)) ), } return metrics[request.param] @@ -66,15 +73,15 @@ def test_signature_entries_are_pm_one(self, metric): # ============================================================================= -# AbstractDiagonalMetric contract +# AbstractDiagonalMetricField contract # ============================================================================= -class TestAbstractDiagonalMetricContract: - """Every AbstractDiagonalMetric subclass must report is_diagonal=True. +class TestAbstractDiagonalMetricFieldContract: + """Every AbstractDiagonalMetricField subclass must report is_diagonal=True. - This class tests the structural promise added by AbstractDiagonalMetric beyond - what AbstractMetric already guarantees: is_diagonal() must return True for all + This class tests the structural promise added by AbstractDiagonalMetricField beyond + what AbstractMetricField already guarantees: is_diagonal() must return True for all valid (metric, chart, base-point) combinations, regardless of position or usys. """ @@ -90,12 +97,12 @@ class TestAbstractDiagonalMetricContract: def metric_chart_at(self, request): cases = { "R3_cart": ( - cxm.EuclideanMetric(3), + cxm.FlatMetric(3), cxc.cart3d, {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")}, ), "R3_sph": ( - cxm.EuclideanMetric(3), + cxm.FlatMetric(3), cxc.sph3d, { "r": u.Q(2.0, "m"), @@ -104,12 +111,12 @@ def metric_chart_at(self, request): }, ), "R2_cart": ( - cxm.EuclideanMetric(2), + cxm.FlatMetric(2), cxc.cart2d, {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")}, ), "hyperspherical2d": ( - cxm.HyperSphericalMetric(ndim=2), + cxm.RoundMetric(ndim=2), cxc.sph2, {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")}, ), @@ -128,196 +135,174 @@ def metric_chart_at(self, request): def test_is_instance_of_abstractdiagonalmetric(self, metric_chart_at): metric, _, _ = metric_chart_at - assert isinstance(metric, cxm.AbstractDiagonalMetric) + assert isinstance(metric, cxm.AbstractDiagonalMetricField) def test_is_diagonal_returns_true(self, metric_chart_at): - metric, chart, at = metric_chart_at - result = metric.is_diagonal(chart, at=at) - assert bool(result) is True + """AbstractDiagonalMetricField subclasses are always diagonal by type.""" + metric, _, _ = metric_chart_at + assert isinstance(metric, cxm.AbstractDiagonalMetricField) def test_is_diagonal_output_shape(self, metric_chart_at): - metric, chart, at = metric_chart_at - result = metric.is_diagonal(chart, at=at) - assert result.shape == () - assert result.dtype == jnp.bool_ + """AbstractDiagonalMetricField structural check (no method call needed).""" + metric, _, _ = metric_chart_at + assert isinstance(metric, cxm.AbstractDiagonalMetricField) def test_is_diagonal_under_jit(self, metric_chart_at): - metric, chart, at = metric_chart_at - - @jax.jit - def compute(at): - return metric.is_diagonal(chart, at=at) - - result = compute(at) - assert bool(result) is True + """Isinstance check doesn't require jit — type is static.""" + metric, _, _ = metric_chart_at + assert isinstance(metric, cxm.AbstractDiagonalMetricField) def test_is_diagonal_ignores_usys(self, metric_chart_at): - """is_diagonal must return True regardless of usys argument.""" - metric, chart, at = metric_chart_at - result_no_usys = metric.is_diagonal(chart, at=at) - result_with_usys = metric.is_diagonal(chart, at=at, usys=u.unitsystems.si) - assert bool(result_no_usys) is True - assert bool(result_with_usys) is True + """Diagonal-ness is a structural property independent of usys.""" + metric, _, _ = metric_chart_at + assert isinstance(metric, cxm.AbstractDiagonalMetricField) # ============================================================================= -# EuclideanMetric +# FlatMetric # ============================================================================= -class TestEuclideanMetric: - """Tests for EuclideanMetric, the flat Riemannian metric on Euclidean space.""" +class TestFlatMetric: + """Tests for FlatMetric, the flat Riemannian metric on Euclidean space.""" def test_isinstance_abstractdiagonalmetric(self): - assert isinstance(cxm.EuclideanMetric(3), cxm.AbstractDiagonalMetric) + assert isinstance(cxm.FlatMetric(3), cxm.AbstractDiagonalMetricField) def test_construction_1d(self): - m = cxm.EuclideanMetric(1) + m = cxm.FlatMetric(1) assert m.ndim == 1 def test_construction_2d(self): - m = cxm.EuclideanMetric(2) + m = cxm.FlatMetric(2) assert m.ndim == 2 def test_construction_3d(self): - m = cxm.EuclideanMetric(3) + m = cxm.FlatMetric(3) assert m.ndim == 3 def test_metric_matrix_cart3d_is_identity(self): - m = cxm.EuclideanMetric(3) p = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} - g = m.metric_matrix(cxc.cart3d, at=p) - assert g.shape == (3, 3) - assert jnp.allclose(g.value, jnp.eye(3)) + g = mm_dispatch(cxm.R3, p, cxc.cart3d) + assert isinstance(g, DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.ones(3)) def test_metric_matrix_cart2d_is_identity(self): - m = cxm.EuclideanMetric(2) p = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")} - g = m.metric_matrix(cxc.cart2d, at=p) - assert g.shape == (2, 2) - assert jnp.allclose(g.value, jnp.eye(2)) + g = mm_dispatch(cxm.R2, p, cxc.cart2d) + assert isinstance(g, DiagonalMetric) + assert jnp.allclose(g.diagonal, jnp.ones(2)) def test_metric_matrix_sph3d_at_origin(self): """Spherical metric at (1, pi/2, 0): diag(1, r^2, r^2 sin^2 theta).""" - m = cxm.EuclideanMetric(3) p = { "r": u.Q(1.0, "m"), "theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad"), } - g = m.metric_matrix(cxc.sph3d, at=p) - assert g.shape == (3, 3) + g = mm_dispatch(cxm.R3, p, cxc.sph3d) + dense = g.to_dense() + assert dense.matrix.shape == (3, 3) # diagonal entries: g_rr=1, g_tt=r^2=1, g_pp=r^2 sin^2(theta)=1 expected_diag = jnp.array([1.0, 1.0, 1.0]) - assert jnp.allclose(jnp.diag(g.value), expected_diag, atol=1e-6) + assert jnp.allclose(jnp.diag(_mat_val(dense)), expected_diag, atol=1e-6) def test_metric_matrix_sph3d_diagonal(self): """Spherical metric is always diagonal.""" - m = cxm.EuclideanMetric(3) p = { "r": u.Q(2.0, "m"), "theta": u.Angle(jnp.pi / 3, "rad"), "phi": u.Angle(1.0, "rad"), } - g = m.metric_matrix(cxc.sph3d, at=p) - # Off-diagonal elements must be zero - offdiag = g.value - jnp.diag(jnp.diag(g.value)) - assert jnp.allclose(offdiag, jnp.zeros((3, 3)), atol=1e-6) + g = mm_dispatch(cxm.R3, p, cxc.sph3d) + assert isinstance(g, DiagonalMetric) def test_metric_matrix_jit(self): - m = cxm.EuclideanMetric(3) p = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} @jax.jit def compute(p): - return m.metric_matrix(cxc.cart3d, at=p) + return mm_dispatch(cxm.R3, p, cxc.cart3d) g = compute(p) - assert g.shape == (3, 3) + assert isinstance(g, DiagonalMetric) def test_carried_by_euclidean_manifold(self): - """R3.metric should return an EuclideanMetric.""" + """R3.metric should return an FlatMetric.""" metric = cxm.R3.metric - assert isinstance(metric, cxm.EuclideanMetric) + assert isinstance(metric, cxm.FlatMetric) assert metric.ndim == 3 # ============================================================================= -# HyperSphericalMetric +# RoundMetric # ============================================================================= -class TestHyperSphericalMetric: - """Tests for HyperSphericalMetric, the round metric on the unit sphere.""" +class TestRoundMetric: + """Tests for RoundMetric, the round metric on the unit sphere.""" def test_isinstance_abstractdiagonalmetric(self): - assert isinstance(cxm.HyperSphericalMetric(ndim=2), cxm.AbstractDiagonalMetric) + assert isinstance(cxm.RoundMetric(ndim=2), cxm.AbstractDiagonalMetricField) def test_construction(self): - m = cxm.HyperSphericalMetric(ndim=2) + m = cxm.RoundMetric(ndim=2) assert m.ndim == 2 def test_metric_matrix_at_equator(self): """S^2 metric at equator: diag(1, sin^2(theta)) = diag(1, 1).""" - m = cxm.HyperSphericalMetric(ndim=2) p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} - g = m.metric_matrix(cxc.sph2, at=p) - assert g.shape == (2, 2) - expected = jnp.diag(jnp.array([1.0, 1.0])) - assert jnp.allclose(g.value, expected, atol=1e-6) + g = mm_dispatch(cxm.S2, p, cxc.sph2) + assert isinstance(g, DiagonalMetric) + expected = jnp.array([1.0, 1.0]) + assert jnp.allclose(g.diagonal, expected, atol=1e-6) def test_metric_matrix_at_pole_theta_component(self): """S^2 metric g_theta_theta = 1 everywhere.""" - m = cxm.HyperSphericalMetric(ndim=2) p = {"theta": u.Angle(0.1, "rad"), "phi": u.Angle(0.0, "rad")} - g = m.metric_matrix(cxc.sph2, at=p) - assert jnp.allclose(g.value[0, 0], 1.0, atol=1e-6) + g = mm_dispatch(cxm.S2, p, cxc.sph2) + assert jnp.allclose(g.diagonal[0], 1.0, atol=1e-6) - def test_metric_matrix_phi_component_at_various_latitudes(self): + @pytest.mark.parametrize("theta", [0.1, jnp.pi / 4, jnp.pi / 2, jnp.pi * 3 / 4]) + def test_metric_matrix_phi_component_at_various_latitudes(self, theta): """g_phi_phi = sin^2(theta).""" - m = cxm.HyperSphericalMetric(ndim=2) - for theta_val in [0.1, jnp.pi / 4, jnp.pi / 2, jnp.pi * 3 / 4]: - p = {"theta": u.Angle(theta_val, "rad"), "phi": u.Angle(0.0, "rad")} - g = m.metric_matrix(cxc.sph2, at=p) - expected_g11 = jnp.sin(theta_val) ** 2 - assert jnp.allclose(g.value[1, 1], expected_g11, atol=1e-6), ( - f"theta={theta_val}: g_phi_phi={g[1, 1]} != sin^2(theta)={expected_g11}" - ) + p = {"theta": u.Angle(theta, "rad"), "phi": u.Angle(0.0, "rad")} + g = mm_dispatch(cxm.S2, p, cxc.sph2) + exp = jnp.sin(theta) ** 2 + assert jnp.allclose(g.diagonal[1], exp, atol=1e-6), ( + f"theta={theta}: g_phi_phi={g.diagonal[1]} != sin^2(theta)={exp}" + ) def test_metric_matrix_is_diagonal(self): """S^2 metric matrix is always diagonal.""" - m = cxm.HyperSphericalMetric(ndim=2) p = {"theta": u.Angle(jnp.pi / 3, "rad"), "phi": u.Angle(1.0, "rad")} - g = m.metric_matrix(cxc.sph2, at=p) - offdiag = g.value - jnp.diag(jnp.diag(g.value)) - assert jnp.allclose(offdiag, jnp.zeros((2, 2)), atol=1e-6) + g = mm_dispatch(cxm.S2, p, cxc.sph2) + assert isinstance(g, DiagonalMetric) def test_metric_matrix_jit(self): - m = cxm.HyperSphericalMetric(ndim=2) p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} @jax.jit def compute(p): - return m.metric_matrix(cxc.sph2, at=p) + return mm_dispatch(cxm.S2, p, cxc.sph2) g = compute(p) - assert g.shape == (2, 2) + assert isinstance(g, DiagonalMetric) def test_metric_matrix_vmap(self): - m = cxm.HyperSphericalMetric(ndim=2) thetas = jnp.linspace(0.1, jnp.pi - 0.1, 5) def single(theta_val): p = {"theta": u.Angle(theta_val, "rad"), "phi": u.Angle(0.0, "rad")} - return m.metric_matrix(cxc.sph2, at=p) + return mm_dispatch(cxm.S2, p, cxc.sph2) gs = jax.vmap(single)(thetas) - assert gs.shape == (5, 2, 2) + assert gs.diagonal.shape == (5, 2) def test_carried_by_hyperspherical_manifold(self): - """S2.metric should return a HyperSphericalMetric.""" + """S2.metric should return a RoundMetric.""" metric = cxm.S2.metric - assert isinstance(metric, cxm.HyperSphericalMetric) + assert isinstance(metric, cxm.RoundMetric) assert metric.ndim == 2 @@ -330,14 +315,13 @@ class TestMinkowskiMetric: """Tests for MinkowskiMetric, the flat Lorentzian metric on Minkowski spacetime.""" def test_isinstance_abstractdiagonalmetric(self): - assert isinstance(cxm.MinkowskiMetric(), cxm.AbstractDiagonalMetric) + assert isinstance(cxm.MinkowskiMetric(), cxm.AbstractDiagonalMetricField) def test_construction(self): M = cxm.MinkowskiMetric() assert M.ndim == 4 def test_metric_matrix_is_diagonal(self): - m = cxm.MinkowskiMetric() chart = cxc.MinkowskiCT() p = { "ct": u.Q(1.0, "m"), @@ -345,14 +329,13 @@ def test_metric_matrix_is_diagonal(self): "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m"), } - g = m.metric_matrix(chart, at=p) - assert g.shape == (4, 4) - expected = jnp.diag(jnp.array([-1.0, 1.0, 1.0, 1.0])) - assert jnp.allclose(g.value, expected, atol=1e-6) + g = mm_dispatch(cxm.MinkowskiManifold(), p, chart) + assert isinstance(g, DiagonalMetric) + expected = jnp.array([-1.0, 1.0, 1.0, 1.0]) + assert jnp.allclose(g.diagonal, expected, atol=1e-6) def test_metric_matrix_is_position_independent(self): """Minkowski metric is flat — does not depend on position.""" - m = cxm.MinkowskiMetric() chart = cxc.MinkowskiCT() p1 = { "ct": u.Q(1.0, "m"), @@ -366,12 +349,11 @@ def test_metric_matrix_is_position_independent(self): "y": u.Q(-30.0, "m"), "z": u.Q(20.0, "m"), } - g1 = m.metric_matrix(chart, at=p1) - g2 = m.metric_matrix(chart, at=p2) - assert jnp.allclose(g1.value, g2.value, atol=1e-6) + g1 = mm_dispatch(cxm.MinkowskiManifold(), p1, chart) + g2 = mm_dispatch(cxm.MinkowskiManifold(), p2, chart) + assert jnp.allclose(g1.diagonal, g2.diagonal, atol=1e-6) def test_metric_matrix_jit(self): - m = cxm.MinkowskiMetric() chart = cxc.MinkowskiCT() p = { "ct": u.Q(1.0, "m"), @@ -382,10 +364,10 @@ def test_metric_matrix_jit(self): @jax.jit def compute(p): - return m.metric_matrix(chart, at=p) + return mm_dispatch(cxm.MinkowskiManifold(), p, chart) g = compute(p) - assert g.shape == (4, 4) + assert isinstance(g, DiagonalMetric) def test_carried_by_minkowski_manifold(self): """minkowski4d.metric should return a MinkowskiMetric.""" @@ -395,30 +377,30 @@ def test_carried_by_minkowski_manifold(self): # ============================================================================= -# InducedMetric +# PullbackMetric # ============================================================================= -class TestInducedMetric: - """Tests for InducedMetric, the pullback metric on an embedded manifold.""" +class TestPullbackMetric: + """Tests for PullbackMetric, the pullback metric on an embedded manifold.""" def test_induced_metric_is_not_abstractdiagonalmetric(self): - """InducedMetric is NOT an AbstractDiagonalMetric.""" + """PullbackMetric is NOT an AbstractDiagonalMetricField.""" manifold = cxm.embedded_twosphere(radius=1.0) - assert not isinstance(manifold.metric, cxm.AbstractDiagonalMetric) + assert not isinstance(manifold.metric, cxm.AbstractDiagonalMetricField) def test_unit_sphere_at_equator(self): """Induced metric on S^2 embedded in R^3 at equator matches sphere metric.""" manifold = cxm.embedded_twosphere(radius=1.0) - metric = manifold.metric - assert isinstance(metric, cxm.InducedMetric) + assert isinstance(manifold.metric, cxm.PullbackMetric) p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} - g = metric.metric_matrix(cxc.sph2, at=p) - assert g.shape == (2, 2) + g = mm_dispatch(manifold, p, cxc.sph2) + assert isinstance(g, DenseMetric) + assert g.matrix.shape == (2, 2) # At equator: diag(1, sin^2(pi/2)) = diag(1, 1) expected = jnp.eye(2) - assert jnp.allclose(g.value, expected, atol=1e-4) + assert jnp.allclose(_mat_val(g), expected, atol=1e-4) def test_radius_2_sphere_at_equator(self): """Induced metric on radius-2 sphere. @@ -426,63 +408,60 @@ def test_radius_2_sphere_at_equator(self): diag(R^2, R^2 sin^2 theta) at equator = diag(4, 4). """ manifold = cxm.EmbeddedManifold( - intrinsic=cxm.HyperSphericalManifold(), - ambient=cxm.EuclideanManifold(3), + intrinsic=cxm.S2, + ambient=cxm.R3, embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "m")), ) - metric = manifold.metric - assert isinstance(metric, cxm.InducedMetric) + assert isinstance(manifold.metric, cxm.PullbackMetric) p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} - g = metric.metric_matrix(cxc.sph2, at=p) - assert g.shape == (2, 2) + g = mm_dispatch(manifold, p, cxc.sph2) + assert isinstance(g, DenseMetric) + assert g.matrix.shape == (2, 2) expected = jnp.diag(jnp.array([4.0, 4.0])) - assert jnp.allclose(g.value, expected, atol=1e-3) + assert jnp.allclose(_mat_val(g), expected, atol=1e-3) def test_induced_metric_jit(self): - """InducedMetric.metric_matrix should work under jit.""" + """metric_matrix dispatch should work under jit for EmbeddedManifold.""" manifold = cxm.embedded_twosphere(radius=1.0) - metric = manifold.metric @jax.jit def compute(p): - return metric.metric_matrix(cxc.sph2, at=p) + return mm_dispatch(manifold, p, cxc.sph2) p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} g = compute(p) - assert g.shape == (2, 2) + assert isinstance(g, DenseMetric) + assert g.matrix.shape == (2, 2) # ============================================================================= -# CartesianProductMetric +# ProductMetric # ============================================================================= -class TestCartesianProductMetric: - """Tests for CartesianProductMetric on product manifolds.""" +class TestProductMetric: + """Tests for ProductMetric on product manifolds.""" def test_product_metric_is_not_abstractdiagonalmetric(self): - """CartesianProductMetric is NOT an AbstractDiagonalMetric.""" + """ProductMetric is NOT an AbstractDiagonalMetricField.""" M = cxm.CartesianProductManifold( - factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - factor_names=("S2", "R1"), + factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1") ) - assert not isinstance(M.metric, cxm.AbstractDiagonalMetric) + assert not isinstance(M.metric, cxm.AbstractDiagonalMetricField) def test_signature_concatenates_factor_signatures(self): M = cxm.CartesianProductManifold( - factors=(cxm.MinkowskiManifold(), cxm.EuclideanManifold(1)), - factor_names=("st", "x"), + factors=(cxm.MinkowskiManifold(), cxm.R1), factor_names=("st", "x") ) metric = M.metric - assert isinstance(metric, cxm.CartesianProductMetric) + assert isinstance(metric, cxm.ProductMetric) assert metric.signature == (-1, 1, 1, 1, 1) assert metric.ndim == 5 def test_metric_matrix_is_block_diagonal_in_product_chart(self): M = cxm.CartesianProductManifold( - factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - factor_names=("S2", "R1"), + factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1") ) chart = cxc.CartesianProductChart((cxc.sph2, cxc.cart1d), ("S2", "R1")) p = { @@ -491,322 +470,13 @@ def test_metric_matrix_is_block_diagonal_in_product_chart(self): "R1.x": u.Q(1.0, "m"), } - g = M.metric.metric_matrix(chart, at=p) - assert g.shape == (3, 3) - assert jnp.allclose(g.value, jnp.eye(3), atol=1e-6) + g = mm_dispatch(M, p, chart) + assert isinstance(g, DenseMetric) + assert g.matrix.shape == (3, 3) + assert jnp.allclose(_mat_val(g), jnp.eye(3), atol=1e-6) def test_product_manifold_carries_product_metric(self): M = cxm.CartesianProductManifold( - factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - factor_names=("S2", "R1"), - ) - assert isinstance(M.metric, cxm.CartesianProductMetric) - - -# ============================================================================= -# AbstractMetric.cholesky -# ============================================================================= - - -class TestAbstractMetricCholesky: - """Tests for AbstractMetric.cholesky — Cholesky factorization L of the metric. - - The factorization satisfies g = L L^T where L is lower-triangular with - strictly positive diagonal entries. Tests are restricted to - positive-definite (Riemannian) metrics; indefinite metrics (Minkowski) are - excluded because jnp.linalg.cholesky requires positive-definiteness. - """ - - @pytest.fixture( - params=[ - "R3_cart", - "R3_sph", - "R2_cart", - "hyperspherical2d", - ] - ) - def metric_chart_at(self, request): - cases = { - "R3_cart": ( - cxm.EuclideanMetric(3), - cxc.cart3d, - {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")}, - ), - "R3_sph": ( - cxm.EuclideanMetric(3), - cxc.sph3d, - { - "r": u.Q(2.0, "m"), - "theta": u.Angle(jnp.pi / 3, "rad"), - "phi": u.Angle(1.0, "rad"), - }, - ), - "R2_cart": ( - cxm.EuclideanMetric(2), - cxc.cart2d, - {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")}, - ), - "hyperspherical2d": ( - cxm.HyperSphericalMetric(ndim=2), - cxc.sph2, - {"theta": u.Angle(jnp.pi / 3, "rad"), "phi": u.Angle(0.5, "rad")}, - ), - } - return cases[request.param] - - def test_returns_quantitymatrix(self, metric_chart_at): - """cholesky() returns a QuantityMatrix for unitful metrics.""" - metric, chart, at = metric_chart_at - L = metric.cholesky(chart, at=at) - assert isinstance(L, cxi.QuantityMatrix) - - def test_shape(self, metric_chart_at): - """cholesky() result has shape (n, n).""" - metric, chart, at = metric_chart_at - L = metric.cholesky(chart, at=at) - n = metric.ndim - assert L.shape == (n, n) - - def test_lower_triangular(self, metric_chart_at): - """All entries strictly above the diagonal must be zero.""" - metric, chart, at = metric_chart_at - L = metric.cholesky(chart, at=at) - upper = jnp.triu(L.value, k=1) - assert jnp.allclose(upper, jnp.zeros_like(upper), atol=1e-6) - - def test_positive_diagonal(self, metric_chart_at): - """All diagonal entries of L must be strictly positive.""" - metric, chart, at = metric_chart_at - L = metric.cholesky(chart, at=at) - assert jnp.all(jnp.diag(L.value) > 0) - - def test_reconstruction(self, metric_chart_at): - """L @ L.T must equal the original metric matrix G (values and units).""" - metric, chart, at = metric_chart_at - L = metric.cholesky(chart, at=at) - G = metric.metric_matrix(chart, at=at) - assert isinstance(L, cxi.QuantityMatrix) - # Verify units: each L[i,j] must carry sqrt(G[i,j]) units - n = G.value.shape[0] - for i in range(n): - for j in range(n): - assert L.unit[i][j] ** 2 == G.unit[i][j] - # Verify numeric reconstruction: L @ L^T == G - assert jnp.allclose(L.value @ L.value.T, G.value, atol=1e-6) - - def test_cartesian_euclidean_is_identity(self): - """Cholesky of the identity metric (Cartesian Euclidean) is the identity.""" - metric = cxm.EuclideanMetric(3) - at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} - L = metric.cholesky(cxc.cart3d, at=at) - assert jnp.allclose(L.value, jnp.eye(3), atol=1e-6) - - def test_diagonal_metric_cholesky_is_diagonal(self, metric_chart_at): - """For diagonal metrics, L is also diagonal (L_ii = sqrt(g_ii)).""" - metric, chart, at = metric_chart_at - G = metric.metric_matrix(chart, at=at) - if not metric.is_diagonal(chart, at=at): - pytest.skip("metric is not diagonal at this chart/point") - L = metric.cholesky(chart, at=at) - # Off-diagonal entries of L must be zero - offdiag = L.value - jnp.diag(jnp.diag(L.value)) - assert jnp.allclose(offdiag, jnp.zeros_like(offdiag), atol=1e-6) - # Diagonal entries must equal sqrt(g_ii) - expected_diag = jnp.sqrt(jnp.diag(G.value)) - assert jnp.allclose(jnp.diag(L.value), expected_diag, atol=1e-6) - - def test_jit(self, metric_chart_at): - """cholesky() must be compatible with jax.jit.""" - metric, chart, at = metric_chart_at - - @jax.jit - def compute(at): - return metric.cholesky(chart, at=at) - - L = compute(at) - assert L.shape == (metric.ndim, metric.ndim) - - -# ============================================================================= -# AbstractMetric.is_diagonal (base-class implementation) -# ============================================================================= - - -class TestAbstractMetricIsDiagonal: - """Tests for AbstractMetric.is_diagonal — the matrix-based base implementation. - - AbstractMetric.is_diagonal evaluates the metric matrix at a specific point - and checks whether all off-diagonal entries satisfy ``jnp.allclose(..., 0)`` - (i.e. within default floating-point tolerance, NOT exact equality). - - AbstractDiagonalMetric overrides this to return True unconditionally without - evaluating the matrix; that behaviour is covered by - TestAbstractDiagonalMetricContract. - - Note: the InducedMetric on S^2 in the sph2 chart has off-diagonal entries at - machine-epsilon level (~1e-17) everywhere, so is_diagonal always returns True - for that metric. - """ - - # ------------------------------------------------------------------ - # True cases: orthogonal metrics - # ------------------------------------------------------------------ - - def test_induced_metric_on_sphere_is_diagonal(self): - """InducedMetric on S^2 in sph2 is diagonal everywhere (off-diag ~1e-17).""" - metric = cxm.embedded_twosphere(radius=1.0).metric - assert not isinstance(metric, cxm.AbstractDiagonalMetric) - p = {"theta": u.Angle(jnp.pi / 3, "rad"), "phi": u.Angle(1.0, "rad")} - result = metric.is_diagonal(cxc.sph2, at=p) - assert bool(result) is True - - def test_product_metric_euclidean_cart_is_diagonal(self): - """Euclidean factors in Cartesian chart is diagonal.""" - M = cxm.CartesianProductManifold( - factors=(cxm.EuclideanManifold(2), cxm.EuclideanManifold(1)), - factor_names=("xy", "z"), - ) - chart = cxc.CartesianProductChart((cxc.cart2d, cxc.cart1d), ("xy", "z")) - p = {"xy.x": u.Q(1.0, "m"), "xy.y": u.Q(2.0, "m"), "z.x": u.Q(3.0, "m")} - result = M.metric.is_diagonal(chart, at=p) - assert bool(result) is True - - # ------------------------------------------------------------------ - # Output contract - # ------------------------------------------------------------------ - - def test_returns_scalar_bool_array(self): - """is_diagonal() must return a scalar Array with bool dtype.""" - metric = cxm.embedded_twosphere(radius=1.0).metric - p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} - result = metric.is_diagonal(cxc.sph2, at=p) - assert result.shape == () - assert result.dtype == jnp.bool_ - - def test_jit_compatible(self): - """is_diagonal() must work under jax.jit.""" - metric = cxm.embedded_twosphere(radius=1.0).metric - p = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} - - @jax.jit - def compute(p): - return metric.is_diagonal(cxc.sph2, at=p) - - result = compute(p) - assert result.shape == () - assert result.dtype == jnp.bool_ - - # ------------------------------------------------------------------ - # Consistency: is_diagonal agrees with allclose off-diagonal check - # ------------------------------------------------------------------ - - @pytest.mark.parametrize( - ("metric", "chart", "at"), - [ - # CartesianProductMetric of Euclidean spaces — exactly diagonal - ( - cxm.CartesianProductManifold( - factors=(cxm.EuclideanManifold(2), cxm.EuclideanManifold(1)), - factor_names=("xy", "z"), - ).metric, - cxc.CartesianProductChart((cxc.cart2d, cxc.cart1d), ("xy", "z")), - {"xy.x": u.Q(1.0, "m"), "xy.y": u.Q(2.0, "m"), "z.x": u.Q(3.0, "m")}, - ), - # InducedMetric at phi=0 — exactly diagonal - ( - cxm.embedded_twosphere(radius=1.0).metric, - cxc.sph2, - {"theta": u.Angle(jnp.pi / 4, "rad"), "phi": u.Angle(0.0, "rad")}, - ), - # InducedMetric at phi=1.0 — off-diag ~1e-17 (allclose → True) - ( - cxm.embedded_twosphere(radius=1.0).metric, - cxc.sph2, - {"theta": u.Angle(jnp.pi / 3, "rad"), "phi": u.Angle(1.0, "rad")}, - ), - # EuclideanMetric (AbstractDiagonalMetric) in Cartesian - ( - cxm.EuclideanMetric(3), - cxc.cart3d, - {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")}, - ), - ], - ids=["product_euclidean", "induced_phi0", "induced_phi1", "euclidean_cart"], - ) - def test_consistent_with_allclose_matrix_check(self, metric, chart, at): - """is_diagonal() must agree with jnp.allclose(offdiag, 0) on the matrix. - - The base implementation uses allclose (not exact equality). - """ - result = metric.is_diagonal(chart, at=at) - G = metric.metric_matrix(chart, at=at) - val = G.value if hasattr(G, "value") else G - offdiag = val - jnp.diag(jnp.diag(val)) - expected = jnp.allclose(offdiag, jnp.zeros_like(offdiag)) - assert bool(result) == bool(expected) - - # ------------------------------------------------------------------ - # Property tests - # ------------------------------------------------------------------ - - @settings(max_examples=40, deadline=None) - @given( - theta=st.floats(min_value=0.05, max_value=jnp.pi - 0.05, allow_nan=False), - radius=st.floats(min_value=0.1, max_value=10.0, allow_nan=False), - ) - def test_induced_sphere_at_phi_zero_always_diagonal(self, theta, radius): - """InducedMetric on any sphere in sph2 at phi=0 is always exactly diagonal. - - At phi=0 the Jacobian columns are orthogonal without any cancellation: - column_0 = [cos(θ), 0, -sin(θ)], column_1 = [0, sin(θ), 0]. - Their dot-product is identically 0 in IEEE arithmetic. - """ - manifold = cxm.EmbeddedManifold( - intrinsic=cxm.HyperSphericalManifold(), - ambient=cxm.EuclideanManifold(3), - embed_map=cxm.TwoSphereIn3D(radius=radius), + factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1") ) - p = {"theta": u.Angle(theta, "rad"), "phi": u.Angle(0.0, "rad")} - result = manifold.metric.is_diagonal(cxc.sph2, at=p) - assert bool(result) is True - - @settings(max_examples=40, deadline=None) - @given( - theta=st.floats(min_value=0.05, max_value=jnp.pi - 0.05, allow_nan=False), - phi=st.floats(min_value=-jnp.pi, max_value=jnp.pi, allow_nan=False), - ) - def test_is_diagonal_agrees_with_allclose_for_induced(self, theta, phi): - """is_diagonal() must equal allclose(offdiag, 0) for InducedMetric at any point. - - The base implementation uses jnp.allclose (not exact equality). - """ - metric = cxm.embedded_twosphere(radius=1.0).metric - p = {"theta": u.Angle(theta, "rad"), "phi": u.Angle(phi, "rad")} - result = metric.is_diagonal(cxc.sph2, at=p) - G = metric.metric_matrix(cxc.sph2, at=p) - val = G.value - offdiag = val - jnp.diag(jnp.diag(val)) - expected = jnp.allclose(offdiag, jnp.zeros_like(offdiag)) - assert bool(result) == bool(expected) - - @settings(max_examples=40, deadline=None) - @given( - theta=st.floats(min_value=0.05, max_value=jnp.pi - 0.05, allow_nan=False), - phi=st.floats(min_value=-jnp.pi, max_value=jnp.pi, allow_nan=False), - ) - def test_abstractdiagonalmetric_matrix_is_nearly_diagonal(self, theta, phi): - """AbstractDiagonalMetric structural promise: near-zero off-diagonals. - - is_diagonal() returns True unconditionally (structural promise), and the - metric matrix confirms this: off-diagonal entries are within floating-point - tolerance of zero even though is_diagonal doesn't evaluate the matrix. - """ - metric = cxm.HyperSphericalMetric(ndim=2) - assert isinstance(metric, cxm.AbstractDiagonalMetric) - p = {"theta": u.Angle(theta, "rad"), "phi": u.Angle(phi, "rad")} - # Structural promise (overridden method — no matrix evaluated) - assert bool(metric.is_diagonal(cxc.sph2, at=p)) is True - # Matrix-level confirmation - G = metric.metric_matrix(cxc.sph2, at=p) - offdiag = G.value - jnp.diag(jnp.diag(G.value)) - assert jnp.allclose(offdiag, jnp.zeros_like(offdiag), atol=1e-6) + assert isinstance(M.metric, cxm.ProductMetric) diff --git a/tests/unit/manifolds/test_scale_factors_dispatch.py b/tests/unit/manifolds/test_scale_factors_dispatch.py index 4350b95e..737ac951 100644 --- a/tests/unit/manifolds/test_scale_factors_dispatch.py +++ b/tests/unit/manifolds/test_scale_factors_dispatch.py @@ -7,14 +7,16 @@ import coordinax.charts as cxc import coordinax.manifolds as cxm -from coordinax.internal import QuantityMatrix +from coordinax._src.metric.matrix import DiagonalMetric +from coordinax.api.manifolds import metric_matrix as mm_dispatch +from coordinax.internal import QMatrix class TestScaleFactorsEuclidean: """Tests for scale_factors on Euclidean metrics and manifolds.""" - def test_cartesian_metric_returns_1d_quantitymatrix(self): - metric = cxm.EuclideanMetric(3) + def test_cartesian_metric_returns_1d_QMatrix(self): + metric = cxm.FlatMetric(3) at = { "x": u.Q(jnp.array(1.0), "m"), "y": u.Q(jnp.array(2.0), "m"), @@ -23,14 +25,14 @@ def test_cartesian_metric_returns_1d_quantitymatrix(self): result = cxm.scale_factors(metric, cxc.cart3d, at=at) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) assert result.shape == (3,) assert result.ndim == 1 assert jnp.allclose(result.value, jnp.array([1.0, 1.0, 1.0])) assert all(result.unit[i] == u.unit("") for i in range(3)) def test_spherical_metric_returns_metric_diagonal_entries(self): - metric = cxm.EuclideanMetric(3) + metric = cxm.FlatMetric(3) at = { "r": u.Q(jnp.array(2.0), "m"), "theta": u.Angle(jnp.pi / 6, "rad"), @@ -39,7 +41,7 @@ def test_spherical_metric_returns_metric_diagonal_entries(self): result = cxm.scale_factors(metric, cxc.sph3d, at=at) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) assert result.shape == (3,) assert jnp.allclose(result.value, jnp.array([1.0, 4.0, 1.0]), atol=1e-6) assert result.unit[0] == u.unit("") @@ -50,35 +52,38 @@ def test_spherical_metric_returns_metric_diagonal_entries(self): class TestScaleFactorsGeneric: """Tests for generic metric-based scale_factors behavior.""" - def test_hyperspherical_bare_arrays_promote_to_quantitymatrix(self): - metric = cxm.HyperSphericalMetric(ndim=2) + def test_hyperspherical_bare_arrays_promote_to_QMatrix(self): + metric = cxm.RoundMetric(ndim=2) at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} result = cxm.scale_factors(metric, cxc.sph2, at=at) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) assert result.shape == (2,) assert jnp.allclose(result.value, jnp.array([1.0, 1.0]), atol=1e-6) assert all(result.unit[i] == u.unit("") for i in range(2)) def test_generic_path_matches_metric_matrix_diag(self): - metric = cxm.HyperSphericalMetric(ndim=2) + metric = cxm.RoundMetric(ndim=2) at = { "theta": u.Angle(jnp.pi / 3, "rad"), "phi": u.Angle(jnp.array(0.1), "rad"), } - expected_metric = metric.metric_matrix(cxc.sph2, at=at) - assert isinstance(expected_metric, QuantityMatrix) - expected = expected_metric.diag() + # S2 in sph2 returns DiagonalMetric; diagonal IS the scale factors + expected_mm = mm_dispatch(cxm.HyperSphericalManifold(2), at, cxc.sph2) + assert isinstance(expected_mm, DiagonalMetric) + # Extract numeric diagonal values + diag = expected_mm.diagonal + expected_values = diag.value if isinstance(diag, QMatrix) else diag + result = cxm.scale_factors(metric, cxc.sph2, at=at) - assert isinstance(result, QuantityMatrix) - assert jnp.allclose(result.value, expected.value) - assert result.unit.to_string() == expected.unit.to_string() + assert isinstance(result, QMatrix) + assert jnp.allclose(result.value, expected_values, atol=1e-6) def test_jit(self): - metric = cxm.HyperSphericalMetric(ndim=2) + metric = cxm.RoundMetric(ndim=2) @jax.jit def compute(at): @@ -90,11 +95,11 @@ def compute(at): } result = compute(at) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) assert jnp.allclose(result.value, jnp.array([1.0, 1.0]), atol=1e-6) def test_vmap_values(self): - metric = cxm.HyperSphericalMetric(ndim=2) + metric = cxm.RoundMetric(ndim=2) thetas = jnp.array([jnp.pi / 6, jnp.pi / 4, jnp.pi / 2]) def compute(theta): @@ -110,11 +115,11 @@ def compute(theta): def test_embedded_manifold_requires_induced_metric(self): M = cxm.EmbeddedManifold( - intrinsic=cxm.HyperSphericalManifold(), - ambient=cxm.EuclideanManifold(3), + intrinsic=cxm.S2, + ambient=cxm.R3, embed_map=cxm.TwoSphereIn3D(radius=u.Q(jnp.array(2.0), "m")), ) - assert isinstance(M.metric, cxm.InducedMetric) + assert isinstance(M.metric, cxm.PullbackMetric) at = { "theta": u.Angle(jnp.pi / 6, "rad"), @@ -133,7 +138,7 @@ def test_embedded_manifold_requires_induced_metric(self): # not just a coincidental [4, 4] value at the equator. expected = jnp.array([4.0, 1.0]) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) assert result.shape == (2,) assert jnp.allclose(result.value, expected, atol=1e-6) assert result.unit[0] == u.unit("m2 / rad2") From 6644c78b108459c514520d6dc2f66edf603726ca Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:24:37 -0400 Subject: [PATCH 04/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(manifolds):?= =?UTF-8?q?=20update=20public=20API=20and=20re-exports=20for=20new=20metri?= =?UTF-8?q?c=20symbols?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The coordinax.api.manifolds dispatch API gains metric_matrix and metric_representation. The coordinax.manifolds re-export module is updated to expose the renamed metric classes (FlatMetric, RoundMetric, ProductMetric, PullbackMetric) and new matrix types (AbstractMetricMatrix, DiagonalMetric, DenseMetric). AbstractManifold is updated to use AbstractMetricField. The scale_factors dispatch file gains metric-aware overloads. Public internal.py and main.py re-exports are aligned with the new names. --- .../src/coordinax/api/manifolds.py | 90 +++++++++++- .../hypothesis/manifolds/_src/manifold.py | 18 ++- src/coordinax/_src/base/atlas.py | 4 +- src/coordinax/_src/base/manifold.py | 34 ++--- src/coordinax/_src/custom/manifold.py | 6 +- src/coordinax/_src/manifolds/angle_between.py | 30 ++-- src/coordinax/_src/manifolds/guess.py | 2 +- src/coordinax/_src/manifolds/scale_factors.py | 138 +++++++++++++++--- src/coordinax/internal.py | 14 +- src/coordinax/main.py | 4 +- src/coordinax/manifolds.py | 33 +++-- .../integration/manifolds/test_custom_jax.py | 8 +- .../manifolds/test_angle_between_dispatch.py | 8 +- tests/unit/manifolds/test_custom.py | 4 +- 14 files changed, 289 insertions(+), 104 deletions(-) diff --git a/packages/coordinax.api/src/coordinax/api/manifolds.py b/packages/coordinax.api/src/coordinax/api/manifolds.py index f747005b..257848c9 100644 --- a/packages/coordinax.api/src/coordinax/api/manifolds.py +++ b/packages/coordinax.api/src/coordinax/api/manifolds.py @@ -2,6 +2,8 @@ __all__ = ( "guess_manifold", + "metric_matrix", + "metric_representation", "scale_factors", "angle_between", "pt_embed", @@ -24,8 +26,6 @@ def guess_manifold(*args: Any, **kwargs: Any) -> "coordinax.manifolds.AbstractManifold": """Guess the manifold from arguments. - Examples - -------- >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm @@ -496,3 +496,89 @@ def pt_map(*args: Any, **kwargs: Any) -> CDict: """ del args, kwargs # Unused in abstract method raise NotImplementedError # pragma: no cover + + +@plum.dispatch.abstract +def metric_matrix(M: Any, point: Any, chart: Any, /) -> Any: + """Compute the coordinate metric matrix at ``point`` in ``chart``. + + Dispatches on the triple ``(type(M), type(point), type(chart))``. + Concrete implementations for each ``(M, chart)`` pair are registered + in the corresponding ``register_metric.py`` modules via + :func:`plum.dispatch`. + + Parameters + ---------- + M : AbstractManifold + The manifold carrying the metric field. + point : CDict + A component dictionary giving the coordinates in ``chart``. + chart : AbstractChart + The coordinate chart in which to express the metric. + + Returns + ------- + AbstractMetricMatrix + The metric matrix at ``point``. The concrete type + (`~coordinax._src.metric.matrix.DiagonalMetric` or + `~coordinax._src.metric.matrix.DenseMetric`) depends on the ``(M, + chart)`` pair and is declared by :func:`metric_representation`. + + Raises + ------ + NotImplementedError + When no specific dispatch rule is registered for the given types. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import coordinax.api.manifolds as cxmapi + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + >>> at = {"x": jnp.array(1.0), "y": jnp.array(2.0), "z": jnp.array(3.0)} + >>> cxmapi.metric_matrix(cxm.R3, at, cxc.cart3d) + DiagonalMetric(diagonal=f64[3]) + + """ + del M, point, chart + raise NotImplementedError # pragma: no cover + + +@plum.dispatch.abstract +def metric_representation(M: Any, chart: Any, /) -> Any: + """Return the `AbstractMetricMatrix` subtype for ``(manifold, chart)``. + + A lightweight, allocation-free query that reports which concrete + `~coordinax._src.metric.matrix.AbstractMetricMatrix` subclass + :func:`metric_matrix` will return for a given ``(manifold, chart)`` pair, + without actually computing the metric values. + + Dispatches on ``(type(manifold), type(chart))``. Concrete rules are + registered in the relevant ``register_metric.py`` modules. The default + fallback returns `~coordinax._src.metric.matrix.DenseMetric`. + + Parameters + ---------- + M : AbstractManifold + The manifold. + chart : AbstractChart + The coordinate chart. + + Returns + ------- + type[AbstractMetricMatrix] + The metric matrix type guaranteed for this ``(manifold, chart)`` pair. + + Examples + -------- + >>> import coordinax.api.manifolds as cxmapi + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + >>> cxmapi.metric_representation(cxm.R3, cxc.cart3d) is cxm.DiagonalMetric + True + + """ + del M, chart + raise NotImplementedError # pragma: no cover diff --git a/packages/coordinax.hypothesis/src/coordinax/hypothesis/manifolds/_src/manifold.py b/packages/coordinax.hypothesis/src/coordinax/hypothesis/manifolds/_src/manifold.py index a39ce3c9..919b9b96 100644 --- a/packages/coordinax.hypothesis/src/coordinax/hypothesis/manifolds/_src/manifold.py +++ b/packages/coordinax.hypothesis/src/coordinax/hypothesis/manifolds/_src/manifold.py @@ -281,7 +281,7 @@ def manifolds( # noqa: F811 @st.composite def manifolds( # noqa: F811 draw: st.DrawFn, - manifold_cls: type[cxm.EuclideanManifold], + M_cls: type[cxm.EuclideanManifold], /, *, filter: type | tuple[type, ...] | st.SearchStrategy = (), @@ -305,7 +305,7 @@ def manifolds( # noqa: F811 if target_ndim is None else target_ndim ) - return cxm.EuclideanManifold(dim) + return M_cls(dim) @plum.dispatch @@ -313,7 +313,7 @@ def manifolds( # noqa: F811 @st.composite def manifolds( # noqa: F811 draw: st.DrawFn, - manifold_cls: type[cxm.HyperSphericalManifold], + M_cls: type[cxm.HyperSphericalManifold], /, *, filter: type | tuple[type, ...] | st.SearchStrategy = (), @@ -332,10 +332,11 @@ def manifolds( # noqa: F811 >>> sphere = cxmst.manifolds(cxm.HyperSphericalManifold) """ + del M_cls target_ndim = draw_if_strategy(draw, ndim) if target_ndim is not None and target_ndim != 2: assume(False) - return cxm.HyperSphericalManifold() + return cxm.S2 @plum.dispatch @@ -343,7 +344,7 @@ def manifolds( # noqa: F811 @st.composite def manifolds( # noqa: F811 draw: st.DrawFn, - manifold_cls: type[cxm.EmbeddedManifold], + M_cls: type[cxm.EmbeddedManifold], /, *, filter: type | tuple[type, ...] | st.SearchStrategy = (), @@ -355,8 +356,8 @@ def manifolds( # noqa: F811 Currently this strategy generates an embedded two-sphere by constructing ``EmbeddedManifold`` directly with: - - ``intrinsic=HyperSphericalManifold()`` - - ``ambient=EuclideanManifold(3)`` + - ``intrinsic=S2`` + - ``ambient=R3`` - ``embed_map=TwoSphereIn3D(radius=...)`` Examples with ``ndim != 2`` are discarded via ``hypothesis.assume``. @@ -369,6 +370,7 @@ def manifolds( # noqa: F811 >>> embedded = cxmst.manifolds(cxm.EmbeddedManifold) """ + del M_cls target_ndim = draw_if_strategy(draw, ndim) if target_ndim is not None and target_ndim != 2: assume(False) @@ -419,7 +421,7 @@ def manifolds( # noqa: F811 required_chart_classes=required_chart_classes, ) ) - metric = cxm.EuclideanMetric(atlas.ndim) + metric = cxm.FlatMetric(atlas.ndim) return cxm.CustomManifold(atlas=atlas, metric=metric) diff --git a/src/coordinax/_src/base/atlas.py b/src/coordinax/_src/base/atlas.py index ab42ee5f..d86121e6 100644 --- a/src/coordinax/_src/base/atlas.py +++ b/src/coordinax/_src/base/atlas.py @@ -125,11 +125,11 @@ def default_chart_for( checks that the manifold's atlas matches this atlas. >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(2) + >>> M = cxm.R2 >>> M.atlas.default_chart_for(M) Cart2D(M=Rn(2)) - >>> try: M.atlas.default_chart_for(cxm.EuclideanManifold(3)) + >>> try: M.atlas.default_chart_for(cxm.R3) ... except ValueError as e: print(e) Atlas EuclideanAtlas(ndim=2) does not match manifold atlas EuclideanAtlas(ndim=3). diff --git a/src/coordinax/_src/base/manifold.py b/src/coordinax/_src/base/manifold.py index 617a63d6..e94ac5b9 100644 --- a/src/coordinax/_src/base/manifold.py +++ b/src/coordinax/_src/base/manifold.py @@ -14,9 +14,8 @@ import coordinax.angles as cxa import coordinax.api.manifolds as cxmapi from .atlas import AbstractAtlas -from .metric import AbstractMetric +from .metric import AbstractMetricField from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix if TYPE_CHECKING: import coordinax.charts # noqa: ICN001 @@ -81,7 +80,7 @@ class AbstractManifold(metaclass=abc.ABCMeta): The Euclidean 3-manifold $\mathbb{R}^3$ with its standard atlas: >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.Rn(3) >>> M Rn(3) @@ -131,7 +130,7 @@ class AbstractManifold(metaclass=abc.ABCMeta): The two-sphere $S^2$ is a 2-dimensional manifold that is *not* a subspace of any Euclidean atlas. Its atlas admits only angular charts: - >>> S2 = cxm.HyperSphericalManifold() + >>> S2 = cxm.HyperSphericalManifold(2) >>> S2.ndim 2 @@ -149,7 +148,7 @@ class AbstractManifold(metaclass=abc.ABCMeta): atlas: AbstractAtlas """Charts compatible with this manifold. This defines the smooth structure.""" - metric: AbstractMetric + metric: AbstractMetricField """The manifold's metric. This defines the geometric structure.""" def __post_init__(self) -> None: @@ -173,7 +172,7 @@ def ndim(self) -> int: therefore determines its dimension. >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.Rn(3) >>> M.ndim 3 @@ -186,7 +185,7 @@ def default_chart(self) -> "coordinax.charts.AbstractChart[Any, Any]": This is a convenience property that proxies to the atlas default chart. >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(2) + >>> M = cxm.Rn(2) >>> M.default_chart() Cart2D(M=Rn(2)) @@ -197,7 +196,7 @@ def has_chart(self, chart: "coordinax.charts.AbstractChart[Any, Any]", /) -> boo """Return whether ``chart`` belongs to this manifold atlas. >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(2) + >>> M = cxm.Rn(2) >>> M.has_chart(cxc.cart2d) True >>> M.has_chart(cxc.cart3d) @@ -210,7 +209,7 @@ def check_chart(self, chart: "coordinax.charts.AbstractChart[Any, Any]", /) -> N """Check that ``chart`` belongs to this manifold atlas. >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(2) + >>> M = cxm.Rn(2) >>> M.check_chart(cxc.cart2d) # does not raise """ @@ -220,21 +219,6 @@ def check_chart(self, chart: "coordinax.charts.AbstractChart[Any, Any]", /) -> N # ===================================================== - def scale_factors( - self, - chart: "coordinax.charts.AbstractChart[Any, Any]", - /, - *, - at: CDict, - usys: OptUSys = None, - ) -> QuantityMatrix: - r"""Return the diagonal entries of the manifold metric in ``chart`` at ``at``. - - This is a thin convenience wrapper over - ``cxmapi.scale_factors(self.metric, chart, at=at, usys=usys)``. - """ - return cxmapi.scale_factors(self.metric, chart, at=at, usys=usys) # ty: ignore[invalid-return-type] - def angle_between( self, chart: "coordinax.charts.AbstractChart[Any, Any]", @@ -259,7 +243,7 @@ def __pdoc__(self, **kw: Any) -> wl.AbstractDoc: >>> import wadler_lindig as wl >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.Rn(3) >>> wl.pprint(M) Rn(3) diff --git a/src/coordinax/_src/custom/manifold.py b/src/coordinax/_src/custom/manifold.py index ddd30c70..bfd1f026 100644 --- a/src/coordinax/_src/custom/manifold.py +++ b/src/coordinax/_src/custom/manifold.py @@ -9,7 +9,7 @@ import jax -from coordinax._src.base import AbstractAtlas, AbstractManifold, AbstractMetric +from coordinax._src.base import AbstractAtlas, AbstractManifold, AbstractMetricField @jax.tree_util.register_static @@ -31,7 +31,7 @@ class CustomManifold(AbstractManifold): ... charts=(cxc.Cart2D, cxc.Polar2D), ... chart_default=cxc.cart2d, ... ) - >>> M = cxm.CustomManifold(atlas=atlas, metric=cxm.EuclideanMetric(2)) + >>> M = cxm.CustomManifold(atlas=atlas, metric=cxm.FlatMetric(2)) >>> M.ndim 2 >>> M.default_chart() @@ -44,7 +44,7 @@ class CustomManifold(AbstractManifold): atlas: AbstractAtlas """Atlas defining chart compatibility for this manifold.""" - metric: AbstractMetric + metric: AbstractMetricField """Riemannian metric for this manifold, used for norm and distance computations.""" def __post_init__(self) -> None: diff --git a/src/coordinax/_src/manifolds/angle_between.py b/src/coordinax/_src/manifolds/angle_between.py index d814c2b2..6a434169 100644 --- a/src/coordinax/_src/manifolds/angle_between.py +++ b/src/coordinax/_src/manifolds/angle_between.py @@ -13,9 +13,10 @@ import coordinax.angles as cxa import coordinax.api.manifolds as cxmapi -from coordinax._src.base import AbstractChart, AbstractMetric +from coordinax._src.base import AbstractChart, AbstractMetricField from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix, pack_to_qmatrix +from coordinax._src.metric.matrix import DenseMetric +from coordinax.internal import QMatrix, UnitsMatrix, pack_to_qmatrix @plum.dispatch @@ -30,8 +31,6 @@ def angle_between( ) -> cxa.AbstractAngle: """Manifold-level dispatch: delegate to the attached metric. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc @@ -49,7 +48,7 @@ def angle_between( @plum.dispatch def angle_between( - metric: AbstractMetric, + metric: AbstractMetricField, chart: AbstractChart, uvec: CDict, vvec: CDict, @@ -61,17 +60,15 @@ def angle_between( """Return the metric angle between two tangent vectors. The input component dictionaries are interpreted as tangent-vector - components in the coordinate basis of ``chart``. The metric is evaluated - at the base point ``at``. + components in the coordinate basis of ``chart``. The metric is evaluated at + the base point ``at``. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> metric = cxm.EuclideanMetric(3) + >>> metric = cxm.FlatMetric(3) >>> at = { ... "r": u.Q(2.0, "m"), ... "theta": u.Angle(jnp.pi / 2, "rad"), @@ -96,7 +93,10 @@ def angle_between( chart.check_data(uvec, keys=True, values=False) chart.check_data(vvec, keys=True, values=False) - g = _as_quantity_matrix(metric.metric_matrix(chart, at=at, usys=usys)) + mm = cxmapi.metric_matrix(chart.M, at, chart) + g = _as_quantity_matrix( + mm.matrix if isinstance(mm, DenseMetric) else mm.to_dense().matrix # ty: ignore[unresolved-attribute] + ) u_qm = pack_to_qmatrix(uvec, keys=chart.components) v_qm = pack_to_qmatrix(vvec, keys=chart.components) @@ -111,14 +111,14 @@ def angle_between( return cxa.Angle(qnp.arccos(cosine_value), "rad") -def _as_quantity_matrix(x: QuantityMatrix | Array) -> QuantityMatrix: - """Convert a numeric matrix into a dimensionless QuantityMatrix.""" - if isinstance(x, QuantityMatrix): +def _as_quantity_matrix(x: QMatrix | Array) -> QMatrix: + """Convert a numeric matrix into a dimensionless QMatrix.""" + if isinstance(x, QMatrix): return x n_rows, n_cols = x.shape[-2:] units = UnitsMatrix(np.full((n_rows, n_cols), u.unit(""))) - return QuantityMatrix(value=x, unit=units) + return QMatrix(value=x, unit=units) def _check_nonzero_norm(*norms: u.AbstractQuantity) -> None: diff --git a/src/coordinax/_src/manifolds/guess.py b/src/coordinax/_src/manifolds/guess.py index 952a12cc..5dc9ccb4 100644 --- a/src/coordinax/_src/manifolds/guess.py +++ b/src/coordinax/_src/manifolds/guess.py @@ -36,7 +36,7 @@ def guess_manifold(obj: AbstractManifold, /) -> AbstractManifold: """Return the manifold of a manifold. >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.Rn(3) >>> cxm.guess_manifold(M) is M True diff --git a/src/coordinax/_src/manifolds/scale_factors.py b/src/coordinax/_src/manifolds/scale_factors.py index e1a8714f..e3a7af92 100644 --- a/src/coordinax/_src/manifolds/scale_factors.py +++ b/src/coordinax/_src/manifolds/scale_factors.py @@ -4,25 +4,37 @@ from jaxtyping import Array +import jax +import jax.numpy as jnp import numpy as np import plum +import quaxed.numpy as qnp import unxt as u +import coordinax.api.charts as cxcapi import coordinax.api.manifolds as cxmapi -from coordinax._src.base import AbstractChart, AbstractMetric +from coordinax._src.base import AbstractChart, AbstractMetricField from coordinax._src.custom_types import CDict, OptUSys -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.embedded.metric import PullbackMetric +from coordinax._src.euclidean.scale_factors import _column_squared_norms as _csn +from coordinax._src.metric.matrix import DiagonalMetric +from coordinax.internal import ( + QMatrix, + UnitsMatrix, + cdict_units, + pack_nonuniform_unit, +) + +DMLS = u.unit("") @plum.dispatch def scale_factors( chart: AbstractChart, /, *, at: CDict, usys: OptUSys = None -) -> QuantityMatrix: +) -> QMatrix: """Manifold-level dispatch: delegate to the attached metric. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc @@ -34,7 +46,7 @@ def scale_factors( ... "phi": u.Angle(jnp.array(0.0), "rad"), ... } >>> cxm.scale_factors(cxc.sph3d, at=at) - QuantityMatrix([1., 4., 4.], '(, km2 / rad2, km2 / rad2)') + QMatrix([1., 4., 4.], '(, km2 / rad2, km2 / rad2)') """ return cxmapi.scale_factors(chart.M.metric, chart, at=at, usys=usys) # ty: ignore[invalid-return-type] @@ -42,35 +54,119 @@ def scale_factors( @plum.dispatch def scale_factors( - metric: AbstractMetric, + metric: AbstractMetricField, chart: AbstractChart, /, *, at: CDict, usys: OptUSys = None, -) -> QuantityMatrix: - """Return the diagonal entries of ``metric.metric_matrix(...)`` as a vector. +) -> QMatrix: + """Return the diagonal entries of the metric at ``at`` in ``chart``. + + Uses the ``metric_matrix`` dispatch API to compute the metric, then + extracts the diagonal entries. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> metric = cxm.HyperSphericalMetric(2) + >>> metric = cxm.RoundMetric(2) >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} >>> cxm.scale_factors(metric, cxc.sph2, at=at) - QuantityMatrix([1., 1.], '(, )') + QMatrix([1., 1.], '(, )') """ - return _as_quantity_matrix(metric.metric_matrix(chart, at=at, usys=usys)).diag() - - -def _as_quantity_matrix(x: QuantityMatrix | Array) -> QuantityMatrix: - """Convert a numeric matrix into a dimensionless QuantityMatrix.""" - if isinstance(x, QuantityMatrix): + mm = cxmapi.metric_matrix(chart.M, at, chart) + if isinstance(mm, DiagonalMetric): + diag = mm.diagonal + if isinstance(diag, QMatrix): + return diag + units = UnitsMatrix(tuple(DMLS for _ in range(diag.shape[-1]))) + return QMatrix(diag, unit=units) + return _as_quantity_matrix(mm.matrix).diag() # ty: ignore[unresolved-attribute] + + +def _as_quantity_matrix(x: QMatrix | Array) -> QMatrix: + """Convert a numeric matrix into a dimensionless QMatrix.""" + if isinstance(x, QMatrix): return x n_rows, n_cols = x.shape[-2:] - units = UnitsMatrix(np.full((n_rows, n_cols), u.unit(""))) - return QuantityMatrix(value=x, unit=units) + units = UnitsMatrix(np.full((n_rows, n_cols), DMLS)) + return QMatrix(value=x, unit=units) + + +@plum.dispatch +def scale_factors( + metric: PullbackMetric, + chart: AbstractChart, + /, + *, + at: CDict, + usys: OptUSys = None, +) -> QMatrix: + """Return scale factors for a pullback (induced) metric via Jacobian pullback. + + Computes the Jacobian of the composed embedding ``intrinsic → + Cartesian ambient`` to obtain a unit-consistent Jacobian where every + entry has the same unit (``ambient_cart_unit / intrinsic_unit``). + The squared column norms then give the scale factors with correct units. + + >>> import jax.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + >>> M = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "m")), + ... ) + >>> at = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} + >>> cxm.scale_factors(M.metric, cxc.sph2, at=at) + QMatrix([4., 4.], '(m2 / rad2, m2 / rad2)') + + """ + embed_map = metric.embed_map + ambient_chart = embed_map.ambient + intrinsic_keys = embed_map.intrinsic.components + + # Use Cartesian ambient chart for a unit-consistent Jacobian. + # Every column of J_cart has the same per-column unit (cart_unit / intrinsic_unit), + # which makes _column_squared_norms well-defined with correct units. + cart_chart = ambient_chart.cartesian + cart_keys = cart_chart.components + + xat, ufrom = pack_nonuniform_unit(at, intrinsic_keys) + ufrom_ = tuple(uf if uf is not None else DMLS for uf in ufrom) + + # Evaluate once to determine Cartesian output units + at_ambient = embed_map.embed(at, usys=usys) + at_cart = cxcapi.pt_map(at_ambient, ambient_chart, cart_chart) + uto_ = cdict_units(at_cart, cart_keys) + uto_ = tuple(ut if ut is not None else DMLS for ut in uto_) + + # Build the unit matrix: J_cart.unit[k][i] = cart_unit_k / intrinsic_unit_i + unit_matrix = UnitsMatrix( + tuple(tuple(tj / fi for fi in ufrom_) for tj in uto_) # ty: ignore[unsupported-operator] + ) + + def _embed_cart(x_arr: jnp.ndarray) -> jnp.ndarray: + q = {k: u.Q(x_arr[i], ufrom_[i]) for i, k in enumerate(intrinsic_keys)} + q_ambient = embed_map.embed(q, usys=usys) + q_cart = cxcapi.pt_map(q_ambient, ambient_chart, cart_chart) + vals = [ + u.ustrip(uto_[j], q_cart[k]) # ty: ignore[not-subscriptable] + if isinstance(q_cart[k], u.AbstractQuantity) # ty: ignore[not-subscriptable] + else qnp.asarray(q_cart[k]) # ty: ignore[not-subscriptable] + for j, k in enumerate(cart_keys) + ] + return qnp.stack(vals) + + J_arr = jax.jacfwd(_embed_cart)(xat) # (n_cart, n_intrinsic) + J_cart = QMatrix(J_arr, unit=unit_matrix) + return _column_squared_norms(J_cart) + + +def _column_squared_norms(J: QMatrix | Array) -> QMatrix: + """Return the squared column norms of a Jacobian matrix as a QMatrix.""" + return _csn(J) diff --git a/src/coordinax/internal.py b/src/coordinax/internal.py index db468066..3dbb8b7a 100644 --- a/src/coordinax/internal.py +++ b/src/coordinax/internal.py @@ -11,7 +11,7 @@ Contents: -- ``QuantityMatrix`` +- ``QMatrix`` An N-D quantity matrix/vector where every element carries its own unit. Supports both 1-D (vector) and 2-D (matrix) cases. Useful for Jacobians and metric tensors whose entries have @@ -37,7 +37,7 @@ """ __all__ = ( - "QuantityMatrix", + "QMatrix", "UnitsMatrix", "tree_cast_int_bool_to_float", "pack_uniform_unit", @@ -47,6 +47,10 @@ "pack_to_qmatrix", "pos_named_objs", "jax_scalar_handler", + "det", + "det_p", + "inv", + "inv_p", # Types "CDict", "OptUSys", @@ -57,9 +61,13 @@ with install_import_hook("coordinax.internal"): from coordinax._src.custom_types import CDict, OptUSys from coordinax._src.internal import ( - QuantityMatrix, + QMatrix, UnitsMatrix, cdict_units, + det, + det_p, + inv, + inv_p, jax_scalar_handler, pack_nonuniform_unit, pack_to_qmatrix, diff --git a/src/coordinax/main.py b/src/coordinax/main.py index a27fcf95..8f2195c5 100644 --- a/src/coordinax/main.py +++ b/src/coordinax/main.py @@ -27,7 +27,7 @@ # manifolds and atlases "EuclideanManifold", "Rn", - "EuclideanMetric", + "FlatMetric", "R3", "S2", "embedded_twosphere", @@ -122,7 +122,7 @@ EmbeddedChart, EmbeddedManifold, EuclideanManifold, - EuclideanMetric, + FlatMetric, Rn, embedded_twosphere, ) diff --git a/src/coordinax/manifolds.py b/src/coordinax/manifolds.py index d66ce79d..5b6d57d4 100644 --- a/src/coordinax/manifolds.py +++ b/src/coordinax/manifolds.py @@ -8,11 +8,17 @@ "pt_map", "scale_factors", "angle_between", + "metric_matrix", + "metric_representation", # Abstract Manifold/Atlas/Metric "AbstractAtlas", - "AbstractMetric", + "AbstractMetricField", "AbstractManifold", - "AbstractDiagonalMetric", + "AbstractDiagonalMetricField", + # Metric matrix classes + "AbstractMetricMatrix", + "DiagonalMetric", + "DenseMetric", # Null "NoManifold", "no_manifold", @@ -22,7 +28,7 @@ "no_atlas", # Euclidean "EuclideanAtlas", - "EuclideanMetric", + "FlatMetric", "EuclideanManifold", "Rn", "R0", @@ -32,7 +38,7 @@ "RN", # HyperSpherical "HyperSphericalAtlas", - "HyperSphericalMetric", + "RoundMetric", "HyperSphericalManifold", "Sn", "S1", @@ -44,7 +50,7 @@ "minkowski4d", # Product "CartesianProductAtlas", - "CartesianProductMetric", + "ProductMetric", "CartesianProductManifold", # Embeddings "EmbeddedManifold", @@ -54,7 +60,7 @@ "embedded_twosphere", "EmbeddedManifold", "EmbeddedChart", - "InducedMetric", + "PullbackMetric", # Custom "CustomAtlas", "CustomMetric", @@ -68,9 +74,9 @@ with install_import_hook("coordinax.manifolds"): from ._src.base import ( AbstractAtlas, - AbstractDiagonalMetric, + AbstractDiagonalMetricField, AbstractManifold, - AbstractMetric, + AbstractMetricField, ) from ._src.custom import CustomAtlas, CustomManifold, CustomMetric from ._src.embedded import ( @@ -78,7 +84,7 @@ CustomEmbeddingMap, EmbeddedChart, EmbeddedManifold, - InducedMetric, + PullbackMetric, ) from ._src.euclidean import ( R0, @@ -88,10 +94,11 @@ RN, EuclideanAtlas, EuclideanManifold, - EuclideanMetric, + FlatMetric, Rn, ) from ._src.manifolds import * # noqa: F403 + from ._src.metric import AbstractMetricMatrix, DenseMetric, DiagonalMetric from ._src.minkowski import ( MinkowskiAtlas, MinkowskiManifold, @@ -109,7 +116,7 @@ from ._src.product import ( CartesianProductAtlas, CartesianProductManifold, - CartesianProductMetric, + ProductMetric, ) from ._src.product.galilean_ct import galilean_spacetime from ._src.spherical import ( @@ -117,7 +124,7 @@ S2, HyperSphericalAtlas, HyperSphericalManifold, - HyperSphericalMetric, + RoundMetric, Sn, TwoSphereIn3D, embedded_twosphere, @@ -126,6 +133,8 @@ from coordinax.api.manifolds import ( angle_between, guess_manifold, + metric_matrix, + metric_representation, pt_embed, pt_project, scale_factors, diff --git a/tests/integration/manifolds/test_custom_jax.py b/tests/integration/manifolds/test_custom_jax.py index e848d9ee..e37fbd9b 100644 --- a/tests/integration/manifolds/test_custom_jax.py +++ b/tests/integration/manifolds/test_custom_jax.py @@ -22,8 +22,8 @@ class TestCustomManifoldJAX: def test_jit_matches_eager(self) -> None: """Jit over manifold transition map matches eager execution.""" - x = u.Q(3.0, "m") - y = u.Q(4.0, "m") + x = u.Q(3, "m") + y = u.Q(4, "m") r_eager, theta_eager = _cart2d_to_polar_with_custom(x, y) r_jit, theta_jit = jax.jit(_cart2d_to_polar_with_custom)(x, y) @@ -35,8 +35,8 @@ def test_jit_matches_eager(self) -> None: def test_vmap_batch_radius(self) -> None: """Vmap over manifold transition map yields expected radii.""" - xs = u.Q(jnp.array([1.0, 0.0, 3.0]), "m") - ys = u.Q(jnp.array([0.0, 2.0, 4.0]), "m") + xs = u.Q(jnp.array([1, 0, 3]), "m") + ys = u.Q(jnp.array([0, 2, 4]), "m") r_batch, _theta_batch = jax.vmap(_cart2d_to_polar_with_custom)(xs, ys) diff --git a/tests/unit/manifolds/test_angle_between_dispatch.py b/tests/unit/manifolds/test_angle_between_dispatch.py index 3314f97d..90a98d70 100644 --- a/tests/unit/manifolds/test_angle_between_dispatch.py +++ b/tests/unit/manifolds/test_angle_between_dispatch.py @@ -15,7 +15,7 @@ class TestAngleBetweenEuclidean: """Tests for angle_between on Euclidean metrics and manifolds.""" def test_cartesian_right_angle_returns_angle(self): - metric = cxm.EuclideanMetric(2) + metric = cxm.FlatMetric(2) at = {"x": u.Q(jnp.array(0.0), "m"), "y": u.Q(jnp.array(0.0), "m")} uvec = {"x": u.Q(jnp.array(1.0), "m"), "y": u.Q(jnp.array(0.0), "m")} vvec = {"x": u.Q(jnp.array(0.0), "m"), "y": u.Q(jnp.array(2.0), "m")} @@ -30,7 +30,7 @@ class TestAngleBetweenFailureModes: """Tests for invalid inputs and unsupported metrics.""" def test_zero_norm_vector_raises_value_error(self): - metric = cxm.EuclideanMetric(2) + metric = cxm.FlatMetric(2) at = {"x": jnp.array(0.0), "y": jnp.array(0.0)} zero = {"x": jnp.array(0.0), "y": jnp.array(0.0)} other = {"x": jnp.array(1.0), "y": jnp.array(0.0)} @@ -67,7 +67,7 @@ class TestAngleBetweenJAX: """Tests for JAX compatibility of angle_between.""" def test_jit(self): - metric = cxm.HyperSphericalMetric(ndim=2) + metric = cxm.RoundMetric(ndim=2) @jax.jit def compute(theta): @@ -82,7 +82,7 @@ def compute(theta): assert jnp.allclose(got, jnp.pi / 4, atol=1e-6) def test_vmap_values(self): - metric = cxm.HyperSphericalMetric(ndim=2) + metric = cxm.RoundMetric(ndim=2) thetas = jnp.array([jnp.pi / 6, jnp.pi / 4, jnp.pi / 2]) def compute(theta): diff --git a/tests/unit/manifolds/test_custom.py b/tests/unit/manifolds/test_custom.py index d1bf6286..6966045f 100644 --- a/tests/unit/manifolds/test_custom.py +++ b/tests/unit/manifolds/test_custom.py @@ -60,7 +60,7 @@ def test_forwards_ndim_and_default_chart(self) -> None: charts=(cxc.Cart2D, cxc.Polar2D), chart_default=cxc.cart2d, ) - manifold = cxm.CustomManifold(atlas, metric=cxm.EuclideanMetric(2)) + manifold = cxm.CustomManifold(atlas, metric=cxm.FlatMetric(2)) assert manifold.ndim == 2 assert manifold.default_chart() == cxc.cart2d @@ -71,7 +71,7 @@ def test_has_chart_and_check_chart(self) -> None: charts=(cxc.Cart2D, cxc.Polar2D), chart_default=cxc.cart2d, ) - manifold = cxm.CustomManifold(atlas, metric=cxm.EuclideanMetric(2)) + manifold = cxm.CustomManifold(atlas, metric=cxm.FlatMetric(2)) assert manifold.has_chart(cxc.cart2d) assert manifold.has_chart(cxc.polar2d) From 260cb25ae1d008f986d461d93e0e09d17b2081f5 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:25:48 -0400 Subject: [PATCH 05/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor:=20propagat?= =?UTF-8?q?e=20QMatrix=20and=20metric=20class=20renames=20across=20charts,?= =?UTF-8?q?=20representations,=20transforms,=20vectors,=20distances,=20and?= =?UTF-8?q?=20frames?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All references to QuantityMatrix (now QMatrix), EuclideanMetric (now FlatMetric), HyperSphericalMetric (now RoundMetric), CartesianProductMetric (now ProductMetric), InducedMetric (now PullbackMetric), AbstractMetric (now AbstractMetricField), and AbstractDiagonalMetric (now AbstractDiagonalMetricField) are updated throughout the core modules and their tests. Docstring Examples section headers that merely introduce the code block are also removed for cleaner inline doctests. --- src/coordinax/_src/charts/checks.py | 8 - src/coordinax/_src/charts/d0.py | 4 - src/coordinax/_src/charts/jacobian.py | 44 +-- src/coordinax/_src/charts/register_cdict.py | 80 +---- src/coordinax/_src/charts/register_guess.py | 12 - src/coordinax/_src/charts/register_ptmap.py | 16 +- src/coordinax/_src/charts/scale_factors.py | 16 +- src/coordinax/_src/internal/dtype_utils.py | 2 - src/coordinax/_src/internal/pack_utils.py | 16 +- src/coordinax/_src/utils.py | 2 - src/coordinax/distances/_src/measures.py | 12 - .../distances/_src/register_converters.py | 2 - .../distances/_src/register_primitives.py | 16 - src/coordinax/distances/_src/register_unxt.py | 3 - src/coordinax/frames/_src/base.py | 2 - src/coordinax/frames/_src/example.py | 10 - src/coordinax/frames/_src/register_pfxm.py | 9 - src/coordinax/frames/_src/xfm.py | 6 - .../representations/_src/basis_change.py | 326 +++++++----------- src/coordinax/representations/_src/core.py | 21 +- src/coordinax/representations/_src/guess.py | 54 +-- .../representations/_src/register_cx.py | 6 - .../representations/_src/tangent_map.py | 30 +- src/coordinax/transforms/_src/actions/add.py | 10 - src/coordinax/transforms/_src/actions/base.py | 5 - .../transforms/_src/actions/composed.py | 14 - .../transforms/_src/actions/identity.py | 13 +- .../transforms/_src/actions/reflect.py | 18 +- .../transforms/_src/actions/register_apply.py | 68 ++-- .../transforms/_src/actions/rotate.py | 19 +- .../transforms/_src/actions/translate.py | 6 - src/coordinax/vectors/_src/bundle.py | 6 - src/coordinax/vectors/_src/point.py | 18 - src/coordinax/vectors/_src/register_cx.py | 81 ++--- .../vectors/_src/register_dataclassish.py | 2 - .../vectors/_src/register_manifolds.py | 3 +- src/coordinax/vectors/_src/register_quax.py | 16 +- src/coordinax/vectors/_src/register_unxt.py | 40 +-- tests/integration/angles/test_plum.py | 2 +- tests/integration/charts/test_jax.py | 44 +-- .../distances/test_distances_plum.py | 2 +- tests/integration/distances/test_quax.py | 90 ++--- tests/unit/angles/test_angle.py | 26 +- tests/unit/charts/test_base.py | 19 +- tests/unit/charts/test_cdict.py | 2 +- tests/unit/charts/test_checks.py | 55 ++- tests/unit/charts/test_galilean_charts.py | 16 +- tests/unit/charts/test_guess_chart.py | 6 +- tests/unit/charts/test_jacobian_pt_map.py | 24 +- tests/unit/charts/test_minkowski_charts.py | 2 +- tests/unit/charts/test_predef_charts.py | 12 +- tests/unit/charts/test_product.py | 29 +- tests/unit/charts/test_register_realize.py | 18 +- tests/unit/charts/test_utils.py | 10 +- tests/unit/distances/test_distance.py | 30 +- .../unit/representations/test_change_basis.py | 170 ++++++++- .../unit/representations/test_tangent_map.py | 25 +- tests/unit/transforms/conftest.py | 20 +- tests/unit/transforms/test_act.py | 71 ++-- tests/unit/vectors/test_tangent.py | 10 +- tests/usage/charts/test_jacobian.py | 18 +- tests/usage/frames/test_act_usage.py | 139 ++++---- 62 files changed, 735 insertions(+), 1121 deletions(-) diff --git a/src/coordinax/_src/charts/checks.py b/src/coordinax/_src/charts/checks.py index 4e96bb38..ccfa0c9e 100644 --- a/src/coordinax/_src/charts/checks.py +++ b/src/coordinax/_src/charts/checks.py @@ -16,8 +16,6 @@ def polar_range(polar: AbcQ, _l: AbcQ = _0d, _u: AbcQ = _pid, /) -> AbcQ: """Check that the polar angle is in the range. - Examples - -------- >>> import unxt as u Pass through the input if it's in the range. @@ -55,8 +53,6 @@ def strictly_positive( ) -> u.AbstractQuantity: """Check that the input is non-negative and non-zero. - Examples - -------- >>> import unxt as u Pass through the input if the value is non-negative. @@ -99,8 +95,6 @@ def leq( ) -> u.AbstractQuantity: """Check that the input value is less than or equal to the input maximum value. - Examples - -------- >>> import unxt as u Pass through the input if the value is less than or equal to the max value: @@ -130,8 +124,6 @@ def geq( ) -> u.AbstractQuantity: """Check that the input value is greater than or equal to the input minimum value. - Examples - -------- >>> import unxt as u Pass through the input if the value is greater than or equal to the min value: diff --git a/src/coordinax/_src/charts/d0.py b/src/coordinax/_src/charts/d0.py index 4b69b9fe..e10c4e69 100644 --- a/src/coordinax/_src/charts/d0.py +++ b/src/coordinax/_src/charts/d0.py @@ -60,15 +60,11 @@ class Cart0D(AbstractFixedComponentsChart[ZeroDKeys, ZeroDDims], Abstract0D): This chart has no coordinate components and no coordinate dimensions. It is the canonical Cartesian chart for 0D representations. - Examples - -------- >>> import coordinax.charts as cxc >>> cxc.cart0d.components () - >>> cxc.cart0d.coord_dimensions () - >>> isinstance(cxc.cartesian_chart(cxc.cart0d), cxc.Cart0D) True diff --git a/src/coordinax/_src/charts/jacobian.py b/src/coordinax/_src/charts/jacobian.py index a3c434f4..4ef6a5f7 100644 --- a/src/coordinax/_src/charts/jacobian.py +++ b/src/coordinax/_src/charts/jacobian.py @@ -12,7 +12,7 @@ $$ J^j{}_i(p) = \frac{\\partial \tau^j}{\partial q^i}\bigg|_p $$ where $q^i$ are the $C_1$ coordinates and $\tau^j$ are the $C_2$ coordinates. -The result is a 2-D {class}`~coordinax.internal.QuantityMatrix` of shape +The result is a 2-D {class}`~coordinax.internal.QMatrix` of shape $(n_\\mathrm{out},\\, n_\\mathrm{in})$ whose $(j, i)$ element carries units $$ \mathrm{unit}(J^j{}_i) = \frac{\mathrm{unit}(\tau^j)}{\mathrm{unit}(q^i)} $$ @@ -56,7 +56,7 @@ from coordinax._src.base import AbstractChart from coordinax._src.custom_types import CDict, OptUSys from coordinax.internal import ( - QuantityMatrix, + QMatrix, UnitsMatrix, pack_to_qmatrix, tree_cast_int_bool_to_float, @@ -78,15 +78,13 @@ def jac_pt_map(at: None, /, *fixed_args: Any, **fixed_kw: Any) -> Any: """Higher-order function for fixed-arg Jacobian point map. - Examples - -------- >>> import coordinax.charts as cxc >>> import unxt as u >>> map = cxc.jac_pt_map(None, cxc.cart3d, cxc.sph3d, usys=u.unitsystems.si) >>> at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} >>> map(at) - QuantityMatrix( + QMatrix( [[ 1., 0., 0.], [-0., -0., -1.], [ 0., 1., 0.]], @@ -110,15 +108,13 @@ def jac_pt_map( ) -> Callable[[object], Any]: """Higher-order function for fixed-arg Jacobian point map. - Examples - -------- >>> import coordinax.charts as cxc >>> import unxt as u >>> map = cxc.jac_pt_map(cxc.cart3d, cxc.sph3d, usys=u.unitsystems.si) >>> at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} >>> map(at) - QuantityMatrix( + QMatrix( [[ 1., 0., 0.], [-0., -0., -1.], [ 0., 1., 0.]], @@ -153,8 +149,6 @@ def jac_pt_map( coordinates expressed. Returns the JAX array Jacobian $J^j{}_i = \partial \tau^j / \partial q^i$, without unit annotation. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import unxt as u @@ -184,13 +178,13 @@ def jac_pt_map( return jac_pt_map_fn(at) # Compute Jacobian as array -def _repack_q_from_jac(jac_qq: QuantityMatrix, /) -> QuantityMatrix: - r"""Rebuild a 2-D ``QuantityMatrix`` Jacobian from the raw ``jax.jacfwd`` output. +def _repack_q_from_jac(jac_qq: QMatrix, /) -> QMatrix: + r"""Rebuild a 2-D ``QMatrix`` Jacobian from the raw ``jax.jacfwd`` output. When ``jax.jacfwd`` differentiates a function that maps a 1-D - ``QuantityMatrix`` of shape ``(n_in,)`` to a 1-D ``QuantityMatrix`` of - shape ``(n_out,)``, the result is a 2-D ``QuantityMatrix`` of shape - ``(n_out, n_in)`` whose ``.value`` is *itself* a 1-D ``QuantityMatrix`` + ``QMatrix`` of shape ``(n_in,)`` to a 1-D ``QMatrix`` of + shape ``(n_out,)``, the result is a 2-D ``QMatrix`` of shape + ``(n_out, n_in)`` whose ``.value`` is *itself* a 1-D ``QMatrix`` carrying the input units (one per column), and whose ``.unit`` is a 1-D ``UnitsMatrix`` carrying the output units (one per row). @@ -200,7 +194,7 @@ def _repack_q_from_jac(jac_qq: QuantityMatrix, /) -> QuantityMatrix: """ ufrom_, uto_ = jac_qq.value.unit, jac_qq.unit # ty: ignore[unresolved-attribute] units = UnitsMatrix(np.divide(uto_._units[:, None], ufrom_._units[None, :])) - return QuantityMatrix(jac_qq.value.value, units) # ty: ignore[unresolved-attribute] + return QMatrix(jac_qq.value.value, units) # ty: ignore[unresolved-attribute] @plum.dispatch @@ -211,7 +205,7 @@ def jac_pt_map( /, *, usys: OptUSys = None, -) -> Array | QuantityMatrix: +) -> Array | QMatrix: r"""Compute the Jacobian at a coordinate-dictionary base point. The primary dict-input dispatch. Branches on whether the values of *at* @@ -224,21 +218,21 @@ def jac_pt_map( ``Array`` dispatch this means *usys* must be provided. **Quantity-valued branch** (at least one value carries a unit) - Packs *at* into a 1-D ``QuantityMatrix`` via + Packs *at* into a 1-D ``QMatrix`` via ``pack_to_qmatrix(at, keys=from_chart.components)``, promotes any integer or boolean leaves to the default floating-point dtype (other dtypes, including complex, are left unchanged and will raise a ``TypeError`` from ``jax.jacfwd`` if passed), then computes ``J_qq = jax.jacfwd(pt_map_fn)(at_in)``. - Because ``jacfwd`` applied to a ``QuantityMatrix``-in / - ``QuantityMatrix``-out function yields a nested ``QuantityMatrix``, + Because ``jacfwd`` applied to a ``QMatrix``-in / + ``QMatrix``-out function yields a nested ``QMatrix``, ``_repack_q_from_jac`` is called to extract the correct 2-D unit structure. Returns ------- - Array | QuantityMatrix - Plain array when *at* is array-valued; ``QuantityMatrix`` of shape + Array | QMatrix + Plain array when *at* is array-valued; ``QMatrix`` of shape ``(n_out, n_in)`` with per-element units otherwise. Raises @@ -335,7 +329,7 @@ def jac_pt_map( /, *, usys: OptUSys = None, -) -> QuantityMatrix: +) -> QMatrix: r"""Compute the Jacobian of the transition function between two charts. $$ @@ -349,7 +343,7 @@ def jac_pt_map( >>> x = u.Q(jnp.array([1.0, 1.0]), "m") >>> cxc.jac_pt_map(cxc.cart2d, cxc.polar2d, usys=u.unitsystems.si)(x) - QuantityMatrix([[ 0.70710678, 0.70710678], + QMatrix([[ 0.70710678, 0.70710678], [-0.5 , 0.5 ]], '((, ), (rad / m, rad / m))') """ @@ -361,7 +355,7 @@ def jac_pt_map( # Astropy treats rad as dimensionless, so x2.unit == 1/m rather than # the correct rad/m. Force the right unit explicitly. rad_per_len = u.unit("rad") / x.unit - return QuantityMatrix( + return QMatrix( jnp.array([[x0.value, x1.value], [x2.value, x3.value]]), unit=((x0.unit, x1.unit), (rad_per_len, rad_per_len)), ) diff --git a/src/coordinax/_src/charts/register_cdict.py b/src/coordinax/_src/charts/register_cdict.py index 81650944..ab414b8a 100644 --- a/src/coordinax/_src/charts/register_cdict.py +++ b/src/coordinax/_src/charts/register_cdict.py @@ -14,7 +14,7 @@ import coordinax.api.charts as cxcapi from coordinax._src.base import AbstractChart from coordinax._src.custom_types import CDict -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax.internal import QMatrix, UnitsMatrix # =================================================================== # CDict @@ -24,8 +24,6 @@ def cdict(obj: CDict, /) -> CDict: """Return a dictionary as-is. - Examples - -------- >>> import coordinax.main as cx >>> d = {"x": 1.0, "y": 2.0} >>> cx.cdict(d) @@ -39,8 +37,6 @@ def cdict(obj: CDict, /) -> CDict: def cdict(obj: CDict, chart: AbstractChart, /) -> CDict: """Return a dictionary as-is. - Examples - -------- >>> import coordinax.charts as cxc >>> d = {"x": 1.0, "y": 2.0} >>> cxc.cdict(d, cxc.cart2d) @@ -62,14 +58,6 @@ def cdict(obj: u.AbstractQuantity, /) -> CDict: dimension. The appropriate Cartesian chart is determined from the last dimension of the quantity. - Raises - ------ - ValueError - If the last dimension of the quantity doesn't match a known Cartesian - chart (0D, 1D, 2D, or 3D). - - Examples - -------- >>> import coordinax.main as cx >>> import unxt as u >>> q = u.Q([1.0, 2.0, 3.0], "m") @@ -120,12 +108,12 @@ def cdict(obj: u.AbstractQuantity, keys: tuple[str, ...], /) -> CDict: @plum.dispatch -def cdict(obj: QuantityMatrix, keys: tuple[str, ...], /) -> CDict: - """Extract component dictionary from a 1D ``QuantityMatrix``. +def cdict(obj: QMatrix, keys: tuple[str, ...], /) -> CDict: + """Extract component dictionary from a 1D ``QMatrix``. This overload supports heterogeneous per-component units by constructing one quantity per chart component from the corresponding numeric slice and - unit in the ``QuantityMatrix``. + unit in the ``QMatrix``. Raises ------ @@ -137,21 +125,21 @@ def cdict(obj: QuantityMatrix, keys: tuple[str, ...], /) -> CDict: -------- >>> import jax.numpy as jnp >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix - >>> q = QuantityMatrix(jnp.array([1.0, 2.0, 3.0]), + >>> q = QMatrix(jnp.array([1.0, 2.0, 3.0]), ... unit=("m", "km/s", "rad")) >>> cxc.cdict(q, ('x', 'y', 'z')) {'x': Q(1., 'm'), 'y': Q(2., 'km / s'), 'z': Q(3., 'rad')} """ if obj.unit.ndim != 1: - msg = f"QuantityMatrix must be 1D for cdict, got ndim={obj.ndim}." + msg = f"QMatrix must be 1D for cdict, got ndim={obj.ndim}." raise ValueError(msg) if obj.shape[-1] != len(keys): msg = ( - f"QuantityMatrix last dimension {obj.shape[-1]} does not match " + f"QMatrix last dimension {obj.shape[-1]} does not match " f"provided keys {len(keys)}." ) raise ValueError(msg) @@ -171,14 +159,6 @@ def cdict(obj: u.AbstractQuantity, chart: AbstractChart, /) -> CDict: 2. The chart has homogeneous coordinate dimensions (all components have the same physical dimension, like Cartesian charts) - Raises - ------ - ValueError - If the last dimension of the quantity doesn't match the chart's - component count, or if dimensions don't match. - - Examples - -------- >>> import coordinax.charts as cxc >>> import unxt as u @@ -191,27 +171,19 @@ def cdict(obj: u.AbstractQuantity, chart: AbstractChart, /) -> CDict: @plum.dispatch -def cdict(obj: QuantityMatrix, chart: AbstractChart, /) -> CDict: - """Extract component dictionary from a 1D ``QuantityMatrix``. +def cdict(obj: QMatrix, chart: AbstractChart, /) -> CDict: + """Extract component dictionary from a 1D ``QMatrix``. This overload supports heterogeneous per-component units by constructing one quantity per chart component from the corresponding numeric slice and - unit in the ``QuantityMatrix``. + unit in the ``QMatrix``. - Raises - ------ - ValueError - If ``obj`` is not 1D, or if the last dimension does not match the - chart component count. - - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import unxt as u - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix - >>> q = QuantityMatrix( + >>> q = QMatrix( ... jnp.array([1.0, 2.0, 3.0]), ... unit=(u.unit("m"), u.unit("km/s"), u.unit("rad")), ... ) @@ -260,14 +232,6 @@ def cdict(obj: ArrayLike, keys: tuple[str, ...], /) -> CDict: def cdict(obj: ArrayLike, chart: AbstractChart, /) -> CDict: """Extract component dictionary from an array. - Raises - ------ - ValueError - If the last dimension of the quantity doesn't match a known Cartesian - chart (0D, 1D, 2D, or 3D). - - Examples - -------- >>> import coordinax.main as cx >>> import jax.numpy as jnp >>> arr = jnp.array([1.0, 2.0, 3.0]) @@ -288,14 +252,6 @@ def cdict( ) -> CDict: """Extract component dictionary from an array. - Raises - ------ - ValueError - If the last dimension of the quantity doesn't match a known Cartesian - chart (0D, 1D, 2D, or 3D). - - Examples - -------- >>> import coordinax.main as cx >>> import jax.numpy as jnp >>> arr = jnp.array([1.0, 2.0, 3.0]) @@ -317,14 +273,6 @@ def cdict( ) -> CDict: """Extract component dictionary from an array. - Raises - ------ - ValueError - If the last dimension of the quantity doesn't match a known Cartesian - chart (0D, 1D, 2D, or 3D). - - Examples - -------- >>> import coordinax.main as cx >>> import jax.numpy as jnp >>> arr = jnp.array([1.0, 2.0, 3.0]) @@ -343,8 +291,6 @@ def cdict(obj: ArrayLike, unit: u.AbstractUnit | str | UnitsMatrix, /) -> CDict: dimension. The appropriate Cartesian chart is determined from the last dimension of the quantity. - Examples - -------- >>> import coordinax.charts as cxc >>> import jax.numpy as jnp >>> arr = jnp.array([1.0, 2.0, 3.0]) diff --git a/src/coordinax/_src/charts/register_guess.py b/src/coordinax/_src/charts/register_guess.py index 0495fae0..33316218 100644 --- a/src/coordinax/_src/charts/register_guess.py +++ b/src/coordinax/_src/charts/register_guess.py @@ -59,8 +59,6 @@ def guess_chart(obj: frozenset[str], /) -> AbstractChart: the first matching chart it finds. Since the function is cached, the result will be consistent across calls. - Examples - -------- >>> import coordinax.charts as cxc >>> d = {"x": 1.0, "y": 2.0, "z": 3.0} >>> chart = cxc.guess_chart(d) @@ -96,8 +94,6 @@ def guess_chart(obj: CDict, /) -> AbstractChart: the first matching chart it finds. Since the function is cached, the result will be consistent across calls. - Examples - -------- >>> import coordinax.charts as cxc >>> d = {"x": 1.0, "y": 2.0, "z": 3.0} >>> chart = cxc.guess_chart(d) @@ -116,8 +112,6 @@ def guess_chart( ) -> AbstractChart: """Infer a 1D Cartesian chart from last dimension of a value / quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cx >>> q = u.Q([1.0], "m") @@ -134,8 +128,6 @@ def guess_chart( ) -> AbstractChart: """Infer a 2D Cartesian chart from last dimension of a value / quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cx >>> q = u.Q([1.0, 2.0], "m") @@ -152,8 +144,6 @@ def guess_chart( ) -> AbstractChart: """Infer a 3D Cartesian chart from last dimension of a value / quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cx >>> q = u.Q([1.0, 2.0, 3.0], "m") @@ -170,8 +160,6 @@ def guess_chart( ) -> AbstractChart: """Infer a N-dimensional Cartesian chart from last dimension of a value / quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cx >>> q = u.Q([1.0, 2.0, 3.0, 4.0], "m") diff --git a/src/coordinax/_src/charts/register_ptmap.py b/src/coordinax/_src/charts/register_ptmap.py index 70f1476a..8731488c 100644 --- a/src/coordinax/_src/charts/register_ptmap.py +++ b/src/coordinax/_src/charts/register_ptmap.py @@ -35,7 +35,7 @@ from coordinax._src.custom_types import CDict, OptUSys from coordinax._src.euclidean import RN, EuclideanManifold, Rn from coordinax._src.utils import uconvert_to_rad -from coordinax.internal import QuantityMatrix, UnitsMatrix, cdict_units +from coordinax.internal import QMatrix, UnitsMatrix, cdict_units @final @@ -83,7 +83,7 @@ def pt_map(q: None, /, *fixed_args: Any, **fixed_kw: Any) -> Callable[..., Any]: >>> p = u.Q([1.0, 0.0, 0.0], "m") >>> map = cxc.pt_map(None, cxc.cart3d, cxc.sph3d) >>> map(p) - QuantityMatrix([1. , 1.57079633, 0. ], '(m, rad, rad)') + QMatrix([1. , 1.57079633, 0. ], '(m, rad, rad)') Array-Like inputs are interpreted as Cartesian coordinates with units from the required `unxt.AbstractUnitSystem`. @@ -132,7 +132,7 @@ def pt_map( >>> p = u.Q([1.0, 0.0, 0.0], "m") >>> map = cxc.pt_map(cxc.cart3d, cxc.sph3d) >>> map(p) - QuantityMatrix([1. , 1.57079633, 0. ], '(m, rad, rad)') + QMatrix([1. , 1.57079633, 0. ], '(m, rad, rad)') Array-Like inputs are interpreted as Cartesian coordinates with units from the required `unxt.AbstractUnitSystem`. @@ -1577,10 +1577,10 @@ def pt_map( /, *, usys: OptUSys = None, -) -> QuantityMatrix: - """Transform a QuantityMatrix between charts. +) -> QMatrix: + """Transform a QMatrix between charts. - Converts the components of a QuantityMatrix from one chart to another, + Converts the components of a QMatrix from one chart to another, preserving the matrix structure with potentially different units per component. >>> import coordinax.charts as cxc @@ -1617,8 +1617,8 @@ def pt_map( p_to = cxcapi.pt_map(p_dict, from_M, from_chart, to_M, to_chart, usys=usys) p_to = cast("dict[str, u.AbstractQuantity]", p_to) - # Stack the transformed components into an QuantityMatrix - p_out = QuantityMatrix( + # Stack the transformed components into an QMatrix + p_out = QMatrix( jnp.stack([u.ustrip(p_to[k]) for k in to_chart.components], axis=-1), unit=UnitsMatrix(cdict_units(p_to, to_chart.components)), ) diff --git a/src/coordinax/_src/charts/scale_factors.py b/src/coordinax/_src/charts/scale_factors.py index ba182f8a..6571408d 100644 --- a/src/coordinax/_src/charts/scale_factors.py +++ b/src/coordinax/_src/charts/scale_factors.py @@ -15,38 +15,36 @@ from .dn import CartND from coordinax._src.base import AbstractDimensionalFlag from coordinax._src.custom_types import CDict, OptUSys -from coordinax._src.euclidean import EuclideanMetric -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.euclidean import FlatMetric +from coordinax.internal import QMatrix, UnitsMatrix DMLS = u.unit("") @plum.dispatch def scale_factors( - metric: EuclideanMetric, + metric: FlatMetric, chart: Cart0D | Cart1D | Cart2D | Cart3D | CartND, /, *, at: CDict, usys: OptUSys = None, -) -> QuantityMatrix: +) -> QMatrix: """Fast path for Euclidean metrics in Cartesian charts. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> metric = cxm.EuclideanMetric(3) + >>> metric = cxm.FlatMetric(3) >>> at = { ... "x": u.Q(jnp.array(1.0), "m"), ... "y": u.Q(jnp.array(2.0), "m"), ... "z": u.Q(jnp.array(3.0), "m"), ... } >>> cxm.scale_factors(metric, cxc.cart3d, at=at) - QuantityMatrix([1., 1., 1.], '(, , )') + QMatrix([1., 1., 1.], '(, , )') """ del metric, at, usys @@ -55,6 +53,6 @@ def scale_factors( if isinstance(chart, AbstractDimensionalFlag) else len(chart.components) ) - return QuantityMatrix( + return QMatrix( jnp.ones((n,)), unit=UnitsMatrix(tuple(u.unit("") for _ in range(n))) ) diff --git a/src/coordinax/_src/internal/dtype_utils.py b/src/coordinax/_src/internal/dtype_utils.py index 1fc4a1e1..3e98ad6f 100644 --- a/src/coordinax/_src/internal/dtype_utils.py +++ b/src/coordinax/_src/internal/dtype_utils.py @@ -34,8 +34,6 @@ def tree_cast_int_bool_to_float(tree: PyTree[NumericLeaf], /) -> PyTree[InexactL This intentionally does not cast complex leaves, which prevents silent imaginary-part loss. - Examples - -------- >>> import jax.numpy as jnp >>> from coordinax.internal import tree_cast_int_bool_to_float diff --git a/src/coordinax/_src/internal/pack_utils.py b/src/coordinax/_src/internal/pack_utils.py index b30187f6..43199ea3 100644 --- a/src/coordinax/_src/internal/pack_utils.py +++ b/src/coordinax/_src/internal/pack_utils.py @@ -25,7 +25,7 @@ from unxt.quantity import AllowValue from .custom_types import CDict, CKey -from .quantity_matrix import QuantityMatrix +from .quantity_matrix import QMatrix DMLS: Final = u.unit("") @@ -119,12 +119,12 @@ def pack_with_usys( def pack_to_qmatrix( p: CDict, /, keys: tuple[CKey, ...] | None = None -) -> Array | QuantityMatrix: - """Pack a component dictionary into a QuantityMatrix or plain Array. +) -> Array | QMatrix: + """Pack a component dictionary into a QMatrix or plain Array. Components are ordered according to ``keys``. If the values are {class}`~unxt.AbstractQuantity`, a 1-D - {class}`~coordinax.internal.QuantityMatrix` is returned with per-component + {class}`~coordinax.internal.QMatrix` is returned with per-component units. If the values are plain arrays, a stacked JAX array is returned. Parameters @@ -136,7 +136,7 @@ def pack_to_qmatrix( Returns ------- - Array | QuantityMatrix + Array | QMatrix Packed representation of the component dictionary. Examples @@ -147,7 +147,7 @@ def pack_to_qmatrix( >>> p = {"x": u.Q(1.0, "km"), "y": u.Q(2.0, "km"), "z": u.Q(3.0, "km")} >>> pack_to_qmatrix(p, ("x", "y", "z")) - QuantityMatrix([1., 2., 3.], '(km, km, km)') + QMatrix([1., 2., 3.], '(km, km, km)') """ # Dict sorter @@ -160,5 +160,5 @@ def pack_to_qmatrix( vals = [ u.ustrip(AllowValue, unit, p[k]) for k, unit in zip(keys, units, strict=True) ] - # Return as QuantityMatrix - return QuantityMatrix(jnp.stack(vals, axis=-1), unit=units) + # Return as QMatrix + return QMatrix(jnp.stack(vals, axis=-1), unit=units) diff --git a/src/coordinax/_src/utils.py b/src/coordinax/_src/utils.py index 58c88b52..5962134e 100644 --- a/src/coordinax/_src/utils.py +++ b/src/coordinax/_src/utils.py @@ -24,8 +24,6 @@ def uconvert_to_rad(value: u.AbstractQuantity, usys: OptUSys, /) -> BareQuantity def uconvert_to_rad(value: Any, usys: OptUSys, /) -> Any: """Convert an angle value to radians, handling no-usys case. - Examples - -------- Angular quantities are converted from their own unit: >>> import quaxed.numpy as jnp diff --git a/src/coordinax/distances/_src/measures.py b/src/coordinax/distances/_src/measures.py index fc83d1f8..905ceeec 100644 --- a/src/coordinax/distances/_src/measures.py +++ b/src/coordinax/distances/_src/measures.py @@ -69,11 +69,8 @@ def __check_init__(self) -> None: def from_(cls: type[Distance], value: ArrayLike, unit: Any, /, **kw: Any) -> Distance: """Construct a distance. - Examples - -------- >>> import unxt as u >>> import coordinax.distances as cxd - >>> cxd.Distance.from_(1, "kpc") Distance(1, 'kpc') @@ -85,8 +82,6 @@ def from_(cls: type[Distance], value: ArrayLike, unit: Any, /, **kw: Any) -> Dis def from_(cls: type[Distance], d: Distance, /, **kw: Any) -> Distance: """Compute distance from distance. - Examples - -------- >>> import unxt as u >>> import coordinax.distances as cxd @@ -107,11 +102,8 @@ def from_(cls: type[Distance], d: Distance, /, **kw: Any) -> Distance: def from_(cls: type[Distance], d: u.Q["length"], /, **kw: Any) -> Distance: """Compute distance from distance. - Examples - -------- >>> import unxt as u >>> import coordinax.distances as cxd - >>> q = u.Q(1, "kpc") >>> cxd.Distance.from_(q, dtype=float) Distance(1., 'kpc') @@ -125,8 +117,6 @@ def from_(cls: type[Distance], d: u.Q["length"], /, **kw: Any) -> Distance: def from_(cls: type[Distance], p: u.Q["angle"], /, **kw: Any) -> Distance: """Compute distance from parallax. - Examples - -------- >>> import unxt as u >>> import coordinax.distances as cxd @@ -144,8 +134,6 @@ def from_(cls: type[Distance], p: u.Q["angle"], /, **kw: Any) -> Distance: def from_(cls: type[Distance], dm: u.Q["mag"], /, **kw: Any) -> Distance: """Compute distance from distance modulus. - Examples - -------- >>> import unxt as u >>> import coordinax.distances as cxd diff --git a/src/coordinax/distances/_src/register_converters.py b/src/coordinax/distances/_src/register_converters.py index 658d72a8..1ed42945 100644 --- a/src/coordinax/distances/_src/register_converters.py +++ b/src/coordinax/distances/_src/register_converters.py @@ -15,8 +15,6 @@ def convert_quantity_to_distance(q: u.AbstractQuantity, /) -> Distance: """Convert any quantity to a Distance. - Examples - -------- >>> from plum import convert >>> from unxt.quantity import BareQuantity >>> from coordinax.distances import Distance diff --git a/src/coordinax/distances/_src/register_primitives.py b/src/coordinax/distances/_src/register_primitives.py index c6d7dbed..4415d43e 100644 --- a/src/coordinax/distances/_src/register_primitives.py +++ b/src/coordinax/distances/_src/register_primitives.py @@ -25,8 +25,6 @@ def atan2_p_abstractdistances(x: AbstractDistance, y: AbstractDistance, /) -> u.Q: """Arctangent2 of two distances degrades to a quantity. - Examples - -------- >>> import quaxed.numpy as jnp >>> from coordinax.distances import Distance @@ -49,8 +47,6 @@ def atan2_p_abstractdistances(x: AbstractDistance, y: AbstractDistance, /) -> u. def cbrt_p_abstractdistance(x: AbstractDistance, /, *, accuracy: Any) -> BareQuantity: """Cube root of a distance. - Examples - -------- >>> import quaxed.numpy as jnp >>> from coordinax.distances import Distance >>> d = Distance(8, "m") @@ -69,8 +65,6 @@ def cbrt_p_abstractdistance(x: AbstractDistance, /, *, accuracy: Any) -> BareQua def div_p_abstractdistances(x: AbstractDistance, y: AbstractDistance, /) -> u.Q: """Division of two Distances. - Examples - -------- >>> import quaxed.numpy as jnp >>> from coordinax.distances import Distance @@ -92,8 +86,6 @@ def dot_general_p_abstractdistances( ) -> BareQuantity: """Dot product of two Distances. - Examples - -------- This is a dot product of two Distances. >>> import quaxed.numpy as jnp @@ -131,8 +123,6 @@ def dot_general_p_abstractdistances( def integer_pow_p_abstractdistance(x: AbstractDistance, /, *, y: Any) -> BareQuantity: """Integer power of a Distance. - Examples - -------- >>> from coordinax.distances import Distance >>> q = Distance(2, "m") >>> q ** 3 @@ -149,8 +139,6 @@ def integer_pow_p_abstractdistance(x: AbstractDistance, /, *, y: Any) -> BareQua def neg_p_distance(x: Distance, /) -> u.Q: """Negation of a Distance degrades to a Quantity. - Examples - -------- >>> from coordinax.distances import Distance >>> q = Distance(10, "m") >>> -q @@ -169,8 +157,6 @@ def pow_p_abstractdistance_arraylike( ) -> BareQuantity: """Power of a Distance by redispatching to Quantity. - Examples - -------- >>> import math >>> from coordinax.distances import Distance @@ -191,8 +177,6 @@ def pow_p_abstractdistance_arraylike( def sqrt_p_abstractdistance(x: AbstractDistance, /, *, accuracy: Any) -> BareQuantity: """Square root of a quantity. - Examples - -------- >>> import quaxed.numpy as jnp >>> from coordinax.distances import Distance diff --git a/src/coordinax/distances/_src/register_unxt.py b/src/coordinax/distances/_src/register_unxt.py index 21a58564..7115343b 100644 --- a/src/coordinax/distances/_src/register_unxt.py +++ b/src/coordinax/distances/_src/register_unxt.py @@ -16,11 +16,8 @@ def dimension_of(obj: type[AbstractDistance], /) -> u.AbstractDimension: """Get the dimension of an angle. - Examples - -------- >>> import unxt as u >>> import coordinax.distances as cxd - >>> u.dimension_of(cxd.AbstractDistance) PhysicalType('length') diff --git a/src/coordinax/frames/_src/base.py b/src/coordinax/frames/_src/base.py index 0cb8f077..d45359c2 100644 --- a/src/coordinax/frames/_src/base.py +++ b/src/coordinax/frames/_src/base.py @@ -110,8 +110,6 @@ def from_( ) -> AbstractReferenceFrame: """Construct a reference frame from a mapping. - Examples - -------- >>> import coordinax.frames as cxf >>> alice = cxf.Alice.from_({}) diff --git a/src/coordinax/frames/_src/example.py b/src/coordinax/frames/_src/example.py index e76351c6..98f378e6 100644 --- a/src/coordinax/frames/_src/example.py +++ b/src/coordinax/frames/_src/example.py @@ -101,13 +101,9 @@ def frame_transition( ) -> cxfm.Identity: """Return an identity operator for frames that are the same. - Examples - -------- >>> import coordinax.frames as cxf - >>> cxf.frame_transition(cxf.alice, cxf.alice) Identity() - >>> cxf.frame_transition(cxf.alex, cxf.alex) Identity() @@ -119,11 +115,8 @@ def frame_transition( def frame_transition(from_frame: Alice, to_frame: Alex, /) -> cxfm.Composed: """Transform from Alice's frame to Alex's frame. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx - >>> op = cxf.frame_transition(cxf.alice, cxf.alex) >>> print(op) Composed(( Translate(...), Rotate(...) )) @@ -140,11 +133,8 @@ def frame_transition( ) -> cxfm.Composed: """Transform back. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx - >>> op = cxf.frame_transition(cxf.alex, cxf.alice) >>> print(op) Composed(( Rotate(...), Translate(...) )) diff --git a/src/coordinax/frames/_src/register_pfxm.py b/src/coordinax/frames/_src/register_pfxm.py index 975b4eb8..5356e1a0 100644 --- a/src/coordinax/frames/_src/register_pfxm.py +++ b/src/coordinax/frames/_src/register_pfxm.py @@ -21,11 +21,8 @@ def frame_transition(from_frame: NoFrame, to_frame: NoFrame, /) -> cxfm.Identity frame-agnostic) there is nothing to transform, so the result is the identity operation. - Examples - -------- >>> import coordinax.frames as cxf >>> import coordinax.transforms as cxfm - >>> op = cxf.frame_transition(cxf.noframe, cxf.noframe) >>> isinstance(op, cxfm.Identity) True @@ -40,10 +37,7 @@ def frame_transition( ) -> NoReturn: """Cannot transform from the null frame. - Examples - -------- >>> import coordinax.frames as cxf - >>> try: ... cxf.frame_transition(cxf.noframe, cxf.alice) ... except cxf.FrameTransformError as e: @@ -61,10 +55,7 @@ def frame_transition( ) -> NoReturn: """Cannot transform to the null frame. - Examples - -------- >>> import coordinax.frames as cxf - >>> try: ... cxf.frame_transition(cxf.alice, cxf.noframe) ... except cxf.FrameTransformError as e: diff --git a/src/coordinax/frames/_src/xfm.py b/src/coordinax/frames/_src/xfm.py index f34174aa..53980e83 100644 --- a/src/coordinax/frames/_src/xfm.py +++ b/src/coordinax/frames/_src/xfm.py @@ -135,8 +135,6 @@ def frame_transition( ) -> AbstractTransform: """Return a frame transform operator to a transformed frame. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.vectors as cxv >>> import coordinax.frames as cxf @@ -168,8 +166,6 @@ def frame_transition( ) -> AbstractTransform: """Return a frame transform operator from a transformed frame. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.vectors as cxv >>> import coordinax.frames as cxf @@ -206,8 +202,6 @@ def frame_transition( When ``from_frame`` and ``to_frame`` are the same object the result is the identity transform. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.vectors as cxv >>> import coordinax.frames as cxf diff --git a/src/coordinax/representations/_src/basis_change.py b/src/coordinax/representations/_src/basis_change.py index 1b935005..93b9ba19 100644 --- a/src/coordinax/representations/_src/basis_change.py +++ b/src/coordinax/representations/_src/basis_change.py @@ -13,6 +13,7 @@ import unxt as u from unxt.quantity import BareQuantity, is_any_quantity +import coordinax.api.manifolds as cxmapi import coordinax.api.representations as cxrapi import coordinax.charts as cxc import coordinax.manifolds as cxm @@ -26,9 +27,10 @@ from .custom_types import CDict, OptUSys from .geom import TangentGeometry from .rep import Representation -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax._src.metric.matrix import DenseMetric, DiagonalMetric +from coordinax.internal import QMatrix, UnitsMatrix -T = TypeVar("T", bound=u.Quantity) +T = TypeVar("T", bound=u.Q) _RAD = u.unit("rad") @@ -43,7 +45,7 @@ def _add_rad_unit(q: ArrayLike | u.AbstractQuantity) -> ArrayLike | u.AbstractQu return BareQuantity(q.value, unit=q.unit * _RAD) if is_any_quantity(q) else q -def _qm_triangular_solve(E: QuantityMatrix, b: QuantityMatrix) -> QuantityMatrix: +def _qm_triangular_solve(E: QMatrix, b: QMatrix) -> QMatrix: """Solve upper-triangular system E @ x = b for x, respecting units. Uses the fact that E is upper-triangular (vielbein = L^T from Cholesky). @@ -67,18 +69,18 @@ def _qm_triangular_solve(E: QuantityMatrix, b: QuantityMatrix) -> QuantityMatrix x_vals = jax.scipy.linalg.solve_triangular( E_norm, b_norm[..., None], lower=False ).squeeze(-1) - return QuantityMatrix(x_vals, unit=x_units) + return QMatrix(x_vals, unit=x_units) ############################################################################## -# With a metric +# With a manifold @plum.dispatch def change_basis( v: CDict, chart: cxc.AbstractChart, - metric: cxm.AbstractMetric, + M: cxm.AbstractManifold, from_basis: CoordinateBasis, to_basis: PhysicalBasis, /, @@ -86,59 +88,66 @@ def change_basis( at: CDict, usys: OptUSys = None, ) -> CDict: - r"""Change from coordinate basis to physical basis using a general metric. + r"""Change from coordinate basis to physical basis using a manifold. - This overload handles any metric that is **not** an - `~coordinax.manifolds.AbstractDiagonalMetric` (e.g. - `~coordinax.manifolds.InducedMetric`). The algorithm uses the Cholesky - vielbein $E = L^\top$. + Retrieves the manifold's metric and applies the appropriate transformation. + For diagonal metrics (e.g. `coordinax.manifolds.FlatMetric` in orthogonal + charts) the fast scale-factor path is taken; for general metrics (e.g. + `coordinax.manifolds.PullbackMetric`) the Cholesky vielbein $E = L^\top$ is + used. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm >>> import coordinax.representations as cxr - ``InducedMetric`` is an ``AbstractMetric`` but not an - ``AbstractDiagonalMetric``, so this dispatch is selected: - - >>> embed_map = cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")) - >>> metric = cxm.InducedMetric(embed_map, cxm.EuclideanMetric(3)) - - >>> v = {"theta": u.Q(1.0, "rad/s"), "phi": u.Q(2.0, "rad/s")} - >>> at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0.0, "rad")} - >>> cxr.change_basis(v, cxc.sph2, metric, cxr.coord_basis, cxr.phys_basis, at=at) - {'theta': Q(..., 'km / s'), 'phi': Q(..., 'km / s')} + Euclidean 3-D manifold in spherical coordinates (diagonal metric): - >>> v = {"theta": 1.0, "phi": 2.0} - >>> at = {"theta": jnp.pi / 3, "phi": 0.0} - >>> usys = u.unitsystems.si - >>> cxr.change_basis(v, cxc.sph2, metric, cxr.coord_basis, cxr.phys_basis, - ... at=at, usys=usys) - {'theta': Q(1., 'km'), 'phi': Q(1.73205081, 'km')} + >>> M3 = cxm.R3 + >>> v = {"r": u.Q(5, "m/s"), "theta": u.Q(1, "rad/s"), "phi": u.Q(2, "rad/s")} + >>> at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0, "rad")} + >>> cxr.change_basis(v, cxc.sph3d, M3, cxr.coord_basis, cxr.phys_basis, at=at) + {'r': Q(5, 'm / s'), 'theta': Q(3, 'm / s'), 'phi': Q(2.87655323, 'm / s')} - Use an embedded two-sphere manifold, whose induced metric is non-diagonal in - the general case: + Embedded two-sphere manifold — non-diagonal + :class:`~coordinax.manifolds.PullbackMetric`: >>> M = cxm.EmbeddedManifold( - ... intrinsic=cxm.HyperSphericalManifold(), - ... ambient=cxm.EuclideanManifold(3), + ... intrinsic=cxm.S2, ambient=cxm.R3, ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), ... ) >>> v = {"theta": u.Q(1.0, "rad/s"), "phi": u.Q(2.0, "rad/s")} >>> at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0.0, "rad")} - >>> cxr.change_basis(v, cxc.sph2, M.metric, cxr.coord_basis, cxr.phys_basis, at=at) + >>> cxr.change_basis(v, cxc.sph2, M, cxr.coord_basis, cxr.phys_basis, at=at) {'theta': Q(..., 'km / s'), 'phi': Q(..., 'km / s')} """ + del from_basis, to_basis # only used for dispatch + at = chart.check_data(at, keys=True) keys = chart.components - # Cholesky vielbein E = L^T, hat_v = E @ v - L = metric.cholesky(chart, at=at, usys=usys) + mm = cxmapi.metric_matrix(M, at, chart) + if isinstance(mm, DiagonalMetric): + h = jnp.sqrt(mm.diagonal) + return {k: h[i] * v[k] for i, k in enumerate(keys)} + # General case: Cholesky vielbein E = L^T, hat_v = E @ v + assert isinstance(mm, DenseMetric) # noqa: S101 + mat = mm.matrix + if isinstance(mat, QMatrix): + L_val = jnp.linalg.cholesky(mat.value) + L_units = UnitsMatrix(mat.unit._units**0.5) + L = QMatrix(L_val, unit=L_units) + else: + L_raw = jnp.linalg.cholesky(mat) + n = mat.shape[-1] + _dmls = u.unit("") + L = QMatrix( + L_raw, + unit=UnitsMatrix(tuple(tuple(_dmls for _ in range(n)) for _ in range(n))), + ) E = jnp.transpose(L, axes=(-2, -1)) # E = L^T, upper-triangular vielbein - v_vec = QuantityMatrix.from_cdict(v, keys) + v_vec = QMatrix.from_cdict(v, keys) hat_v_vec = jnp.matmul(E, v_vec) return cxc.cdict(hat_v_vec, keys) # ty: ignore[invalid-return-type] @@ -147,7 +156,7 @@ def change_basis( def change_basis( v: CDict, chart: cxc.AbstractChart, - metric: cxm.AbstractMetric, + M: cxm.AbstractManifold, from_basis: PhysicalBasis, to_basis: CoordinateBasis, /, @@ -155,38 +164,65 @@ def change_basis( at: CDict, usys: OptUSys = None, ) -> CDict: - r"""Change from physical basis to coordinate basis using a general metric. + r"""Change from physical basis to coordinate basis using a manifold. - This overload handles any metric that is **not** an - `~coordinax.manifolds.AbstractDiagonalMetric` (e.g. - `~coordinax.manifolds.InducedMetric`). The algorithm uses the Cholesky - vielbein $E = L^\top$, solved as a triangular system $v = E^{-1}\hat{v}$. + Retrieves the manifold's metric and applies the inverse transformation. For + diagonal metrics the fast scale-factor path is taken; for general metrics + the Cholesky vielbein $E = L^\top$ is solved as a triangular system $v = + E^{-1}\hat{v}$. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm >>> import coordinax.representations as cxr - ``InducedMetric`` is an ``AbstractMetric`` but not an - ``AbstractDiagonalMetric``, so this dispatch is selected: + Euclidean 3-D manifold in spherical coordinates (diagonal metric): + + >>> M3 = cxm.R3 + >>> v = {"r": u.Q(5, "m/s"), "theta": u.Q(3, "m/s"), "phi": u.Q(2.876553, "m/s")} + >>> at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0.5, "rad")} + >>> cxr.change_basis(v, cxc.sph3d, M3, cxr.phys_basis, cxr.coord_basis, at=at) + {'r': Q(5, 'm / s'), 'theta': Q(1., 'rad / s'), 'phi': Q(1.99999..., 'rad / s')} + + Embedded two-sphere manifold — non-diagonal + :class:`~coordinax.manifolds.PullbackMetric`: - >>> embed_map = cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")) - >>> metric = cxm.InducedMetric(embed_map, cxm.EuclideanMetric(3)) + >>> M = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ... ) >>> v = {"theta": u.Q(1.0, "km/s"), "phi": u.Q(2.0, "km/s")} >>> at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0.0, "rad")} - >>> cxr.change_basis(v, cxc.sph2, metric, cxr.phys_basis, cxr.coord_basis, at=at) + >>> cxr.change_basis(v, cxc.sph2, M, cxr.phys_basis, cxr.coord_basis, at=at) {'theta': Q(..., 'rad / s'), 'phi': Q(..., 'rad / s')} """ + del from_basis, to_basis # only used for dispatch + at = chart.check_data(at, keys=True) keys = chart.components - # Cholesky vielbein E = L^T, v = E^{-1} hat_v (triangular solve) - L = metric.cholesky(chart, at=at, usys=usys) + mm = cxmapi.metric_matrix(M, at, chart) + if isinstance(mm, DiagonalMetric): + h = jnp.sqrt(mm.diagonal) + return {k: v[k] / h[i] for i, k in enumerate(keys)} + # General case: Cholesky vielbein E = L^T, v = E^{-1} hat_v (triangular solve) + assert isinstance(mm, DenseMetric) # noqa: S101 + mat = mm.matrix + if isinstance(mat, QMatrix): + L_val = jnp.linalg.cholesky(mat.value) + L_units = UnitsMatrix(mat.unit._units**0.5) + L = QMatrix(L_val, unit=L_units) + else: + L_raw = jnp.linalg.cholesky(mat) + n = mat.shape[-1] + _dmls = u.unit("") + L = QMatrix( + L_raw, + unit=UnitsMatrix(tuple(tuple(_dmls for _ in range(n)) for _ in range(n))), + ) E = jnp.transpose(L, axes=(-2, -1)) # E = L^T, upper-triangular vielbein - hat_v_vec = QuantityMatrix.from_cdict(v, keys) + hat_v_vec = QMatrix.from_cdict(v, keys) v_vec = _qm_triangular_solve(E, hat_v_vec) return cxc.cdict(v_vec, keys) # ty: ignore[invalid-return-type] @@ -205,104 +241,23 @@ def change_basis( at: CDict, usys: OptUSys = None, ) -> CDict: - r"""Change from physical basis to coordinate basis using a general metric. - - Examples - -------- - >>> import jax.numpy as jnp - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - >>> import coordinax.representations as cxr - - >>> M = cxm.EmbeddedManifold( - ... intrinsic=cxm.HyperSphericalManifold(), - ... ambient=cxm.EuclideanManifold(3), - ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), - ... ) - >>> v = {"theta": u.Q(1.0, "km/s"), "phi": u.Q(2.0, "km/s")} - >>> at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0.0, "rad")} - >>> cxr.change_basis(v, cxc.sph2, M.metric, cxr.phys_basis, cxr.coord_basis, at=at) - {'theta': Q(..., 'rad / s'), 'phi': Q(..., 'rad / s')} - - """ - return cxrapi.change_basis( - v, chart, chart.M.metric, from_basis, to_basis, at=at, usys=usys - ) # ty: ignore[invalid-return-type] - - -# ============================================== -# Diagonal metrics - - -@plum.dispatch -def change_basis( - v: CDict, - chart: cxc.AbstractChart, - metric: cxm.AbstractDiagonalMetric, - from_basis: CoordinateBasis, - to_basis: PhysicalBasis, - /, - *, - at: CDict, - usys: OptUSys = None, -) -> CDict: - r"""Change from coordinate basis to physical basis using a metric. - - Examples - -------- - >>> import unxt as u - >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm - >>> import coordinax.representations as cxr - - Spherical chart, Quantity components: - - >>> metric = cxm.EuclideanMetric(3) - >>> v = {"r": u.Q(5, "m/s"), "theta": u.Q(1, "rad/s"), "phi": u.Q(2, "rad/s")} - >>> at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0, "rad")} - >>> cxr.change_basis(v, cxc.sph3d, metric, cxr.coord_basis, cxr.phys_basis, at=at) - {'r': Q(5, 'm / s'), 'theta': Q(3, 'm / s'), 'phi': Q(2.87655323, 'm / s')} - - """ - at = chart.check_data(at, keys=True) - gdiag = metric.scale_factors(chart, at=at, usys=usys) - return {k: jnp.sqrt(gdiag[i]) * v[k] for i, k in enumerate(chart.components)} - + r"""Change from physical basis to coordinate basis using the chart's manifold. -@plum.dispatch -def change_basis( - v: CDict, - chart: cxc.AbstractChart, - metric: cxm.AbstractDiagonalMetric, - from_basis: PhysicalBasis, - to_basis: CoordinateBasis, - /, - *, - at: CDict, - usys: OptUSys = None, -) -> CDict: - r"""Change from physical basis to coordinate basis using a metric. + Falls back to ``chart.M`` when no explicit manifold is supplied. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc - >>> import coordinax.manifolds as cxm >>> import coordinax.representations as cxr - Spherical chart, Quantity components: - - >>> metric = cxm.EuclideanMetric(3) >>> v = {"r": u.Q(5, "m/s"), "theta": u.Q(3, "m/s"), "phi": u.Q(2.876553, "m/s")} >>> at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0.5, "rad")} - >>> cxr.change_basis(v, cxc.sph3d, metric, cxr.phys_basis, cxr.coord_basis, at=at) - {'r': Q(5, 'm / s'), 'theta': Q(1., 'rad / s'), 'phi': Q(1.99999..., 'rad / s')} + >>> cxr.change_basis(v, cxc.sph3d, cxr.phys_basis, cxr.coord_basis, at=at) + {'r': Q(5, 'm / s'), 'theta': Q(1., 'rad / s'), 'phi': Q(1.99999984, 'rad / s')} """ - at = chart.check_data(at, keys=True) - gdiag = metric.scale_factors(chart, at=at, usys=usys) - return {k: v[k] / jnp.sqrt(gdiag[i]) for i, k in enumerate(chart.components)} + return cxrapi.change_basis( + v, chart, chart.M, from_basis, to_basis, at=at, usys=usys + ) # ty: ignore[invalid-return-type] # ============================================== @@ -324,8 +279,6 @@ def change_basis( possible from an unknown basis, so values are preserved and only the representation-level basis label changes. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import coordinax.representations as cxr @@ -362,8 +315,6 @@ def change_basis( This conversion is only well-defined when all components share the same physical dimension. No numeric transform is applied; values are preserved. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc @@ -418,14 +369,12 @@ def change_basis( $$h_x = h_y = h_z = 1,$$ - so the coordinate basis vectors are already unit vectors - ($\hat{e}_i = \partial_i$) and the transformation matrix is the - identity ($H = I$). The coordinate basis **is** the physical basis, - so this conversion is always the identity map and ``v`` is returned - unchanged regardless of the direction of the conversion. + so the coordinate basis vectors are already unit vectors ($\hat{e}_i = + \partial_i$) and the transformation matrix is the identity ($H = I$). The + coordinate basis **is** the physical basis, so this conversion is always the + identity map and ``v`` is returned unchanged regardless of the direction of + the conversion. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.representations as cxr @@ -457,7 +406,7 @@ def change_basis( def change_basis( v: CDict, chart: cxc.Cart0D | cxc.Cart1D | cxc.Cart2D | cxc.Cart3D | cxc.CartND, - metric: cxm.EuclideanMetric, + M: cxm.AbstractManifold, from_basis: AbstractLinearBasis, to_basis: AbstractLinearBasis, /, @@ -469,43 +418,39 @@ def change_basis( $$h_x = h_y = h_z = 1,$$ - so the coordinate basis vectors are already unit vectors - ($\hat{e}_i = \partial_i$) and the transformation matrix is the - identity ($H = I$). The coordinate basis **is** the physical basis, - so this conversion is always the identity map and ``v`` is returned - unchanged regardless of the direction of the conversion. + so the coordinate basis vectors are already unit vectors ($\hat{e}_i = + \partial_i$) and the transformation matrix is the identity ($H = I$). The + coordinate basis **is** the physical basis, so this conversion is always the + identity map and ``v`` is returned unchanged regardless of the direction of + the conversion. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm >>> import coordinax.representations as cxr Coordinate basis to physical basis in a 2-D Cartesian chart — identity: >>> v = {"x": u.Q(3.0, "m/s"), "y": u.Q(4.0, "m/s")} >>> at = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")} - >>> metric2 = cxm.EuclideanMetric(2) - >>> cxr.change_basis(v, cxc.cart2d, metric2, cxr.coord_basis, cxr.phys_basis, at=at) + >>> M2 = cxm.R2 + >>> cxr.change_basis(v, cxc.cart2d, M2, cxr.coord_basis, cxr.phys_basis, at=at) {'x': Q(3., 'm / s'), 'y': Q(4., 'm / s')} The reverse direction is equally a no-op: - >>> cxr.change_basis(v, cxc.cart2d, metric2, cxr.phys_basis, cxr.coord_basis, at=at) + >>> cxr.change_basis(v, cxc.cart2d, M2, cxr.phys_basis, cxr.coord_basis, at=at) {'x': Q(3., 'm / s'), 'y': Q(4., 'm / s')} Works for any Cartesian dimensionality; ``at`` is optional: >>> v3 = {"x": u.Q(1.0, "m/s"), "y": u.Q(2.0, "m/s"), "z": u.Q(3.0, "m/s")} - >>> metric3 = cxm.EuclideanMetric(3) - >>> cxr.change_basis(v3, cxc.cart3d, metric3, cxr.coord_basis, cxr.phys_basis) + >>> cxr.change_basis(v3, cxc.cart3d, cxm.R3, cxr.coord_basis, cxr.phys_basis) {'x': Q(1., 'm / s'), 'y': Q(2., 'm / s'), 'z': Q(3., 'm / s')} """ - # Check the metric is compatible with the chart - assert metric.ndim == chart.ndim # noqa: S101 - # Re-dispatch to metric-less version. - return cxrapi.change_basis(v, chart, from_basis, to_basis, **kw) # ty: ignore[invalid-return-type] + del chart, M, from_basis, to_basis, kw + return v # ---------------------------------------------- @@ -580,7 +525,7 @@ def change_basis( def change_basis( v: CDict, chart: cxc.Spherical3D, - metric: cxm.EuclideanMetric, + M: cxm.EuclideanManifold, from_basis: CoordinateBasis, to_basis: PhysicalBasis, /, @@ -590,26 +535,26 @@ def change_basis( ) -> CDict: r"""Change from coordinate basis to physical basis in a 3-D spherical chart. - Examples - -------- + Delegates to the chart-specific implementation for + `coordinax.manifolds.EuclideanManifold`. + >>> import unxt as u >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm >>> import coordinax.representations as cxr >>> v = {"r": u.Q(5, "m/s"), "theta": u.Q(1, "rad/s"), "phi": u.Q(2, "rad/s")} >>> at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0, "rad")} - >>> cxr.change_basis(v, cxc.sph3d, cxr.coord_basis, cxr.phys_basis, at=at) + >>> cxr.change_basis(v, cxc.sph3d, cxm.R3, cxr.coord_basis, cxr.phys_basis, at=at) {'r': Q(5, 'm / s'), 'theta': Q(3, 'm / s'), 'phi': Q(2.87655323, 'm / s')} >>> v = {"r": 5, "theta": 1, "phi": 2} # unitless >>> at = {"r": 3, "theta": 0.5, "phi": 0} # unitless - >>> cxr.change_basis(v, cxc.sph3d, cxr.coord_basis, cxr.phys_basis, at=at) + >>> cxr.change_basis(v, cxc.sph3d, cxm.R3, cxr.coord_basis, cxr.phys_basis, at=at) {'r': 5, 'theta': 3, 'phi': Array(2.87655323, dtype=float64, ...)} """ - # Check the metric is compatible with the chart - assert metric.ndim == chart.ndim # noqa: S101 - # Re-dispatch to metric-less version. + del M return cxrapi.change_basis(v, chart, from_basis, to_basis, at=at, usys=usys) # ty: ignore[invalid-return-type] @@ -681,7 +626,7 @@ def change_basis( def change_basis( v: CDict, chart: cxc.Spherical3D, - metric: cxm.EuclideanMetric, + M: cxm.EuclideanManifold, from_basis: PhysicalBasis, to_basis: CoordinateBasis, /, @@ -691,29 +636,26 @@ def change_basis( ) -> CDict: r"""Change from physical basis to coordinate basis in a 3-D spherical chart. - Examples - -------- + Delegates to the chart-specific implementation for + :class:`~coordinax.manifolds.EuclideanManifold`. + >>> import unxt as u >>> import coordinax.charts as cxc - >>> import coordinax.representations as cxr >>> import coordinax.manifolds as cxm - - >>> metric = cxm.EuclideanMetric(3) + >>> import coordinax.representations as cxr >>> v = {"r": u.Q(5, "m/s"), "theta": u.Q(3, "m/s"), "phi": u.Q(2.876553, "m/s")} >>> at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0.5, "rad")} - >>> cxr.change_basis(v, cxc.sph3d, metric, cxr.phys_basis, cxr.coord_basis, at=at) + >>> cxr.change_basis(v, cxc.sph3d, cxm.R3, cxr.phys_basis, cxr.coord_basis, at=at) {'r': Q(5, 'm / s'), 'theta': Q(1., 'rad / s'), 'phi': Q(1.99999984, 'rad / s')} >>> v = {"r": 5, "theta": 3, "phi": 2.876553} # unitless >>> at = {"r": 3, "theta": 0.5, "phi": 0.5} # unitless - >>> cxr.change_basis(v, cxc.sph3d, metric, cxr.phys_basis, cxr.coord_basis, at=at) + >>> cxr.change_basis(v, cxc.sph3d, cxm.R3, cxr.phys_basis, cxr.coord_basis, at=at) {'r': 5, 'theta': 1.0, 'phi': Array(1.99999984, dtype=float64, ...)} """ - # Check the metric is compatible with the chart - assert metric.ndim == chart.ndim # noqa: S101 - # Re-dispatch to metric-less version. + del M return cxrapi.change_basis(v, chart, from_basis, to_basis, at=at, usys=usys) # ty: ignore[invalid-return-type] @@ -738,12 +680,10 @@ def change_basis( """Change basis using source and/or target :class:`Representation` objects. This is a convenience overload: the caller may pass full - :class:`Representation` objects for ``from_rep``/``to_rep`` instead of - bare :class:`AbstractBasis` instances. The basis is extracted from each - argument and the appropriate :func:`change_basis` overload is called. + :class:`Representation` objects for ``from_rep``/``to_rep`` instead of bare + :class:`AbstractBasis` instances. The basis is extracted from each argument + and the appropriate :func:`change_basis` overload is called. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.representations as cxr diff --git a/src/coordinax/representations/_src/core.py b/src/coordinax/representations/_src/core.py index 9122c775..1fce11a4 100644 --- a/src/coordinax/representations/_src/core.py +++ b/src/coordinax/representations/_src/core.py @@ -25,8 +25,6 @@ def cmap(*fixed_args: Any, **fixed_kw: Any) -> Any: """Return a partial function for vector conversion. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -77,8 +75,6 @@ def cmap(*fixed_args: Any, **fixed_kw: Any) -> Any: def cconvert(obj: None, /, *fixed_args: Any, **fixed_kw: Any) -> Any: r"""Return a partial function for vector conversion. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -121,8 +117,6 @@ def cconvert( ) -> Any: r"""Convert point data between charts. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -216,8 +210,6 @@ def cconvert( ) -> Any: r"""Convert point data between charts. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -283,14 +275,11 @@ def cconvert( ) -> Any: r"""Convert point data between charts. - This function delegates to `coordinax.charts.pt_map`. - The representation arguments are checked to ensure they correspond to - canonical point data: + This function delegates to `coordinax.charts.pt_map`. The representation + arguments are checked to ensure they correspond to canonical point data: $$(\mathrm{PointGeometry},\, \mathrm{NoBasis},\, \mathrm{Location}).$$ - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -332,8 +321,6 @@ def cconvert( ) -> Any: r"""Convert tangent data between charts via Jacobian pushforward. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import coordinax.representations as cxr @@ -426,8 +413,6 @@ def add( global Cartesian), ``rhs`` is converted into ``lhs_chart`` and added directly. - Examples - -------- >>> import coordinax.representations as cxr >>> import coordinax.charts as cxc >>> import unxt as u @@ -465,8 +450,6 @@ def subtract( global Cartesian), ``rhs`` is converted into ``lhs_chart`` and subtracted directly. - Examples - -------- >>> import coordinax.representations as cxr >>> import coordinax.charts as cxc >>> import unxt as u diff --git a/src/coordinax/representations/_src/guess.py b/src/coordinax/representations/_src/guess.py index 1bfd33cf..e40eb7db 100644 --- a/src/coordinax/representations/_src/guess.py +++ b/src/coordinax/representations/_src/guess.py @@ -38,10 +38,7 @@ def guess_geometry_kind(obj: AbstractGeometry, /) -> AbstractGeometry: """Infer geometry kind from an AbstractGeometry object. - Examples - -------- >>> import coordinax.representations as cxr - >>> geom = cxr.PointGeometry() >>> cxr.guess_geometry_kind(geom) is geom True @@ -74,8 +71,6 @@ def guess_geometry_kind( ) -> AbstractGeometry: """Infer geometry kind from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -110,8 +105,6 @@ def guess_geometry_kind( def guess_geometry_kind(obj: u.AbstractQuantity, /) -> AbstractGeometry: """Infer geometry kind from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -137,8 +130,6 @@ def guess_geometry_kind(obj: u.AbstractQuantity, /) -> AbstractGeometry: def guess_geometry_kind(obj: CDict, /) -> AbstractGeometry: """Infer geometry kind from the physical dimensions of a component dictionary. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -173,8 +164,6 @@ def guess_geometry_kind(obj: CDict, /) -> AbstractGeometry: def guess_geometry_kind(obj: CDict, chart: cxc.AbstractChart, /) -> AbstractGeometry: """Infer geometry kind from the physical dimensions of a component dictionary and chart. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr >>> import coordinax.charts as cxc @@ -197,8 +186,6 @@ def guess_geometry_kind( ) -> AbstractGeometry: """Infer geometry kind from the physical dimensions of a component dictionary and chart. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr >>> import coordinax.charts as cxc @@ -232,10 +219,7 @@ def guess_geometry_kind( def guess_basis_kind(obj: AbstractBasis, /) -> AbstractBasis: """Infer basis kind from an AbstractBasis object. - Examples - -------- >>> import coordinax.representations as cxr - >>> basis = cxr.NoBasis() >>> cxr.guess_basis_kind(basis) is basis True @@ -260,8 +244,6 @@ def guess_basis_kind( ) -> AbstractBasis: """Infer basis kind from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -284,11 +266,8 @@ def guess_basis_kind( def guess_basis_kind(obj: u.AbstractQuantity, /) -> AbstractBasis: """Infer basis kind from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr - >>> cxr.guess_basis_kind(u.Q(1.0, "m")) no_basis @@ -302,8 +281,6 @@ def guess_basis_kind(obj: u.AbstractQuantity, /) -> AbstractBasis: def guess_basis_kind(obj: CDict, /) -> AbstractBasis: """Infer basis kind from the physical dimensions of a component dictionary. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -335,10 +312,7 @@ def guess_basis_kind(obj: CDict, /) -> AbstractBasis: def guess_semantic_kind(obj: AbstractSemanticKind, /) -> AbstractSemanticKind: """Infer semantic kind from an AbstractSemanticKind object. - Examples - -------- >>> import coordinax.representations as cxr - >>> sem = cxr.Location() >>> cxr.guess_semantic_kind(sem) is sem True @@ -370,8 +344,6 @@ def guess_semantic_kind( ) -> AbstractSemanticKind: """Infer semantic kind from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -406,11 +378,8 @@ def guess_semantic_kind( def guess_semantic_kind(obj: u.AbstractQuantity, /) -> AbstractSemanticKind: """Infer semantic kind from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr - >>> cxr.guess_semantic_kind(u.Q(1.0, "m")) loc @@ -424,8 +393,6 @@ def guess_semantic_kind(obj: u.AbstractQuantity, /) -> AbstractSemanticKind: def guess_semantic_kind(obj: CDict, /) -> AbstractSemanticKind: """Infer semantic kind from the physical dimensions of a component dictionary. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -459,11 +426,8 @@ def guess_semantic_kind(obj: CDict, /) -> AbstractSemanticKind: def guess_rep(obj: Representation, /) -> Representation: """Infer representation from a Representation object. - Examples - -------- >>> import coordinax.representations as cxr - - >>> rep = cxr.Representation(cxr.PointGeometry(), cxr.NoBasis(), cxr.Location()) + >>> rep = cxr.point >>> cxr.guess_rep(rep) is rep True @@ -475,10 +439,7 @@ def guess_rep(obj: Representation, /) -> Representation: def guess_rep(obj: PointGeometry, /) -> Representation: """Infer representation from a PointGeometry object. - Examples - -------- >>> import coordinax.representations as cxr - >>> geom = cxr.PointGeometry() >>> cxr.guess_rep(geom) point @@ -498,11 +459,8 @@ def guess_rep( ) -> Representation: """Infer point representation from data and an already-inferred PointGeometry. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr - >>> rep = cxr.guess_rep(u.dimension("length"), cxr.point_geom) >>> rep point @@ -522,8 +480,6 @@ def guess_rep( ) -> Representation: """Infer tangent representation from data and an already-inferred TangentGeometry. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -554,8 +510,6 @@ def guess_rep( ) -> Representation: """Infer representation from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -595,8 +549,6 @@ def guess_rep( def guess_rep(obj: Any, chart: cxc.AbstractChart, /) -> Representation: """Infer representation from the physical dimension of a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -620,8 +572,6 @@ def guess_rep( ) -> Representation: """Infer representation. - Examples - -------- >>> import unxt as u >>> import coordinax.representations as cxr @@ -639,8 +589,6 @@ def guess_rep( ) -> Representation: """Infer representation. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.representations as cxr diff --git a/src/coordinax/representations/_src/register_cx.py b/src/coordinax/representations/_src/register_cx.py index 6289a11f..4da30db8 100644 --- a/src/coordinax/representations/_src/register_cx.py +++ b/src/coordinax/representations/_src/register_cx.py @@ -29,8 +29,6 @@ def pt_map( ) -> Any: r"""Convert point data between charts. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -119,8 +117,6 @@ def pt_map( ) -> Any: r"""Convert point data between charts. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr @@ -193,8 +189,6 @@ def pt_map( ) -> Any: r"""Convert point data between charts. - Examples - -------- Convert a point from Cartesian coordinates to spherical coordinates: >>> import coordinax.representations as cxr diff --git a/src/coordinax/representations/_src/tangent_map.py b/src/coordinax/representations/_src/tangent_map.py index 39bfd919..f9a6311b 100644 --- a/src/coordinax/representations/_src/tangent_map.py +++ b/src/coordinax/representations/_src/tangent_map.py @@ -17,7 +17,7 @@ from .custom_types import CDict, OptUSys from .geom import TangentGeometry from .rep import Representation -from coordinax.internal import QuantityMatrix, pack_nonuniform_unit +from coordinax.internal import QMatrix, pack_nonuniform_unit # --------------------------------------------------------------------------- # Validation helpers @@ -40,28 +40,28 @@ def _check_linear_basis(rep: Representation, label: str) -> None: # --------------------------------------------------------------------------- -# Shared helper: apply a QuantityMatrix Jacobian to a tangent vector CDict +# Shared helper: apply a QMatrix Jacobian to a tangent vector CDict # --------------------------------------------------------------------------- def _apply_jac( - J: Array | QuantityMatrix, + J: Array | QMatrix, from_components: tuple[str, ...], to_components: tuple[str, ...], v: CDict, ) -> CDict: - """Apply a 2-D QuantityMatrix Jacobian to a tangent CDict. + """Apply a 2-D QMatrix Jacobian to a tangent CDict. If the components of ``v`` are plain arrays, the output is a plain-array CDict (using ``J.value @ v_arr``). If any component of ``v`` is a {class}`~unxt.AbstractQuantity`, ``v`` is packed into a 1-D - {class}`~coordinax.internal.QuantityMatrix` and the result is computed + {class}`~coordinax.internal.QMatrix` and the result is computed via ``qnp.matmul(J, v_qm)``, which handles per-element unit conversion. Parameters ---------- J - QuantityMatrix of shape ``(n_out, n_in)`` returned by ``jac_pt_map``. + QMatrix of shape ``(n_out, n_in)`` returned by ``jac_pt_map``. from_components Ordered component names for the input chart (columns of J). to_components @@ -78,15 +78,15 @@ def _apply_jac( """ if isinstance(v[from_components[0]], u.AbstractQuantity): v_arr, v_units = pack_nonuniform_unit(v, keys=from_components) - v_qm = QuantityMatrix(v_arr, unit=v_units) - w = qnp.matmul(J, v_qm) # (n_out,) QuantityMatrix + v_qm = QMatrix(v_arr, unit=v_units) + w = qnp.matmul(J, v_qm) # (n_out,) QMatrix return {key: u.Q(w.value[i], w.unit[i]) for i, key in enumerate(to_components)} v_arr = jnp.stack([jnp.asarray(v[k]) for k in from_components]) - # When J is a QuantityMatrix, use J.value to avoid the Quax fallback path - # that returns a QuantityMatrix with J's own 2D unit structure (wrong). + # When J is a QMatrix, use J.value to avoid the Quax fallback path + # that returns a QMatrix with J's own 2D unit structure (wrong). # Plain-array velocity is dimensionless, so numeric-only application is correct. - j_arr = J.value if isinstance(J, QuantityMatrix) else J + j_arr = J.value if isinstance(J, QMatrix) else J result = j_arr @ v_arr return {key: result[i] for i, key in enumerate(to_components)} @@ -112,8 +112,6 @@ def tangent_map( Applies the Jacobian of the chart transition map to the tangent vector components ``v``, evaluated at the base point ``at``. - Examples - -------- Convert a tangent vector from Cartesian to polar 2D at the point (1, 0): >>> import jax.numpy as jnp @@ -154,8 +152,6 @@ def tangent_map( 2. apply the chart Jacobian pushforward, 3. convert target components back to physical basis. - Examples - -------- Convert a physical-basis tangent vector from Cartesian to spherical 3D: >>> import unxt as u @@ -215,8 +211,6 @@ def tangent_map( Applies the Jacobian of the chart transition map to the tangent vector components ``v``, evaluated at the base point ``at``. - Examples - -------- Convert a tangent vector from Cartesian to polar 2D at the point (1, 0): >>> import jax.numpy as jnp @@ -256,8 +250,6 @@ def tangent_map( Applies the Jacobian of the chart transition map to the tangent vector components ``v``, evaluated at the base point ``at``. - Examples - -------- Convert a tangent vector from Cartesian to polar 2D at the point (1, 0): >>> import jax.numpy as jnp diff --git a/src/coordinax/transforms/_src/actions/add.py b/src/coordinax/transforms/_src/actions/add.py index 9595d557..6524453c 100644 --- a/src/coordinax/transforms/_src/actions/add.py +++ b/src/coordinax/transforms/_src/actions/add.py @@ -160,8 +160,6 @@ def __str__(self) -> str: def from_(cls: type[AbstractAdd], obj: AbstractAdd, /) -> AbstractAdd: """Construct a AbstractAdd from another AbstractAdd. - Examples - -------- >>> import coordinax.main as cx >>> shift1 = cxfm.Translate.from_([1, 2, 3], "km") >>> cxfm.Translate.from_(shift1) is shift1 @@ -177,11 +175,8 @@ def from_(cls: type[AbstractAdd], obj: AbstractAdd, /) -> AbstractAdd: def from_(cls: type[AbstractAdd], q: u.AbstractQuantity, /) -> AbstractAdd: """Construct an AbstractAdd subclass from a Quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.transforms as cxfm - >>> cxfm.Translate.from_(u.Q([1, 2, 3], "km")) Translate( {'x': Q(1, 'km'), 'y': Q(2, 'km'), 'z': Q(3, 'km')}, chart=Cart3D(M=Rn(3)) @@ -197,10 +192,7 @@ def from_(cls: type[AbstractAdd], q: u.AbstractQuantity, /) -> AbstractAdd: def from_(cls: type[AbstractAdd], x: ArrayLike, unit: str) -> AbstractAdd: """Construct an Add operator from an array-like offset and unit. - Examples - -------- >>> import coordinax.transforms as cxfm - >>> cxfm.Translate.from_([1, 2, 3], "km") Translate( {'x': Q(1, 'km'), 'y': Q(2, 'km'), 'z': Q(3, 'km')}, chart=Cart3D(M=Rn(3)) @@ -220,8 +212,6 @@ def simplify(op: AbstractAdd, /, **kw: Any) -> AbstractAdd | Identity: A translation with zero delta simplifies to Identity. - Examples - -------- >>> import coordinax.transforms as cxfm >>> op = cxfm.Translate.from_([1, 2, 3], "km") diff --git a/src/coordinax/transforms/_src/actions/base.py b/src/coordinax/transforms/_src/actions/base.py index f9eeb738..10db94fa 100644 --- a/src/coordinax/transforms/_src/actions/base.py +++ b/src/coordinax/transforms/_src/actions/base.py @@ -240,8 +240,6 @@ def from_( def from_(cls: type[AbstractTransform], obj: Mapping[str, Any], /) -> AbstractTransform: """Construct from a mapping. - Examples - -------- >>> import coordinax.transforms as cxfm >>> cxfm.Composed.from_({"transforms": (cxfm.Identity(), cxfm.Identity())}) Composed((Identity(), Identity())) @@ -259,10 +257,7 @@ def from_( ) -> AbstractTransform: """Construct from a Quantity's value and unit. - Examples - -------- >>> import coordinax.transforms as cxfm - >>> op = cxfm.Translate.from_([1, 1, 1], "km") >>> print(op) Translate( diff --git a/src/coordinax/transforms/_src/actions/composed.py b/src/coordinax/transforms/_src/actions/composed.py index c605dd33..ddb8987d 100644 --- a/src/coordinax/transforms/_src/actions/composed.py +++ b/src/coordinax/transforms/_src/actions/composed.py @@ -28,8 +28,6 @@ def convert_to_transforms_tuple(inp: Any, /) -> tuple[AbstractTransform, ...]: """Convert to a tuple of transforms for `Pipe`. - Examples - -------- >>> import coordinax.transforms as cxfm >>> op1 = cxfm.Rotate([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) @@ -202,8 +200,6 @@ def __pdoc__(self, **kw: Any) -> wl.AbstractDoc: def compose(*transforms: AbstractTransform) -> Composed: """Compose multiple transforms into a `Composed` transform. - Examples - -------- >>> import coordinax.transforms as cxfm >>> shift = cxfm.Translate.from_([1, 2, 3], "km") @@ -239,8 +235,6 @@ def act( ) -> Array: """Apply Composed to an ArrayLike by sequentially applying each transform. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc @@ -274,8 +268,6 @@ def act( ) -> CDict: """Apply Composed to a CDict by sequentially applying each transform. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.transforms as cxfm @@ -321,8 +313,6 @@ def act( ) -> u.AbstractQuantity: """Apply Composed to a Quantity by sequentially applying each transform. - Examples - -------- >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.transforms as cxfm @@ -347,8 +337,6 @@ def act( ) -> u.AbstractQuantity: """Apply Composed to a Quantity by sequentially applying each transform. - Examples - -------- >>> import unxt as u >>> import coordinax.transforms as cxfm @@ -373,8 +361,6 @@ def act( def simplify(op: Composed, /) -> AbstractTransform: """Simplify a Composed transform. - Examples - -------- >>> import coordinax.transforms as cxfm >>> shift = cxfm.Translate.from_([1, 2, 3], "km") diff --git a/src/coordinax/transforms/_src/actions/identity.py b/src/coordinax/transforms/_src/actions/identity.py index 1db2c98a..00a37bc1 100644 --- a/src/coordinax/transforms/_src/actions/identity.py +++ b/src/coordinax/transforms/_src/actions/identity.py @@ -108,16 +108,13 @@ def inverse(self) -> "Identity": def simplify(op: Identity, /, **__: Any) -> Identity: """Simplify a {class}`coordinax.transforms.Identity` operator. - Examples - -------- - The {class}`coordinax.transforms.Identity` operator is the simplest operator and + The `coordinax.transforms.Identity` operator is the simplest operator and cannot be simplified further: >>> import coordinax.transforms as cxfm - >>> op = cxfm.Identity() - >>> simplified = cxfm.simplify(op) - >>> simplified == op + >>> op = cxfm.identity + >>> cxfm.simplify(op) is op True """ @@ -133,13 +130,11 @@ def simplify(op: Identity, /, **__: Any) -> Identity: def act(op: Identity, tau: Any, x: Any, /, *args: Any, **kw: Any) -> Any: """Identity operator - returns input unchanged. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.transforms as cxfm - >>> op = cxfm.Identity() + >>> op = cxfm.identity >>> q = [1, 2, 3] >>> cxfm.act(op, None, q) is q diff --git a/src/coordinax/transforms/_src/actions/reflect.py b/src/coordinax/transforms/_src/actions/reflect.py index 749d315b..a365d078 100644 --- a/src/coordinax/transforms/_src/actions/reflect.py +++ b/src/coordinax/transforms/_src/actions/reflect.py @@ -43,18 +43,12 @@ class Reflect(AbstractTransform): r"""Operator for Euclidean hyperplane reflections. - A reflection across the hyperplane orthogonal to a nonzero normal vector - $ - n - $ acts on Cartesian coordinates by the Householder matrix - - $$ - H_n = I - 2\hat{n}\hat{n}^T, - $$ - - where $ - \hat{n} = n / \lVert n \rVert - $. + A reflection across the hyperplane orthogonal to a nonzero normal vector $n$ + acts on Cartesian coordinates by the Householder matrix + + $$ H_n = I - 2\hat{n}\hat{n}^T, $$ + + where $ \hat{n} = n / \lVert n \rVert $. Examples -------- diff --git a/src/coordinax/transforms/_src/actions/register_apply.py b/src/coordinax/transforms/_src/actions/register_apply.py index 9331c064..24a176b6 100644 --- a/src/coordinax/transforms/_src/actions/register_apply.py +++ b/src/coordinax/transforms/_src/actions/register_apply.py @@ -16,7 +16,7 @@ import coordinax.representations as cxr from .base import AbstractTransform from .custom_types import CDict -from coordinax.internal import QuantityMatrix, pack_nonuniform_unit, pack_uniform_unit +from coordinax.internal import QMatrix, pack_nonuniform_unit, pack_uniform_unit _MSG_CHARTS_MATCH: Final = ( "inferred chart guess_chart(x)={0.__class__.__name__} " @@ -77,8 +77,6 @@ def act( with a Cartesian chart (e.g. `coordinax.charts.Cartesian3D`) and `coordinax.representations.PointGeometry` geometry. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -120,8 +118,6 @@ def act( with a Cartesian chart (e.g. `coordinax.charts.Cartesian3D`) and `coordinax.representations.PointGeometry` geometry. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -150,8 +146,6 @@ def act(op: AbstractTransform, tau: Any, x: AbcQ, /, **kw: Any) -> AbcQ: `coordinax.charts.Cartesian3D`) and `coordinax.representations.PointGeometry` geometry. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -180,8 +174,6 @@ def act( ) -> AbcQ: """Apply operator, routing through dictionary-based implementation. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.transforms as cxfm @@ -218,8 +210,6 @@ def act( ) -> AbcQ: """Apply operator, routing through dictionary-based implementation. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc @@ -246,9 +236,9 @@ def act( # =================================================================== -# On QuantityMatrix inputs +# On QMatrix inputs # -# Precedence=2 on all QuantityMatrix dispatches so they are preferred over +# Precedence=2 on all QMatrix dispatches so they are preferred over # the (SpecificTransform, AbstractQuantity) dispatches in rotate.py, # translate.py, and composed.py (precedence=0) AND over the Identity # catch-all (precedence=1). Without this, plum sees e.g. @@ -257,23 +247,19 @@ def act( @plum.dispatch(precedence=2) # ty: ignore[no-matching-overload] -def act( - op: AbstractTransform, tau: Any, x: QuantityMatrix, /, **kw: Any -) -> QuantityMatrix: - """Apply an operator to a ``QuantityMatrix``. +def act(op: AbstractTransform, tau: Any, x: QMatrix, /, **kw: Any) -> QMatrix: + """Apply an operator to a ``QMatrix``. The chart is inferred from the matrix size and the representation defaults to ``point``. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.transforms as cxfm - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix >>> op = cxfm.Rotate.from_euler("z", u.Q(90, "deg")) - >>> qm = QuantityMatrix( + >>> qm = QMatrix( ... jnp.array([1.0, 0.0, 0.0]), ... unit=(u.unit("km"), u.unit("km"), u.unit("km")), ... ) @@ -284,29 +270,27 @@ def act( """ chart = cxc.guess_chart(x) out = cxfmapi.act(op, tau, x, chart, cxr.point, **kw) - return cast("QuantityMatrix", out) + return cast("QMatrix", out) @plum.dispatch(precedence=2) # ty: ignore[no-matching-overload] def act( op: AbstractTransform, tau: Any, - x: QuantityMatrix, + x: QMatrix, chart: cxc.AbstractChart, /, **kw: Any, -) -> QuantityMatrix: - """Apply an operator to a ``QuantityMatrix`` with explicit chart. +) -> QMatrix: + """Apply an operator to a ``QMatrix`` with explicit chart. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.transforms as cxfm - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix - >>> qm = QuantityMatrix( + >>> qm = QMatrix( ... jnp.array([1.0, 0.0, 0.0]), ... unit=("km", "km", "km"), ... ) @@ -323,35 +307,33 @@ def act( """ out = cxfmapi.act(op, tau, x, chart, cxr.point, **kw) - return cast("QuantityMatrix", out) + return cast("QMatrix", out) @plum.dispatch(precedence=2) # ty: ignore[no-matching-overload] def act( op: AbstractTransform, tau: Any, - x: QuantityMatrix, + x: QMatrix, chart: cxc.AbstractChart, rep: cxr.Representation, /, **kw: Any, -) -> QuantityMatrix: - """Apply an operator to a ``QuantityMatrix`` with explicit chart and rep. +) -> QMatrix: + """Apply an operator to a ``QMatrix`` with explicit chart and rep. Routes through the CDict-based implementation, then repacks the result - into a ``QuantityMatrix``. + into a ``QMatrix``. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.charts as cxc >>> import coordinax.transforms as cxfm >>> import coordinax.representations as cxr - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix >>> op = cxfm.Rotate.from_euler("z", u.Q(90, "deg")) - >>> qm = QuantityMatrix( + >>> qm = QMatrix( ... jnp.array([1.0, 0.0, 0.0]), ... unit=("km", "km", "km"), ... ) @@ -360,13 +342,13 @@ def act( Array([0., 1., 0.], dtype=float64) """ - # Convert QuantityMatrix → CDict + # Convert QMatrix → CDict v = cxc.cdict(x, chart) # Act on the CDict nv = cxfmapi.act(op, tau, v, chart, rep, **kw) - # Repack CDict → QuantityMatrix + # Repack CDict → QMatrix arr, units = pack_nonuniform_unit(nv, keys=chart.components) - return QuantityMatrix(arr, unit=units) + return QMatrix(arr, unit=units) # =================================================================== @@ -377,8 +359,6 @@ def act( def act(op: AbstractTransform, tau: Any, x: CDict, /, **kw: Any) -> CDict: """Apply operator to a CDict representation of a vector. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.transforms as cxfm @@ -404,8 +384,6 @@ def act( ) -> CDict: """Apply operator to a CDict representation of a vector. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.transforms as cxfm diff --git a/src/coordinax/transforms/_src/actions/rotate.py b/src/coordinax/transforms/_src/actions/rotate.py index 2313a215..c6d8a9ae 100644 --- a/src/coordinax/transforms/_src/actions/rotate.py +++ b/src/coordinax/transforms/_src/actions/rotate.py @@ -323,11 +323,8 @@ def __matmul__(self: "Rotate", other: Any, /) -> Any: def from_(cls: type[Rotate], obj: Rotate, /) -> Rotate: """Construct a Rotate from another Rotate. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.transforms as cxfm - >>> R = cxfm.Rotate(jnp.eye(3)) >>> cxfm.Rotate.from_(R) is R True @@ -340,11 +337,9 @@ def from_(cls: type[Rotate], obj: Rotate, /) -> Rotate: def from_(cls: type[Rotate], obj: Callable[..., Any], /) -> Rotate: """Construct a Rotate from a callable. - The callable must have a return type annotation with shape ending in NxN - (a square matrix). + The callable must have a return type annotation with shape ending in NxN (a + square matrix). - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.transforms as cxfm >>> from jaxtyping import Array, Real @@ -394,12 +389,9 @@ def from_(cls: type[Rotate], obj: Callable[..., Any], /) -> Rotate: def from_(cls: type[Rotate], obj: AbcQ, /) -> Rotate: """Construct a Rotate from a Quantity. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.transforms as cxfm - >>> cxfm.Rotate.from_(u.Q(jnp.eye(3), "")) Rotate(f64[3,3](jax)) @@ -411,11 +403,8 @@ def from_(cls: type[Rotate], obj: AbcQ, /) -> Rotate: def from_(cls: type[Rotate], obj: ArrayLike, /) -> Rotate: """Construct a Rotate from an Array. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.transforms as cxfm - >>> cxfm.Rotate.from_(jnp.eye(3)) Rotate(f64[3,3](jax)) @@ -427,8 +416,6 @@ def from_(cls: type[Rotate], obj: ArrayLike, /) -> Rotate: def from_(cls: type[Rotate], obj: jtransform.Rotation, /) -> Rotate: """Initialize from a `jax.scipy.spatial.transform.Rotation`. - Examples - -------- >>> import jax.numpy as jnp >>> from jax.scipy.spatial.transform import Rotation >>> import coordinax.main as cx @@ -451,8 +438,6 @@ def from_(cls: type[Rotate], obj: jtransform.Rotation, /) -> Rotate: def simplify(op: Rotate, /, **kw: Any) -> AbstractTransform: """Simplify the Galilean rotation operator. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.main as cx diff --git a/src/coordinax/transforms/_src/actions/translate.py b/src/coordinax/transforms/_src/actions/translate.py index 65f30fad..45a0df62 100644 --- a/src/coordinax/transforms/_src/actions/translate.py +++ b/src/coordinax/transforms/_src/actions/translate.py @@ -139,8 +139,6 @@ def act( The array is interpreted as Cartesian coordinates. The delta is converted to the same unit system to perform the addition. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.transforms as cxfm @@ -210,8 +208,6 @@ def act( The array is interpreted as Cartesian coordinates. The delta is converted to the same unit system to perform the addition. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.transforms as cxfm >>> import coordinax.representations as cxr @@ -266,8 +262,6 @@ def act( Dispatches on ``op.semantic_kind`` to determine which representations are shifted. - Examples - -------- >>> import coordinax.transforms as cxfm >>> import unxt as u diff --git a/src/coordinax/vectors/_src/bundle.py b/src/coordinax/vectors/_src/bundle.py index 46a6c8ff..1e8e0af9 100644 --- a/src/coordinax/vectors/_src/bundle.py +++ b/src/coordinax/vectors/_src/bundle.py @@ -551,8 +551,6 @@ def __str__(self) -> str: def from_(cls: type[Coordinate], pv: Coordinate, /) -> Coordinate: """Identity: return the same Coordinate unchanged. - Examples - -------- >>> import coordinax.main as cx >>> pv = cx.Coordinate(point=cx.Point.from_([1.0, 2.0, 3.0], "m")) >>> cx.Coordinate.from_(pv) is pv @@ -566,8 +564,6 @@ def from_(cls: type[Coordinate], pv: Coordinate, /) -> Coordinate: def from_(cls: type[Coordinate], p: Point, /) -> Coordinate: """Wrap a single ``Point`` as a point-only bundle (no field vectors). - Examples - -------- >>> import coordinax.main as cx >>> p = cx.Point.from_([1.0, 2.0, 3.0], "m") >>> pv = cx.Coordinate.from_(p) @@ -591,8 +587,6 @@ def from_( The mapping may contain a ``"point"`` key for the base; the explicit ``point`` keyword argument takes precedence if both are supplied. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.charts as cxc diff --git a/src/coordinax/vectors/_src/point.py b/src/coordinax/vectors/_src/point.py index 278fc6f4..e27263bf 100644 --- a/src/coordinax/vectors/_src/point.py +++ b/src/coordinax/vectors/_src/point.py @@ -221,8 +221,6 @@ def __pdoc__(self, *, vector_form: bool = False, **kw: Any) -> wl.AbstractDoc: def from_(cls: type[Point], obj: Point, /) -> Point: """Construct a point from another point. - Examples - -------- >>> import coordinax.main as cx >>> vec1 = cx.Point.from_([1, 2, 3], "m") >>> vec2 = cx.Point.from_(vec1) @@ -246,8 +244,6 @@ def from_( ) -> Point: """Construct a vector from an object, and chart and rep info. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -346,8 +342,6 @@ def from_(cls: type[Point], obj: Any, /) -> Any: Note that this is a pretty limited constructor since it often lacks the necessary information to do a proper construction. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -444,8 +438,6 @@ def from_( ) -> Any: """Construct a vector from an array, unit, chart, and rep. - Examples - -------- >>> import jax.numpy as jnp >>> import coordinax.main as cx @@ -478,8 +470,6 @@ def from_( ) -> Point: """Construct a point from another point, replacing its frame. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.frames as cxf @@ -506,8 +496,6 @@ def from_( ) -> Point: """Construct a point from any object with a frame. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.frames as cxf >>> import unxt as u @@ -534,8 +522,6 @@ def from_( ) -> Point: """Construct a point from an object, chart, and frame. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.frames as cxf >>> import coordinax.charts as cxc @@ -564,8 +550,6 @@ def from_( ) -> Point: """Construct a point from an object, chart, representation, and frame. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.frames as cxf >>> import coordinax.charts as cxc @@ -596,8 +580,6 @@ def from_( ) -> Point: """Construct a point from an array, unit, and frame. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.frames as cxf diff --git a/src/coordinax/vectors/_src/register_cx.py b/src/coordinax/vectors/_src/register_cx.py index a0959264..9fdefa28 100644 --- a/src/coordinax/vectors/_src/register_cx.py +++ b/src/coordinax/vectors/_src/register_cx.py @@ -35,8 +35,6 @@ def cconvert( ) -> Point: """Convert a point from one chart to another. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx @@ -152,8 +150,6 @@ def pt_map( def cdict(obj: Point, /) -> CDict: """Extract component dictionary from a Point. - Examples - -------- >>> import coordinax.main as cx >>> import unxt as u >>> vec = cx.Point.from_(u.Q([1, 2, 3], "m")) @@ -169,14 +165,12 @@ def cdict(obj: Point, /) -> CDict: def cdict(obj: Tangent, /) -> CDict: """Extract component dictionary from a Tangent. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.representations as cxr + >>> import coordinax.charts as cxc >>> import unxt as u - >>> vec = cx.Tangent.from_( - ... u.Q([1.0, 2.0, 3.0], "m/s"), "m/s", cx.cart3d, cxr.coord_basis, cxr.vel - ... ) + >>> d = {"x": u.Q(1.0, "m/s"), "y": u.Q(2.0, "m/s"), "z": u.Q(3.0, "m/s")} + >>> vec = cx.Tangent.from_(d, cxc.cart3d, cxr.coord_vel) >>> d = cx.cdict(vec) >>> list(d.keys()) ['x', 'y', 'z'] @@ -204,8 +198,6 @@ def cconvert( (Jacobian pushforward) is evaluated. It may be a `Point` instance (whose ``.data`` is used) or a raw ``CDict``. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.charts as cxc @@ -259,8 +251,6 @@ def cconvert( ) -> Tangent: """Convert a tangent Tangent from one chart to another (explicit from-chart). - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.charts as cxc @@ -296,16 +286,14 @@ def change_basis( ) -> Tangent: """Change the basis of a `Tangent` vector. - Converts the component data from the current basis to ``to_basis`` using - the registered ``change_basis`` overload for dicts, then returns a new - `Tangent` with the updated data and basis. + Converts the component data from the current basis to ``to_basis`` using the + registered ``change_basis`` overload for dicts, then returns a new `Tangent` + with the updated data and basis. - The ``at`` parameter provides the base point at which the scale factors - are evaluated. It may be a `Point` instance (whose ``.data`` is used) or - a raw ``CDict``. + The ``at`` parameter provides the base point at which the scale factors are + evaluated. It may be a `Point` instance (whose ``.data`` is used) or a raw + ``CDict``. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.charts as cxc @@ -364,11 +352,9 @@ def change_basis( `Tangent` carries the same chart and frame as the input `Point`, and its basis is ``to_basis``. - The ``at`` and ``usys`` parameters are accepted for API consistency but - are not used. + The ``at`` and ``usys`` parameters are accepted for API consistency but are + not used. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.representations as cxr @@ -396,16 +382,13 @@ def change_basis( def add(lhs: Point, rhs: Point, /) -> Point: """Add two points. - For non-Cartesian charts the operation converts both operands to the - ambient Cartesian chart, adds there, and converts the result back - to the ``lhs`` chart. For Cartesian charts the addition is direct. + For non-Cartesian charts the operation converts both operands to the ambient + Cartesian chart, adds there, and converts the result back to the ``lhs`` + chart. For Cartesian charts the addition is direct. The result keeps the ``lhs`` chart and representation. - Examples - -------- >>> import coordinax.main as cx - >>> v1 = cx.Point.from_([1, 2, 3], "m") >>> v2 = cx.Point.from_([4, 5, 6], "m") >>> print(cxr.add(v1, v2)) @@ -428,16 +411,13 @@ def add(lhs: Point, rhs: Point, /) -> Point: def subtract(lhs: Point, rhs: Point, /) -> Point: """Subtract two vectors. - For non-Cartesian charts the operation converts both operands to the - ambient Cartesian chart, subtracts there, and converts the result back - to the ``lhs`` chart. For Cartesian charts the subtraction is direct. + For non-Cartesian charts the operation converts both operands to the ambient + Cartesian chart, subtracts there, and converts the result back to the + ``lhs`` chart. For Cartesian charts the subtraction is direct. The result keeps the ``lhs`` chart and representation. - Examples - -------- >>> import coordinax.main as cx - >>> v1 = cx.Point.from_([4, 5, 6], "m") >>> v2 = cx.Point.from_([1, 2, 3], "m") >>> print(cxr.subtract(v1, v2)) @@ -462,12 +442,10 @@ def subtract(lhs: Point, rhs: Point, /) -> Point: def add(lhs: Tangent, rhs: Tangent, /) -> Tangent: """Add two tangent vectors component-wise. - Tangent spaces are genuine vector spaces: addition is component-wise in - any chart basis (no Cartesian round-trip is needed or correct). Both - operands must share the same representation (chart + basis + semantic). + Tangent spaces are genuine vector spaces: addition is component-wise in any + chart basis (no Cartesian round-trip is needed or correct). Both operands + must share the same representation (chart + basis + semantic). - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.charts as cxc @@ -505,8 +483,6 @@ def subtract(lhs: Tangent, rhs: Tangent, /) -> Tangent: any chart basis (no Cartesian round-trip is needed or correct). Both operands must share the same representation (chart + basis + semantic). - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> import coordinax.charts as cxc @@ -544,8 +520,6 @@ def subtract(lhs: Tangent, rhs: Tangent, /) -> Tangent: def act(op: cxfm.AbstractTransform, tau: Any, x: Tangent, /, **kw: Any) -> Tangent: """Act a frame transform on a tangent Tangent. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -571,8 +545,6 @@ def act(op: cxfm.AbstractTransform, tau: Any, x: Tangent, /, **kw: Any) -> Tange def act(op: cxfm.AbstractTransform, tau: Any, x: Point, /, **kw: Any) -> Point: """Act a frame transform on a Point. - Examples - -------- >>> import jax.numpy as jnp >>> import unxt as u >>> import coordinax.main as cx @@ -601,8 +573,6 @@ def act( ) -> Coordinate: """Act a frame transform on a Coordinate (point + all fibres). - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.frames as cxf >>> import coordinax.charts as cxc @@ -649,14 +619,13 @@ def act( def _replace_coordinate(coord: Coordinate, /, **kwargs: Any) -> Coordinate: """Replace fields on a Coordinate. - Supports replacing ``point``, any named fibre field, or ``frame``. - When ``frame`` is supplied, it is forwarded to both the base ``point`` and - every fibre field so the bundle stays internally consistent. + Supports replacing ``point``, any named fibre field, or ``frame``. When + ``frame`` is supplied, it is forwarded to both the base ``point`` and every + fibre field so the bundle stays internally consistent. Unknown keys are rejected with a ``TypeError``. - Examples - -------- + >>> import coordinax.main as cx >>> import coordinax.frames as cxf @@ -713,8 +682,6 @@ def cconvert( Delegates to {meth}`Coordinate.cconvert`. - Examples - -------- >>> import coordinax.main as cx >>> import coordinax.charts as cxc diff --git a/src/coordinax/vectors/_src/register_dataclassish.py b/src/coordinax/vectors/_src/register_dataclassish.py index f91556fc..479f0870 100644 --- a/src/coordinax/vectors/_src/register_dataclassish.py +++ b/src/coordinax/vectors/_src/register_dataclassish.py @@ -16,8 +16,6 @@ def replace(obj: Point, /, **kwargs: Any) -> Point: """Replace fields of a point. - Examples - -------- >>> import dataclassish >>> import unxt as u >>> import coordinax.main as cx diff --git a/src/coordinax/vectors/_src/register_manifolds.py b/src/coordinax/vectors/_src/register_manifolds.py index c4cd35f8..0cb7891a 100644 --- a/src/coordinax/vectors/_src/register_manifolds.py +++ b/src/coordinax/vectors/_src/register_manifolds.py @@ -30,8 +30,7 @@ def pt_project( >>> q = cx.Point.from_( ... {"r": u.Q(1, "m"), "theta": u.Q(2, "rad"), "phi": u.Q(3, "rad")}, ... cx.sph3d) - >>> M = cxm.HyperSphericalManifold(2) - >>> cxm.pt_project(q, M) + >>> cxm.pt_project(q, cxm.S2) Point({'theta': Q(2, 'rad'), 'phi': Q(3, 'rad')}, chart=SphericalTwoSphere(M=Sn(2))) """ diff --git a/src/coordinax/vectors/_src/register_quax.py b/src/coordinax/vectors/_src/register_quax.py index 0dc1738c..4dda7521 100644 --- a/src/coordinax/vectors/_src/register_quax.py +++ b/src/coordinax/vectors/_src/register_quax.py @@ -129,12 +129,10 @@ def neg_p_absvec(operand: Point, /) -> Point: def add_p_absvecs(lhs: Point, rhs: Point, /, **kw: Any) -> Point: r"""Element-wise addition of two points. - For non-Cartesian charts the operation converts both operands to the - ambient Cartesian chart, adds there, and converts the result back - to the ``lhs`` chart. For Cartesian charts the addition is direct. + For non-Cartesian charts the operation converts both operands to the ambient + Cartesian chart, adds there, and converts the result back to the ``lhs`` + chart. For Cartesian charts the addition is direct. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.main as cx @@ -152,14 +150,12 @@ def add_p_absvecs(lhs: Point, rhs: Point, /, **kw: Any) -> Point: def sub_p_absvecs(lhs: Point, rhs: Point, /, **kw: Any) -> Point: r"""Element-wise subtraction of two points. - For non-Cartesian charts the operation converts both operands to the - ambient Cartesian chart, subtracts there, and converts the result back - to the ``lhs`` chart. For Cartesian charts the subtraction is direct. + For non-Cartesian charts the operation converts both operands to the ambient + Cartesian chart, subtracts there, and converts the result back to the + ``lhs`` chart. For Cartesian charts the subtraction is direct. The result keeps the ``lhs`` chart and representation. - Examples - -------- >>> import quaxed.numpy as jnp >>> import coordinax.main as cx diff --git a/src/coordinax/vectors/_src/register_unxt.py b/src/coordinax/vectors/_src/register_unxt.py index b10730e8..71979f89 100644 --- a/src/coordinax/vectors/_src/register_unxt.py +++ b/src/coordinax/vectors/_src/register_unxt.py @@ -13,7 +13,7 @@ import unxt as u from .point import Point -from coordinax.internal import QuantityMatrix, pack_nonuniform_unit +from coordinax.internal import QMatrix, pack_nonuniform_unit @final @@ -45,8 +45,6 @@ class ToUnitsOptions(Enum): def uconvert(usys: u.AbstractUnitSystem, vec: Point, /) -> Point: """Convert the point to the given units. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx @@ -68,8 +66,6 @@ def uconvert( ) -> Point: """Convert the point to the given units. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx @@ -100,8 +96,6 @@ def uconvert( def uconvert(units: Mapping[str, Any], vec: Point, /) -> Point: """Convert the point to the given units. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx @@ -138,17 +132,10 @@ def uconvert(units: Mapping[str, Any], vec: Point, /) -> Point: def uconvert(flag: Literal[ToUnitsOptions.consistent], vec: Point, /) -> Point: """Convert the point to a self-consistent set of units. - Parameters - ---------- - flag - The point is converted to consistent units by looking for the first - quantity with each physical type and converting all components to - the units of that quantity. - vec - The point to convert. + The point is converted to consistent units by looking for the first quantity + with each physical type and converting all components to the units of that + quantity. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx @@ -183,17 +170,8 @@ def uconvert(flag: Literal[ToUnitsOptions.consistent], vec: Point, /) -> Point: @plum.dispatch def uconvert(usys: str, vec: Point, /) -> Point: - """Convert the vector to the given units system. - - Parameters - ---------- - usys - The units system to convert to, as a string. - vec - The vector to convert. + """Convert the Point to the given units system. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx @@ -216,8 +194,6 @@ def uconvert(usys: str, vec: Point, /) -> Point: def point_to_q(obj: Point, /) -> u.AbstractQuantity: """`coordinax.Point` -> `unxt.Quantity`. - Examples - -------- >>> import unxt as u >>> import coordinax.main as cx >>> from plum import convert @@ -230,13 +206,13 @@ def point_to_q(obj: Point, /) -> u.AbstractQuantity: ... {"r": u.Q(1, "km"), "theta": u.Q(2, "deg"), "phi": u.Q(3, "deg")}, ... cx.sph3d) >>> convert(vec, u.AbstractQuantity) - QuantityMatrix([1, 2, 3], '(km, deg, deg)') + QMatrix([1, 2, 3], '(km, deg, deg)') >>> vec = cx.Point.from_( ... {"rho": u.Q(1, "km"), "phi": u.Q(2, "deg"), "z": u.Q(3, "m")}, ... cx.cyl3d) >>> convert(vec, u.AbstractQuantity) - QuantityMatrix([1, 2, 3], '(km, deg, m)') + QMatrix([1, 2, 3], '(km, deg, m)') """ # Pack the the data into value, unit tuple @@ -248,4 +224,4 @@ def point_to_q(obj: Point, /) -> u.AbstractQuantity: unit = u.unit("") if units[0] is None else units[0] return u.Q(vals, unit) - return QuantityMatrix(vals, units) + return QMatrix(vals, units) diff --git a/tests/integration/angles/test_plum.py b/tests/integration/angles/test_plum.py index ae18c559..64fdc209 100644 --- a/tests/integration/angles/test_plum.py +++ b/tests/integration/angles/test_plum.py @@ -12,7 +12,7 @@ def test_promotion_rule(a): """Test the promotion rule for angles.""" # Quantities - q = u.Q(1.0, "rad") + q = u.Q(1, "rad") # Explicit promotion test a_p, q_p = promote(a, q) diff --git a/tests/integration/charts/test_jax.py b/tests/integration/charts/test_jax.py index 1b289770..7bb087be 100644 --- a/tests/integration/charts/test_jax.py +++ b/tests/integration/charts/test_jax.py @@ -27,8 +27,8 @@ # Helpers # --------------------------------------------------------------------------- -_pos_floats = st.floats(min_value=0.5, max_value=5.0, width=32, allow_nan=False) -_any_floats = st.floats(min_value=-5.0, max_value=5.0, width=32, allow_nan=False) +_pos_floats = st.floats(min_value=0.5, max_value=5, width=32, allow_nan=False) +_any_floats = st.floats(min_value=-5, max_value=5, width=32, allow_nan=False) def _cart3d_to_sph3d(x, y, z): @@ -62,7 +62,7 @@ class TestJITCompatibility: def test_jit_cart3d_to_sph3d_matches_eager(self) -> None: """jit(cart3d → sph3d) gives the same result as the eager call.""" - x, y, z = u.Q(3.0, "m"), u.Q(4.0, "m"), u.Q(0.0, "m") + x, y, z = u.Q(3, "m"), u.Q(4, "m"), u.Q(0, "m") r_eager, theta_eager, phi_eager = _cart3d_to_sph3d(x, y, z) r_jit, theta_jit, phi_jit = jax.jit(_cart3d_to_sph3d)(x, y, z) @@ -77,7 +77,7 @@ def test_jit_cart3d_to_sph3d_matches_eager(self) -> None: def test_jit_cart3d_to_cyl3d_matches_eager(self) -> None: """jit(cart3d → cyl3d) gives the same result as the eager call.""" - x, y, z = u.Q(3.0, "m"), u.Q(4.0, "m"), u.Q(5.0, "m") + x, y, z = u.Q(3, "m"), u.Q(4, "m"), u.Q(5, "m") rho_eager, phi_eager, z_eager = _cart3d_to_cyl3d(x, y, z) rho_jit, phi_jit, z_jit = jax.jit(_cart3d_to_cyl3d)(x, y, z) @@ -97,11 +97,11 @@ def identity(x, y, z): p = {"x": x, "y": y, "z": z} return cxc.pt_map(p, cxc.cart3d, cxc.cart3d) - x, y, z = u.Q(1.0, "m"), u.Q(2.0, "m"), u.Q(3.0, "m") + x, y, z = u.Q(1, "m"), u.Q(2, "m"), u.Q(3, "m") result = jax.jit(identity)(x, y, z) - assert u.ustrip("m", result["x"]) == pytest.approx(1.0) - assert u.ustrip("m", result["y"]) == pytest.approx(2.0) - assert u.ustrip("m", result["z"]) == pytest.approx(3.0) + assert u.ustrip("m", result["x"]) == pytest.approx(1) + assert u.ustrip("m", result["y"]) == pytest.approx(2) + assert u.ustrip("m", result["z"]) == pytest.approx(3) @given(x=_any_floats, y=_any_floats, z=_pos_floats) @settings(deadline=None) @@ -123,9 +123,9 @@ class TestVmapCompatibility: def test_vmap_cart3d_to_sph3d_matches_individual(self) -> None: """vmap(cart3d → sph3d) over a batch matches element-wise results.""" - xs = u.Q(jnp.array([1.0, 0.0, 0.0]), "m") - ys = u.Q(jnp.array([0.0, 1.0, 0.0]), "m") - zs = u.Q(jnp.array([0.0, 0.0, 1.0]), "m") + xs = u.Q(jnp.array([1, 0, 0]), "m") + ys = u.Q(jnp.array([0, 1, 0]), "m") + zs = u.Q(jnp.array([0, 0, 1]), "m") r_batch, _, _ = jax.vmap(_cart3d_to_sph3d)(xs, ys, zs) @@ -147,9 +147,9 @@ def test_vmap_cart3d_to_cyl3d_shape(self) -> None: def test_vmap_matches_looped_individual(self) -> None: """Vmap result == list of individual calls stacked together.""" - xs = u.Q(jnp.array([3.0, 0.0, 0.0]), "m") - ys = u.Q(jnp.array([0.0, 4.0, 0.0]), "m") - zs = u.Q(jnp.array([0.0, 0.0, 5.0]), "m") + xs = u.Q(jnp.array([3, 0, 0]), "m") + ys = u.Q(jnp.array([0, 4, 0]), "m") + zs = u.Q(jnp.array([0, 0, 5]), "m") r_vmap, theta_vmap, _phi_vmap = jax.vmap(_cart3d_to_sph3d)(xs, ys, zs) @@ -200,18 +200,18 @@ def test_grad_r_wrt_x_at_unit_x(self) -> None: """dr/dx = 1 at (1, 0, 0) m — agrees with the analytic Jacobian.""" def r_value(x_val): - p = {"x": u.Q(x_val, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} + p = {"x": u.Q(x_val, "m"), "y": u.Q(0, "m"), "z": u.Q(0, "m")} return cxc.pt_map(p, cxc.cart3d, cxc.sph3d)["r"].value dr_dx = jax.grad(r_value)(1.0) - assert float(dr_dx) == pytest.approx(1.0, rel=1e-5) + assert float(dr_dx) == pytest.approx(1, rel=1e-5) # drho/dx = x/rho. At (3,4,0): rho=5, so drho/dx = 3/5 = 0.6. def test_grad_rho_wrt_x_at_known_point(self) -> None: """drho/dx = x/rho — checked at (3, 4, 0) m.""" def rho_value(x_val): - p = {"x": u.Q(x_val, "m"), "y": u.Q(4.0, "m"), "z": u.Q(0.0, "m")} + p = {"x": u.Q(x_val, "m"), "y": u.Q(4, "m"), "z": u.Q(0, "m")} return cxc.pt_map(p, cxc.cart3d, cxc.cyl3d)["rho"].value drho_dx = jax.grad(rho_value)(3.0) @@ -222,11 +222,11 @@ def test_grad_polar_r_wrt_x(self) -> None: """d(polar r)/dx = 1 at (1, 0) m.""" def r_polar(x_val): - p = {"x": u.Q(x_val, "m"), "y": u.Q(0.0, "m")} + p = {"x": u.Q(x_val, "m"), "y": u.Q(0, "m")} return cxc.pt_map(p, cxc.cart2d, cxc.polar2d)["r"].value dr_dx = jax.grad(r_polar)(1.0) - assert float(dr_dx) == pytest.approx(1.0, rel=1e-5) + assert float(dr_dx) == pytest.approx(1, rel=1e-5) @given(x=_pos_floats, z=_pos_floats) @settings(deadline=None) @@ -234,7 +234,7 @@ def test_grad_r_equals_x_over_r_property(self, x: float, z: float) -> None: """Property: dr/dx = x/r (analytical Jacobian of spherical r).""" def r_value(x_val): - p = {"x": u.Q(x_val, "m"), "y": u.Q(0.0, "m"), "z": u.Q(z, "m")} + p = {"x": u.Q(x_val, "m"), "y": u.Q(0, "m"), "z": u.Q(z, "m")} return cxc.pt_map(p, cxc.cart3d, cxc.sph3d)["r"].value dr_dx = jax.grad(r_value)(x) @@ -252,7 +252,7 @@ class TestComposedTransforms: def test_jit_vmap(self) -> None: """jit(vmap(fn)) works and gives the same result as vmap(fn).""" - xs = u.Q(jnp.array([1.0, 2.0, 3.0]), "m") + xs = u.Q(jnp.array([1, 2, 3]), "m") ys = u.Q(jnp.zeros(3), "m") zs = u.Q(jnp.zeros(3), "m") @@ -265,7 +265,7 @@ def test_jit_grad(self) -> None: """jit(grad(fn)) works and gives the same result as grad(fn).""" def r_value(x_val): - p = {"x": u.Q(x_val, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} + p = {"x": u.Q(x_val, "m"), "y": u.Q(0, "m"), "z": u.Q(0, "m")} return cxc.pt_map(p, cxc.cart3d, cxc.sph3d)["r"].value dr_dx_grad = jax.grad(r_value)(1.0) diff --git a/tests/integration/distances/test_distances_plum.py b/tests/integration/distances/test_distances_plum.py index 0c380b07..75331aa8 100644 --- a/tests/integration/distances/test_distances_plum.py +++ b/tests/integration/distances/test_distances_plum.py @@ -12,7 +12,7 @@ def test_promote_distance(d): """Test the promotion rule for distance.""" # Quantities - q = u.Q(1.0, "kpc") + q = u.Q(1, "kpc") # Explicit promotion test d_p, q_p = plum.promote(d, q) diff --git a/tests/integration/distances/test_quax.py b/tests/integration/distances/test_quax.py index 1863c94a..1f2fbdaa 100644 --- a/tests/integration/distances/test_quax.py +++ b/tests/integration/distances/test_quax.py @@ -33,15 +33,12 @@ # Non-negative, bounded to avoid overflow _dist_kpc = cxst.distances( - unit="kpc", - elements=st.floats(min_value=0.0, max_value=3.0, width=32), + unit="kpc", elements=st.floats(min_value=0, max_value=3, width=32) ) # Wider range for vmap/grad tests _dist_kpc_arr = cxst.distances( - unit="kpc", - shape=(4,), - elements=st.floats(min_value=0.0, max_value=2.0, width=32), + unit="kpc", shape=(4,), elements=st.floats(min_value=0, max_value=2, width=32) ) @@ -54,9 +51,9 @@ class TestQuaxedUnary: ("value", "fn", "expected"), [ (1.5, qnp.abs, 1.5), - (1.7, qnp.floor, 1.0), - (1.2, qnp.ceil, 2.0), - (1.6, qnp.round, 2.0), + (1.7, qnp.floor, 1), + (1.2, qnp.ceil, 2), + (1.6, qnp.round, 2), ], ) def test_known_value_returns_distance( @@ -105,10 +102,7 @@ class TestQuaxedBinary: @pytest.mark.parametrize( ("a_val", "b_val", "fn", "expected"), - [ - (1.0, 0.5, qnp.add, 1.5), - (1.5, 0.5, qnp.subtract, 1.0), - ], + [(1, 0.5, qnp.add, 1.5), (1.5, 0.5, qnp.subtract, 1)], ) def test_known_value( self, a_val: float, b_val: float, fn: object, expected: float @@ -119,16 +113,16 @@ def test_known_value( def test_multiply_by_scalar(self) -> None: """Multiplying a Distance by a plain scalar returns a Distance.""" - result = qnp.multiply(cxd.Distance(1.5, "kpc"), 2.0) + result = qnp.multiply(cxd.Distance(1.5, "kpc"), 2) assert isinstance(result, cxd.Distance) - assert jnp.allclose(result.value, 3.0) + assert jnp.allclose(result.value, 3) @given( a=cxst.distances( - unit="kpc", elements=st.floats(min_value=0.0, max_value=3.0, width=32) + unit="kpc", elements=st.floats(min_value=0, max_value=3, width=32) ), b=cxst.distances( - unit="kpc", elements=st.floats(min_value=0.0, max_value=3.0, width=32) + unit="kpc", elements=st.floats(min_value=0, max_value=3, width=32) ), ) def test_add_commutativity(self, a: cxd.Distance, b: cxd.Distance) -> None: @@ -146,10 +140,10 @@ class TestQuaxedReductions: @pytest.mark.parametrize( ("values", "fn", "expected"), [ - ([1.0, 2.0, 3.0], qnp.sum, 6.0), - ([0.0, 2.0, 4.0], qnp.mean, 2.0), - ([3.0, 1.0, 2.0], qnp.min, 1.0), - ([3.0, 1.0, 2.0], qnp.max, 3.0), + ([1, 2, 3], qnp.sum, 6), + ([0, 2, 4], qnp.mean, 2), + ([3, 1, 2], qnp.min, 1), + ([3, 1, 2], qnp.max, 3), ], ) def test_known_value( @@ -171,26 +165,26 @@ class TestQuaxedArrayOps: def test_stack(self) -> None: """Stack creates a Distance array from scalar Distances.""" - result = qnp.stack([cxd.Distance(1.0, "kpc"), cxd.Distance(2.0, "kpc")]) + result = qnp.stack([cxd.Distance(1, "kpc"), cxd.Distance(2, "kpc")]) assert isinstance(result, cxd.Distance) assert result.shape == (2,) - assert jnp.allclose(result.value, jnp.array([1.0, 2.0])) + assert jnp.allclose(result.value, jnp.array([1, 2])) def test_concatenate(self) -> None: result = qnp.concatenate( - [cxd.Distance([1.0, 2.0], "kpc"), cxd.Distance([3.0, 4.0], "kpc")] + [cxd.Distance([1, 2], "kpc"), cxd.Distance([3, 4], "kpc")] ) assert isinstance(result, cxd.Distance) assert result.shape == (4,) - assert jnp.allclose(result.value, jnp.array([1.0, 2.0, 3.0, 4.0])) + assert jnp.allclose(result.value, jnp.array([1, 2, 3, 4])) def test_sort(self) -> None: - result = qnp.sort(cxd.Distance([3.0, 1.0, 2.0], "kpc")) + result = qnp.sort(cxd.Distance([3, 1, 2], "kpc")) assert isinstance(result, cxd.Distance) - assert jnp.allclose(result.value, jnp.array([1.0, 2.0, 3.0])) + assert jnp.allclose(result.value, jnp.array([1, 2, 3])) def test_reshape(self) -> None: - result = qnp.reshape(cxd.Distance([[1.0, 2.0], [3.0, 4.0]], "kpc"), (4,)) + result = qnp.reshape(cxd.Distance([[1, 2], [3, 4]], "kpc"), (4,)) assert isinstance(result, cxd.Distance) assert result.shape == (4,) @@ -201,10 +195,10 @@ def test_broadcast_to(self) -> None: assert jnp.all(result.value == 1.5) def test_diff(self) -> None: - result = qnp.diff(cxd.Distance([1.0, 3.0, 6.0], "kpc")) + result = qnp.diff(cxd.Distance([1, 3, 6], "kpc")) assert isinstance(result, cxd.Distance) assert result.shape == (2,) - assert jnp.allclose(result.value, jnp.array([2.0, 3.0])) + assert jnp.allclose(result.value, jnp.array([2, 3])) @pytest.mark.parametrize( ("cond", "a_val", "b_val", "expected"), @@ -218,9 +212,7 @@ def test_where_scalar_cond( ) -> None: """Where with a scalar condition selects the correct Distance.""" result = qnp.where( - jnp.array(cond), - cxd.Distance(a_val, "kpc"), - cxd.Distance(b_val, "kpc"), + jnp.array(cond), cxd.Distance(a_val, "kpc"), cxd.Distance(b_val, "kpc") ) assert isinstance(result, cxd.Distance) assert jnp.allclose(result.value, expected) @@ -229,11 +221,11 @@ def test_where_array_cond(self) -> None: """Where with a boolean array selects element-wise.""" result = qnp.where( jnp.array([True, False, True]), - cxd.Distance([1.0, 2.0, 3.0], "kpc"), - cxd.Distance([4.0, 5.0, 6.0], "kpc"), + cxd.Distance([1, 2, 3], "kpc"), + cxd.Distance([4, 5, 6], "kpc"), ) assert isinstance(result, cxd.Distance) - assert jnp.allclose(result.value, jnp.array([1.0, 5.0, 3.0])) + assert jnp.allclose(result.value, jnp.array([1, 5, 3])) class TestQuaxQuaxify: @@ -247,14 +239,11 @@ class TestQuaxQuaxify: def test_raw_jax_add_raises_without_quaxify(self) -> None: """Raw jax.numpy.add raises TypeError on a Distance — no dispatch rules.""" with pytest.raises(TypeError): - jax.numpy.add(cxd.Distance(1.0, "kpc"), cxd.Distance(0.5, "kpc")) + jax.numpy.add(cxd.Distance(1, "kpc"), cxd.Distance(0.5, "kpc")) @pytest.mark.parametrize( ("fn", "a_val", "b_val", "expected"), - [ - (jax.numpy.add, 1.0, 0.5, 1.5), - (jax.numpy.subtract, 1.5, 0.5, 1.0), - ], + [(jax.numpy.add, 1, 0.5, 1.5), (jax.numpy.subtract, 1.5, 0.5, 1)], ) def test_quaxify_binary_known_value( self, fn: object, a_val: float, b_val: float, expected: float @@ -272,25 +261,25 @@ def test_quaxify_user_function(self) -> None: def double(x): return jax.numpy.add(x, x) - result = quax.quaxify(double)(cxd.Distance(1.0, "kpc")) + result = quax.quaxify(double)(cxd.Distance(1, "kpc")) assert isinstance(result, cxd.Distance) - assert jnp.allclose(result.value, 2.0) + assert jnp.allclose(result.value, 2) def test_quaxify_jit(self) -> None: """jax.jit(quaxify(fn)) works on a Distance.""" result = jax.jit(quax.quaxify(jax.numpy.add))( - cxd.Distance(1.0, "kpc"), cxd.Distance(0.5, "kpc") + cxd.Distance(1, "kpc"), cxd.Distance(0.5, "kpc") ) assert isinstance(result, cxd.Distance) assert jnp.allclose(result.value, 1.5) def test_quaxify_vmap(self) -> None: """jax.vmap(quaxify(fn)) maps over a Distance array.""" - arr = cxd.Distance([1.0, 2.0, 3.0], "kpc") + arr = cxd.Distance([1, 2, 3], "kpc") result = jax.vmap(quax.quaxify(jax.numpy.add))(arr, arr) assert isinstance(result, cxd.Distance) assert result.value.shape == (3,) - assert jnp.allclose(result.value, jnp.array([2.0, 4.0, 6.0])) + assert jnp.allclose(result.value, jnp.array([2, 4, 6])) @given(d=_dist_kpc) def test_quaxify_preserves_distance_type(self, d: cxd.Distance) -> None: @@ -307,9 +296,9 @@ class TestJAXTransformsWithQuaxed: ("loss_fn", "d_val", "expected_grad"), [ # d/dx sum(x) = 1 - (lambda x: qnp.sum(x).value, 1.0, 1.0), + (lambda x: qnp.sum(x).value, 1.0, 1), # d/dx 2*x = 2 (via x + x) - (lambda x: qnp.add(x, x).value, 1.0, 2.0), + (lambda x: qnp.add(x, x).value, 1.0, 2), ], ) def test_grad_known_value( @@ -326,14 +315,11 @@ def test_grad_sum_is_one(self, d: cxd.Distance) -> None: """d/dx sum(x) == 1 for any Distance scalar.""" g = jax.grad(lambda x: qnp.sum(x).value)(d) assert isinstance(g, cxd.Distance) - assert jnp.allclose(g.value, 1.0, atol=1e-5) + assert jnp.allclose(g.value, 1, atol=1e-5) @pytest.mark.parametrize( ("fn", "d_val", "expected_val"), - [ - (lambda x: qnp.add(x, x), 1.0, 2.0), - (lambda x: qnp.multiply(x, 3.0), 2.0, 6.0), - ], + [(lambda x: qnp.add(x, x), 1, 2), (lambda x: qnp.multiply(x, 3), 2, 6)], ) def test_jit(self, fn: object, d_val: float, expected_val: float) -> None: """jax.jit works on functions using qnp over Distance.""" diff --git a/tests/unit/angles/test_angle.py b/tests/unit/angles/test_angle.py index b0d8f50d..571ec567 100644 --- a/tests/unit/angles/test_angle.py +++ b/tests/unit/angles/test_angle.py @@ -27,10 +27,10 @@ _bounded_f32 = st.floats(min_value=-1e10, max_value=1e10, width=32) # Float32 values in [0, 1] — used for algebraic property tests -_unit_f32 = st.floats(min_value=0.0, max_value=1.0, width=32) +_unit_f32 = st.floats(min_value=0, max_value=1, width=32) # Float32 values in [-1, 1] — safe trig domain for grad tests -_trig_f32 = st.floats(min_value=-1.0, max_value=1.0, width=32) +_trig_f32 = st.floats(min_value=-1, max_value=1, width=32) # --------------------------------------------------------------------------- # Named constants for float32 overflow boundary @@ -93,10 +93,7 @@ def test_shape_matches_requested( @pytest.mark.parametrize( ("value", "unit_str", "expected_shape"), - [ - (1, "rad", ()), - ([1.0, 2.0, 3.0], "deg", (3,)), - ], + [(1, "rad", ()), ([1, 2, 3], "deg", (3,))], ) def test_construct_from_python( self, value: object, unit_str: str, expected_shape: tuple[int, ...] @@ -108,7 +105,7 @@ def test_construct_from_python( def test_construct_from_jnp_array(self) -> None: """Angles can be constructed from a JAX array.""" - arr = jnp.array([0.0, jnp.pi / 2, jnp.pi]) + arr = jnp.array([0, jnp.pi / 2, jnp.pi]) a = cxa.Angle(arr, "rad") assert isinstance(a, cxa.Angle) assert a.shape == (3,) @@ -116,7 +113,7 @@ def test_construct_from_jnp_array(self) -> None: def test_invalid_unit_raises(self) -> None: """Non-angular units are rejected at construction time.""" with pytest.raises(ValueError, match="angular dimensions"): - cxa.Angle(1.0, "m") + cxa.Angle(1, "m") class TestAngleConversion: @@ -124,10 +121,7 @@ class TestAngleConversion: @pytest.mark.parametrize( ("value", "from_unit", "to_unit", "expected"), - [ - (180.0, "deg", "rad", jnp.pi), - (jnp.pi, "rad", "deg", 180.0), - ], + [(180, "deg", "rad", jnp.pi), (jnp.pi, "rad", "deg", 180)], ) def test_unit_conversion( self, @@ -205,8 +199,8 @@ def test_wrap_idempotent(self, angle: cxa.Angle) -> None: @pytest.mark.parametrize( ("value", "expected"), [ - (370.0, 10.0), # one full turn above the range - (-10.0, 350.0), # one step below zero + (370, 10), # one full turn above the range + (-10, 350), # one step below zero ], ) def test_wrap_known_values(self, value: float, expected: float) -> None: @@ -235,7 +229,7 @@ def test_add_returns_angle(self, angle: cxa.Angle) -> None: def test_sub_self_is_zero(self, angle: cxa.Angle) -> None: result = angle - angle assert isinstance(result, cxa.Angle) - assert jnp.allclose(result.value, 0.0) + assert jnp.allclose(result.value, 0) @given(angle=cxst.angles()) def test_neg_flips_sign(self, angle: cxa.Angle) -> None: @@ -275,7 +269,7 @@ def test_add_overflow_produces_inf(self, angle: cxa.Angle) -> None: @given(angle=cxst.angles()) def test_mul_quantity_leaves_angle_type(self, angle: cxa.Angle) -> None: """Angle x dimensioned Quantity yields a plain Quantity, not an Angle.""" - result = angle * u.Q(2.0, "s") + result = angle * u.Q(2, "s") assert isinstance(result, u.AbstractQuantity) assert not isinstance(result, AbstractAngle) diff --git a/tests/unit/charts/test_base.py b/tests/unit/charts/test_base.py index 17873ed9..e5636d1e 100644 --- a/tests/unit/charts/test_base.py +++ b/tests/unit/charts/test_base.py @@ -68,47 +68,42 @@ class TestAbstractChartCheckData: def test_check_data_passes_for_valid_data(self) -> None: """check_data passes for valid data matching chart components.""" - data = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} + data = {"x": u.Q(1, "m"), "y": u.Q(2, "m"), "z": u.Q(3, "m")} cxc.cart3d.check_data(data) def test_check_data_raises_for_missing_component(self) -> None: """check_data raises when a component is missing.""" - data = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")} # missing z + data = {"x": u.Q(1, "m"), "y": u.Q(2, "m")} # missing z with pytest.raises(ValueError, match="Data keys do not match"): cxc.cart3d.check_data(data) def test_check_data_raises_for_extra_component(self) -> None: """check_data raises when an extra component is present.""" - data = { - "x": u.Q(1.0, "m"), - "y": u.Q(2.0, "m"), - "z": u.Q(3.0, "m"), - "w": u.Q(4.0, "m"), - } + data = {"x": u.Q(1, "m"), "y": u.Q(2, "m"), "z": u.Q(3, "m"), "w": u.Q(4, "m")} with pytest.raises(ValueError, match="Data keys do not match"): cxc.cart3d.check_data(data) def test_check_data_skips_key_check_when_disabled(self) -> None: """check_data skips key check when keys=False.""" - data = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")} # missing z + data = {"x": u.Q(1, "m"), "y": u.Q(2, "m")} # missing z cxc.cart3d.check_data(data, keys=False) # should not raise # --- dimensions check --- def test_check_data_passes_for_valid_dimensions(self) -> None: """check_data passes when data dimensions match chart coord_dimensions.""" - data = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} + data = {"x": u.Q(1, "m"), "y": u.Q(2, "m"), "z": u.Q(3, "m")} cxc.cart3d.check_data(data, values=True) def test_check_data_raises_for_wrong_dimension(self) -> None: """check_data raises when a value's dimension doesn't match the chart.""" - data = {"x": u.Q(1.0, "s"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} + data = {"x": u.Q(1, "s"), "y": u.Q(2, "m"), "z": u.Q(3, "m")} with pytest.raises(ValueError, match="Data dimension for 'x' does not match"): cxc.cart3d.check_data(data, values=True) def test_check_data_ignores_values_by_default(self) -> None: """check_data does not check value dimensions when values=False (default).""" - data = {"x": u.Q(1.0, "s"), "y": u.Q(2.0, "m"), "z": u.Q(3.0, "m")} + data = {"x": u.Q(1, "s"), "y": u.Q(2, "m"), "z": u.Q(3, "m")} cxc.cart3d.check_data(data) # should not raise diff --git a/tests/unit/charts/test_cdict.py b/tests/unit/charts/test_cdict.py index 1527a43f..e294f108 100644 --- a/tests/unit/charts/test_cdict.py +++ b/tests/unit/charts/test_cdict.py @@ -33,7 +33,7 @@ def test_cdict_from_quantity(q): def test_cdict_from_quantity_chart_error(): """cdict(Quantity, chart) with mismatched chart should raise ValueError.""" - q = u.Q([1.0, 2.0], "m") # 2 components + q = u.Q([1, 2], "m") # 2 components with pytest.raises( ValueError, match=r"Quantity last dimension 2 does not match provided keys 3." ): diff --git a/tests/unit/charts/test_checks.py b/tests/unit/charts/test_checks.py index 66a62c5e..85ad8eaf 100644 --- a/tests/unit/charts/test_checks.py +++ b/tests/unit/charts/test_checks.py @@ -31,8 +31,8 @@ class TestPolarRange: @given( data=st.data(), - lower=float32s(min_value=0.0, max_value=1.0), - upper=float32s(min_value=2.0, max_value=PI_F32), + lower=float32s(min_value=0, max_value=1), + upper=float32s(min_value=2, max_value=PI_F32), quantity_cls=angle_classes, ) @settings(deadline=None) @@ -51,7 +51,7 @@ def test_angular_within_bounds_passes( result = checks.polar_range(angle) assert jnp.array_equal(result.value, angle.value) - @given(ust.quantities("m", elements=float32s(min_value=0.0, max_value=PI_F32))) + @given(ust.quantities("m", elements=float32s(min_value=0, max_value=PI_F32))) @settings(deadline=None) def test_non_angular_units_raises(self, x: u.AbstractQuantity) -> None: """Non-angular quantities always raise, regardless of value.""" @@ -67,8 +67,8 @@ def test_angular_outside_bounds_raises(self, data: st.DataObject) -> None: ust.quantities( "rad", elements=st.one_of( - float32s(min_value=-10.0, max_value=0.0, exclude_max=True), - float32s(min_value=PI_F32, max_value=10.0, exclude_min=True), + float32s(min_value=-10, max_value=0, exclude_max=True), + float32s(min_value=PI_F32, max_value=10, exclude_min=True), ), quantity_cls=data.draw(angle_classes), ), @@ -98,7 +98,7 @@ def test_positive_values_pass(self, x: u.AbstractQuantity) -> None: def test_zero_raises(self) -> None: """Zero should raise an error.""" - x = u.Q(0.0, "m") + x = u.Q(0, "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be non-negative and non-zero", @@ -107,7 +107,7 @@ def test_zero_raises(self) -> None: def test_negative_raises(self) -> None: """Negative values should raise an error.""" - x = u.Q(-1.0, "m") + x = u.Q(-1, "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be non-negative and non-zero", @@ -116,7 +116,7 @@ def test_negative_raises(self) -> None: def test_array_with_zero_raises(self) -> None: """Arrays containing zero should raise an error.""" - x = u.Q([1.0, 0.0, 2.0], "m") + x = u.Q([1, 0, 2], "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be non-negative and non-zero", @@ -125,7 +125,7 @@ def test_array_with_zero_raises(self) -> None: def test_array_with_negative_raises(self) -> None: """Arrays containing negative values should raise an error.""" - x = u.Q([1.0, -1.0, 2.0], "m") + x = u.Q([1, -1, 2], "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be non-negative and non-zero", @@ -136,13 +136,13 @@ def test_array_with_negative_raises(self) -> None: class TestLeq: """Tests for leq (less than or equal) check.""" - @given(data=st.data(), max_val=float32s(min_value=1.0, max_value=100.0)) + @given(data=st.data(), max_val=float32s(min_value=1, max_value=100)) @settings(deadline=None) def test_values_below_max_pass(self, data: st.DataObject, max_val: float) -> None: """Values <= max should pass through unchanged.""" x = data.draw( ust.quantities( - "m", shape=(), elements=float32s(min_value=0.0, max_value=max_val) + "m", shape=(), elements=float32s(min_value=0, max_value=max_val) ) ) max_q = u.Q(max_val, "m") @@ -151,15 +151,15 @@ def test_values_below_max_pass(self, data: st.DataObject, max_val: float) -> Non def test_equal_to_max_passes(self) -> None: """Values equal to max should pass.""" - x = u.Q(5.0, "m") - max_q = u.Q(5.0, "m") + x = u.Q(5, "m") + max_q = u.Q(5, "m") result = checks.leq(x, max_q) assert jnp.array_equal(result.value, x.value) def test_above_max_raises(self) -> None: """Values above max should raise an error.""" - x = u.Q(6.0, "m") - max_q = u.Q(5.0, "m") + x = u.Q(6, "m") + max_q = u.Q(5, "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be less than or equal to" ): @@ -167,8 +167,8 @@ def test_above_max_raises(self) -> None: def test_array_with_value_above_max_raises(self) -> None: """Arrays with any value above max should raise an error.""" - x = u.Q([1.0, 5.0, 6.0], "m") - max_q = u.Q(5.0, "m") + x = u.Q([1, 5, 6], "m") + max_q = u.Q(5, "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be less than or equal to" ): @@ -178,17 +178,12 @@ def test_array_with_value_above_max_raises(self) -> None: class TestGeq: """Tests for geq (greater than or equal) check.""" - @given( - data=st.data(), - min_val=float32s(min_value=0.0, max_value=10.0), - ) + @given(data=st.data(), min_val=float32s(min_value=0, max_value=10)) def test_values_above_min_pass(self, data: st.DataObject, min_val: float) -> None: """Values >= min should pass through unchanged.""" x = data.draw( ust.quantities( - "m", - shape=(), - elements=float32s(min_value=min_val, max_value=100.0), + "m", shape=(), elements=float32s(min_value=min_val, max_value=100) ) ) min_q = u.Q(min_val, "m") @@ -197,15 +192,15 @@ def test_values_above_min_pass(self, data: st.DataObject, min_val: float) -> Non def test_equal_to_min_passes(self) -> None: """Values equal to min should pass.""" - x = u.Q(5.0, "m") - min_q = u.Q(5.0, "m") + x = u.Q(5, "m") + min_q = u.Q(5, "m") result = checks.geq(x, min_q) assert jnp.array_equal(result.value, x.value) def test_below_min_raises(self) -> None: """Values below min should raise an error.""" - x = u.Q(4.0, "m") - min_q = u.Q(5.0, "m") + x = u.Q(4, "m") + min_q = u.Q(5, "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be greater than or equal to", @@ -214,8 +209,8 @@ def test_below_min_raises(self) -> None: def test_array_with_value_below_min_raises(self) -> None: """Arrays with any value below min should raise an error.""" - x = u.Q([4.0, 5.0, 6.0], "m") - min_q = u.Q(5.0, "m") + x = u.Q([4, 5, 6], "m") + min_q = u.Q(5, "m") with pytest.raises( (eqx.EquinoxRuntimeError, ValueError), match="must be greater than or equal to", diff --git a/tests/unit/charts/test_galilean_charts.py b/tests/unit/charts/test_galilean_charts.py index 1d900588..7fe00c72 100644 --- a/tests/unit/charts/test_galilean_charts.py +++ b/tests/unit/charts/test_galilean_charts.py @@ -155,10 +155,10 @@ def test_split_components_ct_key(self) -> None: """split_components returns time dict with 'ct' key.""" chart = cxc.GalileanCT() p = { - "ct": u.Q(1.0, "km"), - "x": u.Q(2.0, "km"), - "y": u.Q(3.0, "km"), - "z": u.Q(4.0, "km"), + "ct": u.Q(1, "km"), + "x": u.Q(2, "km"), + "y": u.Q(3, "km"), + "z": u.Q(4, "km"), } time_part, spatial_part = chart.split_components(p) assert "ct" in time_part @@ -168,10 +168,10 @@ def test_merge_components_roundtrip(self) -> None: """merge_components(split_components(p)) == p.""" chart = cxc.GalileanCT() p = { - "ct": u.Q(1.0, "km"), - "x": u.Q(2.0, "km"), - "y": u.Q(3.0, "km"), - "z": u.Q(4.0, "km"), + "ct": u.Q(1, "km"), + "x": u.Q(2, "km"), + "y": u.Q(3, "km"), + "z": u.Q(4, "km"), } time_part, spatial_part = chart.split_components(p) merged = chart.merge_components((time_part, spatial_part)) diff --git a/tests/unit/charts/test_guess_chart.py b/tests/unit/charts/test_guess_chart.py index 85683571..6cf875f7 100644 --- a/tests/unit/charts/test_guess_chart.py +++ b/tests/unit/charts/test_guess_chart.py @@ -69,7 +69,7 @@ def test_guess_chart_from_dict_returns_same_components( ) -> None: """guess_chart with dict input returns chart with same components.""" # Create a component dictionary with dummy values - d = dict.fromkeys(chart.components, 1.0) + d = dict.fromkeys(chart.components, 1) # Guess the chart from the dict guessed = cxc.guess_chart(d) @@ -90,8 +90,8 @@ def test_frozenset_dispatch(self) -> None: def test_dict_dispatch_returns_same_type(self) -> None: """The dict dispatch should return same chart type for same keys.""" - d1 = {"x": 1.0, "y": 2.0, "z": 3.0} - d2 = {"x": 5.0, "y": 6.0, "z": 7.0} + d1 = {"x": 1, "y": 2, "z": 3} + d2 = {"x": 5, "y": 6, "z": 7} result1 = cxc.guess_chart(d1) result2 = cxc.guess_chart(d2) assert type(result1) is type(result2) diff --git a/tests/unit/charts/test_jacobian_pt_map.py b/tests/unit/charts/test_jacobian_pt_map.py index f6d982f2..ffb4ea5a 100644 --- a/tests/unit/charts/test_jacobian_pt_map.py +++ b/tests/unit/charts/test_jacobian_pt_map.py @@ -22,7 +22,7 @@ import unxt as u import coordinax.charts as cxc -from coordinax.internal import QuantityMatrix +from coordinax.internal import QMatrix usys_si = u.unitsystems.si @@ -71,7 +71,7 @@ def test_importable_from_charts(self) -> None: class TestJacobianPtMapReturnType: - """Returns a 2-D QuantityMatrix with shape (n_to, n_from).""" + """Returns a 2-D QMatrix with shape (n_to, n_from).""" @pytest.mark.parametrize( ("from_chart", "to_chart", "at", "exp_shape"), @@ -101,9 +101,9 @@ class TestJacobianPtMapReturnType: ), ], ) - def test_returns_QuantityMatrix(self, from_chart, to_chart, at, exp_shape) -> None: + def test_returns_QMatrix(self, from_chart, to_chart, at, exp_shape) -> None: J = cxc.jac_pt_map(at, from_chart, to_chart) - assert isinstance(J, QuantityMatrix) + assert isinstance(J, QMatrix) assert J.ndim == 2 assert J.value.shape == exp_shape @@ -484,7 +484,7 @@ class TestJacobianPtMapCompositionProperty: r"""Property: J_{C2→C1}(p_{C2}) @ J_{C1→C2}(p_{C1}) = I. This is the chain rule: the Jacobian of the round-trip is the identity. - Uses QuantityMatrix matmul (quaxed) which tracks units through the product. + Uses QMatrix matmul (quaxed) which tracks units through the product. The result has all-dimensionless units and values equal to the nxn identity. """ @@ -675,7 +675,7 @@ def jitted(at): return cxc.jac_pt_map(at, cxc.cart3d, cxc.sph3d) J = jitted(at) - assert isinstance(J, QuantityMatrix) + assert isinstance(J, QMatrix) assert_allclose(J.value[0, 0], 1, atol=1e-6) # ∂r/∂x at (1,0,0) def test_jit_cart2d_to_polar2d(self) -> None: @@ -743,7 +743,7 @@ def test_curried_2d(self) -> None: at = {"x": u.Q(1, "m"), "y": u.Q(0, "m")} fn = cxc.jac_pt_map(cxc.cart2d, cxc.polar2d, usys=usys_si) J = fn(at) - assert isinstance(J, QuantityMatrix) + assert isinstance(J, QMatrix) assert J.value.shape == (2, 2) @@ -765,7 +765,7 @@ def test_array_input_returns_array(self) -> None: def test_int_array_input_is_promoted_and_supported(self) -> None: """Integer plain-array input is promoted and produces the correct Jacobian.""" at_int = jnp.array([1, 0, 0]) - at_float = jnp.array([1.0, 0.0, 0.0], dtype=float) + at_float = jnp.array([1, 0, 0], dtype=float) J_int = cxc.jac_pt_map(at_int, cxc.cart3d, cxc.sph3d, usys=usys_si) J_float = cxc.jac_pt_map(at_float, cxc.cart3d, cxc.sph3d, usys=usys_si) @@ -777,7 +777,7 @@ def test_int_array_input_is_promoted_and_supported(self) -> None: def test_bool_array_input_is_promoted_and_supported(self) -> None: """Boolean plain-array input is promoted and produces the correct Jacobian.""" at_bool = jnp.array([True, False, False], dtype=jnp.bool_) - at_float = jnp.array([1.0, 0.0, 0.0], dtype=float) + at_float = jnp.array([1, 0, 0], dtype=float) J_bool = cxc.jac_pt_map(at_bool, cxc.cart3d, cxc.sph3d, usys=usys_si) J_float = cxc.jac_pt_map(at_float, cxc.cart3d, cxc.sph3d, usys=usys_si) @@ -836,9 +836,9 @@ def test_generic_pair_int_arrays_are_promoted_and_supported(self) -> None: """Cart3D→Sph3D integer CDict values are promoted via Array dispatch.""" at_int = {"x": jnp.array(1), "y": jnp.array(0), "z": jnp.array(0)} at_float = { - "x": jnp.array(1.0, dtype=float), - "y": jnp.array(0.0, dtype=float), - "z": jnp.array(0.0, dtype=float), + "x": jnp.array(1, dtype=float), + "y": jnp.array(0, dtype=float), + "z": jnp.array(0, dtype=float), } J_int = cxc.jac_pt_map(at_int, cxc.cart3d, cxc.sph3d, usys=usys_si) diff --git a/tests/unit/charts/test_minkowski_charts.py b/tests/unit/charts/test_minkowski_charts.py index d4fc3538..cf4e40f5 100644 --- a/tests/unit/charts/test_minkowski_charts.py +++ b/tests/unit/charts/test_minkowski_charts.py @@ -28,7 +28,7 @@ def test_explicit_manifold(self) -> None: def test_wrong_manifold_raises(self) -> None: """Passing a non-MinkowskiManifold manifold raises ValueError.""" with pytest.raises((ValueError, TypeError)): - cxc.MinkowskiCT(M=cxm.EuclideanManifold(3)) + cxc.MinkowskiCT(M=cxm.R3) # ============================================================================= diff --git a/tests/unit/charts/test_predef_charts.py b/tests/unit/charts/test_predef_charts.py index 8e7b6d8b..63ee4f21 100644 --- a/tests/unit/charts/test_predef_charts.py +++ b/tests/unit/charts/test_predef_charts.py @@ -166,14 +166,14 @@ def _quantity_for_dimension(dim: str | None) -> u.AbstractQuantity: if dim is None: - return u.Q(1.0, "") - return u.Q(1.0, _DIM_TO_UNIT[dim]) + return u.Q(1, "") + return u.Q(1, _DIM_TO_UNIT[dim]) def _mismatched_quantity_for_dimension(dim: str | None) -> u.AbstractQuantity: if dim == "time": - return u.Q(1.0, "m") - return u.Q(1.0, "s") + return u.Q(1, "m") + return u.Q(1, "s") def _strict_dimension_component(chart: cxc.AbstractChart) -> str: @@ -319,7 +319,7 @@ def test_predef_chart_no_cartesian_raises(chart) -> None: @pytest.mark.parametrize("chart", [p[1] for p in _CHART_PARAMS], ids=_CHART_IDS) def test_predef_chart_check_data_valid_keys_default_dimensions_false(chart) -> None: """check_data defaults to key checks only when values=False.""" - data = {k: u.Q(1.0, "m") for k in chart.components} + data = {k: u.Q(1, "m") for k in chart.components} chart.check_data(data) # should not raise @@ -329,7 +329,7 @@ def test_predef_chart_check_data_wrong_keys_raises(chart) -> None: if chart.ndim == 0: pytest.skip("0D chart has no components to remove") # Drop the last component to trigger the mismatch - data = {k: u.Q(1.0, "m") for k in chart.components[:-1]} + data = {k: u.Q(1, "m") for k in chart.components[:-1]} with pytest.raises(ValueError, match="Data keys do not match"): chart.check_data(data) diff --git a/tests/unit/charts/test_product.py b/tests/unit/charts/test_product.py index a220c59e..aab7aed1 100644 --- a/tests/unit/charts/test_product.py +++ b/tests/unit/charts/test_product.py @@ -76,43 +76,26 @@ def test_split_components_extracts_by_prefix( self, phase_space: cxc.CartesianProductChart ) -> None: """split_components should extract keys by prefix and strip it.""" - p = {"q.x": 1.0, "q.y": 2.0, "q.z": 3.0, "p.x": 4.0, "p.y": 5.0, "p.z": 6.0} + p = {"q.x": 1, "q.y": 2, "q.z": 3, "p.x": 4, "p.y": 5, "p.z": 6} parts = phase_space.split_components(p) assert len(parts) == 2 - assert parts[0] == {"x": 1.0, "y": 2.0, "z": 3.0} - assert parts[1] == {"x": 4.0, "y": 5.0, "z": 6.0} + assert parts[0] == {"x": 1, "y": 2, "z": 3} + assert parts[1] == {"x": 4, "y": 5, "z": 6} def test_merge_components_reattaches_prefix( self, phase_space: cxc.CartesianProductChart ) -> None: """merge_components should re-add dot-delimited prefix.""" - parts = ( - {"x": 1.0, "y": 2.0, "z": 3.0}, - {"x": 4.0, "y": 5.0, "z": 6.0}, - ) + parts = ({"x": 1, "y": 2, "z": 3}, {"x": 4, "y": 5, "z": 6}) merged = phase_space.merge_components(parts) - expected = { - "q.x": 1.0, - "q.y": 2.0, - "q.z": 3.0, - "p.x": 4.0, - "p.y": 5.0, - "p.z": 6.0, - } + expected = {"q.x": 1, "q.y": 2, "q.z": 3, "p.x": 4, "p.y": 5, "p.z": 6} assert merged == expected def test_split_merge_roundtrip( self, phase_space: cxc.CartesianProductChart ) -> None: """Split followed by merge should recover original dict.""" - original = { - "q.x": 1.0, - "q.y": 2.0, - "q.z": 3.0, - "p.x": 4.0, - "p.y": 5.0, - "p.z": 6.0, - } + original = {"q.x": 1, "q.y": 2, "q.z": 3, "p.x": 4, "p.y": 5, "p.z": 6} parts = phase_space.split_components(original) recovered = phase_space.merge_components(parts) assert recovered == original diff --git a/tests/unit/charts/test_register_realize.py b/tests/unit/charts/test_register_realize.py index e94fe84e..b1b8c7f3 100644 --- a/tests/unit/charts/test_register_realize.py +++ b/tests/unit/charts/test_register_realize.py @@ -90,15 +90,15 @@ def test_namespaced_phase_space_transform(self) -> None: phase_cart = cxc.CartesianProductChart((cxc.cart3d, cxc.cart3d), ("q", "p")) phase_sph = cxc.CartesianProductChart((cxc.sph3d, cxc.sph3d), ("q", "p")) p = { - "q.x": u.Q(1.0, "m"), - "q.y": u.Q(0.0, "m"), - "q.z": u.Q(0.0, "m"), - "p.x": u.Q(0.0, "m"), - "p.y": u.Q(1.0, "m"), - "p.z": u.Q(0.0, "m"), + "q.x": u.Q(1, "m"), + "q.y": u.Q(0, "m"), + "q.z": u.Q(0, "m"), + "p.x": u.Q(0, "m"), + "p.y": u.Q(1, "m"), + "p.z": u.Q(0, "m"), } result = cxc.pt_map(p, phase_cart, phase_sph) - assert u.ustrip("m", result["q.r"]) == pytest.approx(1.0) - assert u.ustrip("rad", result["q.phi"]) == pytest.approx(0.0) - assert u.ustrip("m", result["p.r"]) == pytest.approx(1.0) + assert u.ustrip("m", result["q.r"]) == pytest.approx(1) + assert u.ustrip("rad", result["q.phi"]) == pytest.approx(0) + assert u.ustrip("m", result["p.r"]) == pytest.approx(1) assert u.ustrip("rad", result["p.phi"]) == pytest.approx(jnp.pi / 2) diff --git a/tests/unit/charts/test_utils.py b/tests/unit/charts/test_utils.py index 82b30e81..4aa57896 100644 --- a/tests/unit/charts/test_utils.py +++ b/tests/unit/charts/test_utils.py @@ -33,19 +33,19 @@ def test_uconvert_to_rad_non_angle_quantity_raises() -> None: def test_uconvert_to_rad_arraylike_without_usys_is_radians() -> None: """Plain numerics and JAX arrays are interpreted as radians by default.""" scalar = uconvert_to_rad(float(jnp.pi / 3), None) - vector = uconvert_to_rad(jnp.array([0.0, jnp.pi / 2, jnp.pi]), None) + vector = uconvert_to_rad(jnp.array([0, jnp.pi / 2, jnp.pi]), None) assert isinstance(scalar, float) assert scalar == pytest.approx(float(jnp.pi / 3)) - assert bool(jnp.allclose(vector, jnp.array([0.0, jnp.pi / 2, jnp.pi]))) + assert bool(jnp.allclose(vector, jnp.array([0, jnp.pi / 2, jnp.pi]))) def test_uconvert_to_rad_arraylike_with_usys_angle_unit() -> None: """Plain numerics and JAX arrays can be interpreted through usys['angle'].""" usys = u.unitsystem("m", "deg") - scalar = uconvert_to_rad(90.0, usys) - vector = uconvert_to_rad(jnp.array([0.0, 90.0, 180.0]), usys) + scalar = uconvert_to_rad(90, usys) + vector = uconvert_to_rad(jnp.array([0, 90, 180]), usys) assert scalar == pytest.approx(float(jnp.pi / 2)) - assert bool(jnp.allclose(vector, jnp.array([0.0, jnp.pi / 2, jnp.pi]))) + assert bool(jnp.allclose(vector, jnp.array([0, jnp.pi / 2, jnp.pi]))) diff --git a/tests/unit/distances/test_distance.py b/tests/unit/distances/test_distance.py index 5d91a285..dc9ee697 100644 --- a/tests/unit/distances/test_distance.py +++ b/tests/unit/distances/test_distance.py @@ -31,10 +31,10 @@ # --------------------------------------------------------------------------- # Non-negative float32 values bounded away from overflow (safe to double, add) -_bounded_f32 = st.floats(min_value=0.0, max_value=1e10, width=32) +_bounded_f32 = st.floats(min_value=0, max_value=1e10, width=32) # Non-negative float32 values in [0, 1] — used for algebraic property tests -_unit_f32 = st.floats(min_value=0.0, max_value=1.0, width=32) +_unit_f32 = st.floats(min_value=0, max_value=1, width=32) # --------------------------------------------------------------------------- # Named constants for float32 overflow boundary @@ -106,10 +106,7 @@ def test_shape_matches_requested( @pytest.mark.parametrize( ("value", "unit_str", "expected_shape"), - [ - (1, "kpc", ()), - ([1.0, 2.0, 3.0], "pc", (3,)), - ], + [(1, "kpc", ()), ([1, 2, 3], "pc", (3,))], ) def test_construct_from_python( self, value: object, unit_str: str, expected_shape: tuple[int, ...] @@ -121,7 +118,7 @@ def test_construct_from_python( def test_construct_from_jnp_array(self) -> None: """Distances can be constructed from a JAX array.""" - arr = jnp.array([0.0, 1.0, 2.0]) + arr = jnp.array([0, 1, 2]) d = cxd.Distance(arr, "kpc") assert isinstance(d, cxd.Distance) assert d.shape == (3,) @@ -131,7 +128,7 @@ def test_construct_from_jnp_array(self) -> None: def test_invalid_unit_raises(self) -> None: """Non-length units are rejected at construction time.""" with pytest.raises(ValueError, match="dimensions length"): - cxd.Distance(1.0, "rad") + cxd.Distance(1, "rad") def test_negative_raises_when_checked(self) -> None: """Negative values raise when check_negative=True.""" @@ -139,7 +136,7 @@ def test_negative_raises_when_checked(self) -> None: (eqx.EquinoxRuntimeError, ValueError), match="Distance must be non-negative", ): - cxd.Distance(-1.0, "kpc", check_negative=True) + cxd.Distance(-1, "kpc", check_negative=True) class TestDistanceConversion: @@ -147,10 +144,7 @@ class TestDistanceConversion: @pytest.mark.parametrize( ("value", "from_unit", "to_unit", "expected"), - [ - (1.0, "kpc", "pc", 1000.0), - (1000.0, "pc", "kpc", 1.0), - ], + [(1, "kpc", "pc", 1000), (1000, "pc", "kpc", 1)], ) def test_unit_conversion( self, value: float, from_unit: str, to_unit: str, expected: float @@ -183,7 +177,7 @@ def test_add_returns_distance(self, d: cxd.Distance) -> None: def test_sub_self_is_zero(self, d: cxd.Distance) -> None: result = d - d assert isinstance(result, cxd.Distance) - assert jnp.allclose(result.value, 0.0) + assert jnp.allclose(result.value, 0) @given(d=cxst.distances()) def test_neg_degrades_to_quantity(self, d: cxd.Distance) -> None: @@ -229,7 +223,7 @@ def test_add_overflow_produces_inf(self, d: cxd.Distance) -> None: @given(d=cxst.distances()) def test_mul_quantity_promotes(self, d: cxd.Distance) -> None: """Distance x dimensioned Quantity yields a plain Quantity, not a Distance.""" - result = d * u.Q(2.0, "s") + result = d * u.Q(2, "s") assert isinstance(result, u.AbstractQuantity) assert not isinstance(result, cxd.Distance) @@ -256,11 +250,11 @@ def test_add_associativity( def test_atan2_promotes_to_angle(self) -> None: """atan2(Distance, Distance) yields an angle-dimensioned Quantity.""" - result = qnp.atan2(cxd.Distance(1.0, "m"), cxd.Distance(3.0, "km")) + result = qnp.atan2(cxd.Distance(1, "m"), cxd.Distance(3, "km")) assert isinstance(result, u.AbstractQuantity) assert not isinstance(result, cxd.Distance) assert u.dimension_of(result) == u.dimension("angle") - assert jnp.allclose(result.ustrip("rad"), jnp.atan2(1.0, 3000.0)) + assert jnp.allclose(result.ustrip("rad"), jnp.atan2(1, 3000)) class TestDistanceConversionProperties: @@ -327,4 +321,4 @@ def test_grad_through_distance(self, d: cxd.Distance) -> None: """jax.grad differentiates through quaxed sum; d/dx sum(x) == 1.""" g = jax.grad(lambda x: qnp.sum(x).value)(d) assert isinstance(g, cxd.Distance) - assert jnp.allclose(g.value, 1.0, atol=1e-5) + assert jnp.allclose(g.value, 1, atol=1e-5) diff --git a/tests/unit/representations/test_change_basis.py b/tests/unit/representations/test_change_basis.py index 87a88fc6..e538fd7a 100644 --- a/tests/unit/representations/test_change_basis.py +++ b/tests/unit/representations/test_change_basis.py @@ -13,7 +13,7 @@ import coordinax.main as cx import coordinax.manifolds as cxm import coordinax.representations as cxr -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax.internal import QMatrix, UnitsMatrix from coordinax.representations._src.basis_change import _qm_triangular_solve @@ -124,15 +124,15 @@ def test_round_trip_spherical_non_cartesian(self): ) def test_round_trip_diagonal_metric(self): - metric = cxm.EuclideanMetric(3) + M = cxm.R3 v = {"r": u.Q(4, "m/s"), "theta": u.Q(0.5, "rad/s"), "phi": u.Q(0.25, "rad/s")} at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0, "rad")} v_phys = cxr.change_basis( - v, cxc.sph3d, metric, cxr.coord_basis, cxr.phys_basis, at=at + v, cxc.sph3d, M, cxr.coord_basis, cxr.phys_basis, at=at ) v_back = cxr.change_basis( - v_phys, cxc.sph3d, metric, cxr.phys_basis, cxr.coord_basis, at=at + v_phys, cxc.sph3d, M, cxr.phys_basis, cxr.coord_basis, at=at ) np.testing.assert_allclose( @@ -146,17 +146,19 @@ def test_round_trip_diagonal_metric(self): ) def test_round_trip_general_metric(self): - metric = cxm.InducedMetric( - cxm.TwoSphereIn3D(radius=u.Q(1, "km")), cxm.EuclideanMetric(3) + M = cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=u.Q(1, "km")), ) v = {"theta": u.Q(1, "rad/s"), "phi": u.Q(2, "rad/s")} at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0, "rad")} v_phys = cxr.change_basis( - v, cxc.sph2, metric, cxr.coord_basis, cxr.phys_basis, at=at + v, cxc.sph2, M, cxr.coord_basis, cxr.phys_basis, at=at ) v_back = cxr.change_basis( - v_phys, cxc.sph2, metric, cxr.phys_basis, cxr.coord_basis, at=at + v_phys, cxc.sph2, M, cxr.phys_basis, cxr.coord_basis, at=at ) np.testing.assert_allclose( @@ -232,17 +234,165 @@ def single(v: dict[str, Any], at: dict[str, Any]) -> dict[str, Any]: np.testing.assert_allclose(batched["y"], 0) +class TestChangeBasisManifold: + """Tests for the manifold-based change_basis dispatches (Phase 3d). + + Covers the new overloads added in Phase 3d: + - Cartesian + any manifold → identity (precedence=1) + - AbstractChart + AbstractManifold, diagonal metric path (scale factors) + - AbstractChart + AbstractManifold, general Cholesky path + - Spherical3D + EuclideanManifold → delegates to chart-specific dispatch + """ + + def test_cartesian_with_manifold_coord_to_phys_is_identity(self): + """Dispatch 7: Cartesian + any manifold, coord→phys = identity.""" + v = {"x": u.Q(3.0, "m/s"), "y": u.Q(4.0, "m/s")} + at = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")} + out = cxr.change_basis( + v, cxc.cart2d, cxm.R2, cxr.coord_basis, cxr.phys_basis, at=at + ) + np.testing.assert_allclose(u.ustrip("m/s", out["x"]), 3.0) + np.testing.assert_allclose(u.ustrip("m/s", out["y"]), 4.0) + + def test_cartesian_with_manifold_phys_to_coord_is_identity(self): + """Dispatch 7: Cartesian + any manifold, phys→coord = identity.""" + v = {"x": u.Q(3.0, "m/s"), "y": u.Q(4.0, "m/s")} + out = cxr.change_basis(v, cxc.cart2d, cxm.R2, cxr.phys_basis, cxr.coord_basis) + np.testing.assert_allclose(u.ustrip("m/s", out["x"]), 3.0) + np.testing.assert_allclose(u.ustrip("m/s", out["y"]), 4.0) + + def test_euclidean_sph3d_manifold_matches_no_metric(self): + """Dispatch 9 delegates to 8: EuclideanManifold+sph3d with no-metric.""" + v = {"r": u.Q(5, "m/s"), "theta": u.Q(1, "rad/s"), "phi": u.Q(2, "rad/s")} + at = {"r": u.Q(3, "m"), "theta": u.Q(0.5, "rad"), "phi": u.Q(0, "rad")} + + out_with = cxr.change_basis( + v, cxc.sph3d, cxm.R3, cxr.coord_basis, cxr.phys_basis, at=at + ) + out_without = cxr.change_basis( + v, cxc.sph3d, cxr.coord_basis, cxr.phys_basis, at=at + ) + + np.testing.assert_allclose( + u.ustrip("m/s", out_with["r"]), u.ustrip("m/s", out_without["r"]) + ) + np.testing.assert_allclose( + u.ustrip("m/s", out_with["theta"]), u.ustrip("m/s", out_without["theta"]) + ) + np.testing.assert_allclose( + u.ustrip("m/s", out_with["phi"]), u.ustrip("m/s", out_without["phi"]) + ) + + def test_diagonal_path_coord_to_phys_values(self): + """Diagonal path: verify scale-factor multiplication at theta=pi/2.""" + # h_r=1, h_theta=r=2, h_phi=r*sin(pi/2)=2 + v = {"r": u.Q(5, "m/s"), "theta": u.Q(1, "rad/s"), "phi": u.Q(2, "rad/s")} + at = {"r": u.Q(2, "m"), "theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0, "rad")} + out = cxr.change_basis( + v, cxc.sph3d, cxm.R3, cxr.coord_basis, cxr.phys_basis, at=at + ) + np.testing.assert_allclose(u.ustrip("m/s", out["r"]), 5.0) + np.testing.assert_allclose(u.ustrip("m/s", out["theta"]), 2.0) # 1 * 2 + np.testing.assert_allclose(u.ustrip("m/s", out["phi"]), 4.0) # 2 * 2 + + def test_diagonal_path_phys_to_coord_values(self): + """Diagonal path: verify inverse scale-factor division at theta=pi/2.""" + # h_r=1, h_theta=2, h_phi=2 + v_phys = { + "r": u.Q(5.0, "m/s"), + "theta": u.Q(2.0, "m/s"), + "phi": u.Q(4.0, "m/s"), + } + at = {"r": u.Q(2, "m"), "theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0, "rad")} + out = cxr.change_basis( + v_phys, cxc.sph3d, cxm.R3, cxr.phys_basis, cxr.coord_basis, at=at + ) + np.testing.assert_allclose(u.ustrip("m/s", out["r"]), 5.0) + np.testing.assert_allclose(u.ustrip("rad/s", out["theta"]), 1.0) # 2 / 2 + np.testing.assert_allclose(u.ustrip("rad/s", out["phi"]), 2.0) # 4 / 2 + + def test_cholesky_path_output_keys_and_units(self): + """Cholesky path (PullbackMetric): correct keys and speed dimension.""" + manifold = cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ) + v = {"theta": u.Q(1.0, "rad/s"), "phi": u.Q(2.0, "rad/s")} + at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0.0, "rad")} + out = cxr.change_basis( + v, cxc.sph2, manifold, cxr.coord_basis, cxr.phys_basis, at=at + ) + assert set(out.keys()) == {"theta", "phi"} + assert u.dimension_of(out["theta"]) == u.dimension("speed") + assert u.dimension_of(out["phi"]) == u.dimension("speed") + + +class TestChangeBasisManifoldJAX: + """JAX transformation compatibility for manifold-based change_basis dispatches.""" + + def test_jit_diagonal_path(self): + v = {"r": u.Q(5.0, "m/s"), "theta": u.Q(1.0, "rad/s"), "phi": u.Q(1.0, "rad/s")} + at = { + "r": u.Q(2.0, "m"), + "theta": u.Q(jnp.pi / 2, "rad"), + "phi": u.Q(0.0, "rad"), + } + + @jax.jit + def run(v, at): + return cxr.change_basis( + v, cxc.sph3d, cxm.R3, cxr.coord_basis, cxr.phys_basis, at=at + ) + + out = run(v, at) + np.testing.assert_allclose(u.ustrip("m/s", out["r"]), 5.0) + np.testing.assert_allclose(u.ustrip("m/s", out["theta"]), 2.0) # h=2 + + def test_jit_cholesky_path(self): + manifold = cxm.EmbeddedManifold( + intrinsic=cxm.S2, + ambient=cxm.R3, + embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ) + v = {"theta": u.Q(1.0, "rad/s"), "phi": u.Q(2.0, "rad/s")} + at = {"theta": u.Q(jnp.pi / 3, "rad"), "phi": u.Q(0.0, "rad")} + + @jax.jit + def run(v, at): + return cxr.change_basis( + v, cxc.sph2, manifold, cxr.coord_basis, cxr.phys_basis, at=at + ) + + out = run(v, at) + assert set(out.keys()) == {"theta", "phi"} + + def test_jit_cartesian_with_manifold(self): + v = {"x": u.Q(3.0, "m/s"), "y": u.Q(4.0, "m/s")} + at = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "m")} + + @jax.jit + def run(v, at): + return cxr.change_basis( + v, cxc.cart2d, cxm.R2, cxr.coord_basis, cxr.phys_basis, at=at + ) + + out = run(v, at) + np.testing.assert_allclose(u.ustrip("m/s", out["x"]), 3.0) + np.testing.assert_allclose(u.ustrip("m/s", out["y"]), 4.0) + + class TestTriangularSolveBatching: """Regression tests for batched triangular solves used by basis conversion.""" def test_qm_triangular_solve_batched_rows_scaled_correctly(self): e_val = jnp.array([[[2, 1], [0, 4]], [[3, 2], [0, 5]]]) e_unit = UnitsMatrix(((u.unit("m"), u.unit("m")), (u.unit("m"), u.unit("m")))) - e = QuantityMatrix(e_val, unit=e_unit) + e = QMatrix(e_val, unit=e_unit) b_val = jnp.array([[5, 8], [7, 10]]) b_unit = UnitsMatrix((u.unit("m/s"), u.unit("m/s"))) - b = QuantityMatrix(b_val, unit=b_unit) + b = QMatrix(b_val, unit=b_unit) out = _qm_triangular_solve(e, b) diff --git a/tests/unit/representations/test_tangent_map.py b/tests/unit/representations/test_tangent_map.py index 45589d06..c27b342f 100644 --- a/tests/unit/representations/test_tangent_map.py +++ b/tests/unit/representations/test_tangent_map.py @@ -141,11 +141,7 @@ def test_representation_dispatch_with_phys_disp(self) -> None: def test_same_chart_identity_phys_basis(self) -> None: """Same-chart optimisation also holds for PhysicalBasis inputs.""" - v = { - "r": jnp.array(2.0), - "theta": jnp.array(-1.5), - "phi": jnp.array(0.25), - } + v = {"r": jnp.array(2.0), "theta": jnp.array(-1.5), "phi": jnp.array(0.25)} at = {"r": jnp.array(3.0), "theta": jnp.array(0.5), "phi": jnp.array(0.1)} result = cxr.tangent_map(v, cxc.sph3d, cxr.phys_basis, cxc.sph3d, at=at) @@ -274,17 +270,14 @@ def jitted_tangent_map(v, at): def test_vmap(self) -> None: """tangent_map can be vmap-ped over a batch of base points.""" vs = {"x": jnp.ones(3), "y": jnp.zeros(3)} - ats = { - "x": jnp.array([1.0, 2.0, 3.0]), - "y": jnp.zeros(3), - } + ats = {"x": jnp.array([1, 2, 3]), "y": jnp.zeros(3)} def single_map(v: dict[str, Any], at: dict[str, Any]) -> dict[str, Any]: return cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) batched = jax.vmap(single_map)(vs, ats) # At y=0, any x>0: dr/dx = x/r = 1, so dr = 1 always - np.testing.assert_allclose(batched["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(batched["r"], 1, atol=1e-6) class TestTangentMapSemanticPreservation: @@ -292,14 +285,14 @@ class TestTangentMapSemanticPreservation: def test_vel_rep(self) -> None: """tangent_map works with coord_vel representation.""" - v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} - at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + v = {"x": jnp.array(1), "y": jnp.array(0)} + at = {"x": jnp.array(1), "y": jnp.array(0)} result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_vel, cxc.polar2d, at=at) - np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["r"], 1, atol=1e-6) def test_acc_rep(self) -> None: """tangent_map works with coord_acc representation.""" - v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} - at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + v = {"x": jnp.array(1), "y": jnp.array(0)} + at = {"x": jnp.array(1), "y": jnp.array(0)} result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_acc, cxc.polar2d, at=at) - np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["r"], 1, atol=1e-6) diff --git a/tests/unit/transforms/conftest.py b/tests/unit/transforms/conftest.py index ce5157f6..b5c12cc2 100644 --- a/tests/unit/transforms/conftest.py +++ b/tests/unit/transforms/conftest.py @@ -10,7 +10,7 @@ import coordinax.frames as cxf import coordinax.main as cx import coordinax.transforms as cxfm -from coordinax.internal import QuantityMatrix +from coordinax.internal import QMatrix # =================================================================== # Transform fixtures @@ -65,7 +65,7 @@ def composed_op(translate_op, rotate_op): @pytest.fixture def array_3d(): """Bare JAX array [1, 0, 0].""" - return jnp.array([1.0, 0.0, 0.0]) + return jnp.array([1, 0, 0]) @pytest.fixture @@ -76,9 +76,9 @@ def quantity_3d(): @pytest.fixture def qmatrix_3d(): - """QuantityMatrix [1, 0, 0] with uniform km units.""" - return QuantityMatrix( - jnp.array([1.0, 0.0, 0.0]), + """QMatrix [1, 0, 0] with uniform km units.""" + return QMatrix( + jnp.array([1, 0, 0]), unit=(u.unit("km"), u.unit("km"), u.unit("km")), ) @@ -117,8 +117,8 @@ def coord_xfm_3d(): # Expected results after each transform applied to (1, 0, 0) km -EXPECTED_IDENTITY = (1.0, 0.0, 0.0) # no change -EXPECTED_ROTATE = (0.0, 1.0, 0.0) # 90° z-rotation -EXPECTED_REFLECT = (-1.0, 0.0, 0.0) # reflection across yz-plane -EXPECTED_TRANSLATE = (2.0, 0.0, 0.0) # +1 km in x -EXPECTED_COMPOSED = (0.0, 2.0, 0.0) # translate then rotate +EXPECTED_IDENTITY = (1, 0, 0) # no change +EXPECTED_ROTATE = (0, 1, 0) # 90° z-rotation +EXPECTED_REFLECT = (-1, 0, 0) # reflection across yz-plane +EXPECTED_TRANSLATE = (2, 0, 0) # +1 km in x +EXPECTED_COMPOSED = (0, 2, 0) # translate then rotate diff --git a/tests/unit/transforms/test_act.py b/tests/unit/transforms/test_act.py index 1dd99733..042a9fa6 100644 --- a/tests/unit/transforms/test_act.py +++ b/tests/unit/transforms/test_act.py @@ -1,7 +1,7 @@ """Red/Green TDD tests for ``coordinax.frames.act`` dispatches. Tests every combination of {Identity, Rotate, Translate, Composed} × -{Array, Quantity, QuantityMatrix, CDict, Vector, Point+Frame, Point+XfmFrame}. +{Array, Quantity, QMatrix, CDict, Vector, Point+Frame, Point+XfmFrame}. Each class tests: - correctness: known-value checks @@ -31,7 +31,7 @@ EXPECTED_ROTATE, EXPECTED_TRANSLATE, ) -from coordinax.internal import QuantityMatrix +from coordinax.internal import QMatrix ATOL = 1e-5 @@ -56,7 +56,7 @@ def _extract_xyz(result): z = float(u.ustrip("km", d["z"])) return (x, y, z) - if isinstance(result, QuantityMatrix): + if isinstance(result, QMatrix): x = float(u.ustrip("km", u.Q(result.value[0], result.unit[0]))) y = float(u.ustrip("km", u.Q(result.value[1], result.unit[1]))) z = float(u.ustrip("km", u.Q(result.value[2], result.unit[2]))) @@ -179,12 +179,12 @@ def test_rotate_with_chart_and_rep(self, rotate_op, quantity_3d): # =================================================================== -# Level 3: QuantityMatrix +# Level 3: QMatrix # =================================================================== -class TestActOnQuantityMatrix: - """Apply transforms to ``QuantityMatrix``. +class TestActOnQMatrix: + """Apply transforms to ``QMatrix``. This is expected to be RED initially — no dispatches exist. """ @@ -207,7 +207,7 @@ def test_composed(self, composed_op, qmatrix_3d): def test_returns_quantity_matrix(self, rotate_op, qmatrix_3d): result = cxfm.act(rotate_op, None, qmatrix_3d) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) def test_rotate_roundtrip(self, rotate_op, qmatrix_3d): fwd = cxfm.act(rotate_op, None, qmatrix_3d) @@ -233,14 +233,13 @@ def test_rotate_with_chart_and_rep(self, rotate_op, qmatrix_3d): _assert_close(_extract_xyz(result), EXPECTED_ROTATE) def test_heterogeneous_units_identity(self, identity_op): - """QuantityMatrix with mixed units passes through Identity.""" - qm = QuantityMatrix( - jnp.array([1.0, 2.0, 3.0]), - unit=(u.unit("km"), u.unit("km"), u.unit("km")), + """QMatrix with mixed units passes through Identity.""" + qm = QMatrix( + jnp.array([1, 2, 3]), unit=(u.unit("km"), u.unit("km"), u.unit("km")) ) result = cxfm.act(identity_op, None, qm) - assert isinstance(result, QuantityMatrix) - _assert_close(_extract_xyz(result), (1.0, 2.0, 3.0)) + assert isinstance(result, QMatrix) + _assert_close(_extract_xyz(result), (1, 2, 3)) # =================================================================== @@ -463,7 +462,7 @@ def test_rotate_all_levels_agree( results.append(_extract_xyz(cxfm.act(rotate_op, None, array_3d))) # Level 2: Quantity results.append(_extract_xyz(cxfm.act(rotate_op, None, quantity_3d))) - # Level 3: QuantityMatrix + # Level 3: QMatrix results.append(_extract_xyz(cxfm.act(rotate_op, None, qmatrix_3d))) # Level 4: CDict results.append(_extract_xyz(cxfm.act(rotate_op, None, cdict_3d))) @@ -589,20 +588,12 @@ def rot90z(self): @pytest.fixture def at_sph(self): """Base point at the equator phi=0; Cartesian (1,0,0).""" - return { - "r": u.Q(1.0, "m"), - "theta": u.Q(jnp.pi / 2, "rad"), - "phi": u.Q(0.0, "rad"), - } + return {"r": u.Q(1, "m"), "theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0, "rad")} @pytest.fixture def v_radial_sph(self): """Purely radial velocity in spherical coord-basis.""" - return { - "r": u.Q(1.0, "m/s"), - "theta": u.Q(0.0, "rad/s"), - "phi": u.Q(0.0, "rad/s"), - } + return {"r": u.Q(1, "m/s"), "theta": u.Q(0, "rad/s"), "phi": u.Q(0, "rad/s")} def test_cart_consistency(self, rot90z, at_sph, v_radial_sph): """cart(R*v at R*p) == R * cart(v at p).""" @@ -618,7 +609,7 @@ def test_cart_consistency(self, rot90z, at_sph, v_radial_sph): ) # Rotated base point in spherical (phi: 0 -> pi/2) at_sph_rot = { - "r": u.Q(1.0, "m"), + "r": u.Q(1, "m"), "theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(jnp.pi / 2, "rad"), } @@ -693,7 +684,7 @@ def test_jit(self, rot90z, at_sph, v_radial_sph): at=at_sph, ) )(v_radial_sph) - assert abs(float(result["r"].to_value("m/s")) - 1.0) < ATOL + assert abs(float(result["r"].to_value("m/s")) - 1) < ATOL # =================================================================== @@ -709,9 +700,9 @@ def test_cart3d_velocity_to_rotated_frame(self): rot = cxfm.Rotate.from_euler("z", u.Q(90, "deg")) rotated_frame = cxf.TransformedReferenceFrame(cxf.alice, rot) - point = cx.Point.from_([1.0, 0.0, 0.0], "m", cxf.alice) + point = cx.Point.from_([1, 0, 0], "m", cxf.alice) vel = cx.Tangent( - {"x": u.Q(1.0, "m/s"), "y": u.Q(0.0, "m/s"), "z": u.Q(0.0, "m/s")}, + {"x": u.Q(1, "m/s"), "y": u.Q(0, "m/s"), "z": u.Q(0, "m/s")}, cxc.cart3d, cxr.coord_basis, cxr.vel, @@ -723,20 +714,20 @@ def test_cart3d_velocity_to_rotated_frame(self): # Point (1,0,0) rotated 90° about z -> (0,1,0) _assert_close( ( - float(result.point.data["x"].to_value("m")), - float(result.point.data["y"].to_value("m")), - float(result.point.data["z"].to_value("m")), + float(result.point.data["x"].ustrip("m")), + float(result.point.data["y"].ustrip("m")), + float(result.point.data["z"].ustrip("m")), ), - (0.0, 1.0, 0.0), + (0, 1, 0), ) # Velocity (1,0,0) m/s rotated -> (0,1,0) m/s _assert_close( ( - float(result["velocity"].data["x"].to_value("m/s")), - float(result["velocity"].data["y"].to_value("m/s")), - float(result["velocity"].data["z"].to_value("m/s")), + float(result["velocity"].data["x"].ustrip("m/s")), + float(result["velocity"].data["y"].ustrip("m/s")), + float(result["velocity"].data["z"].ustrip("m/s")), ), - (0.0, 1.0, 0.0), + (0, 1, 0), ) def test_coordinate_to_frame_then_cconvert_sph(self): @@ -744,9 +735,9 @@ def test_coordinate_to_frame_then_cconvert_sph(self): rot = cxfm.Rotate.from_euler("z", u.Q(90, "deg")) rotated_frame = cxf.TransformedReferenceFrame(cxf.alice, rot) - point = cx.Point.from_([1.0, 0.0, 0.0], "m", cxf.alice) + point = cx.Point.from_([1, 0, 0], "m", cxf.alice) vel = cx.Tangent( - {"x": u.Q(1.0, "m/s"), "y": u.Q(0.0, "m/s"), "z": u.Q(0.0, "m/s")}, + {"x": u.Q(1, "m/s"), "y": u.Q(0, "m/s"), "z": u.Q(0, "m/s")}, cxc.cart3d, cxr.coord_basis, cxr.vel, @@ -756,10 +747,10 @@ def test_coordinate_to_frame_then_cconvert_sph(self): result = coord.to_frame(rotated_frame).cconvert(cxc.sph3d) # Point should land at (r=1, theta=pi/2, phi=pi/2) - assert abs(float(result.point.data["r"].to_value("m")) - 1.0) < ATOL + assert abs(float(result.point.data["r"].to_value("m")) - 1) < ATOL assert ( abs(float(result.point.data["theta"].to_value("rad")) - jnp.pi / 2) < ATOL ) assert abs(float(result.point.data["phi"].to_value("rad")) - jnp.pi / 2) < ATOL # Velocity should be purely radial (ṙ≈1, θ̇≈0, φ̇≈0) - assert abs(float(result["velocity"].data["r"].to_value("m/s")) - 1.0) < ATOL + assert abs(float(result["velocity"].data["r"].to_value("m/s")) - 1) < ATOL diff --git a/tests/unit/vectors/test_tangent.py b/tests/unit/vectors/test_tangent.py index dd6cc0ab..5f9d3ef1 100644 --- a/tests/unit/vectors/test_tangent.py +++ b/tests/unit/vectors/test_tangent.py @@ -545,15 +545,15 @@ def test_slice_preserves_chart_basis_semantic(self): def test_slice_values_correct(self): """Indexed values match the original batch element.""" data = { - "x": u.Q([1.0, 2.0], "m/s"), - "y": u.Q([3.0, 4.0], "m/s"), - "z": u.Q([5.0, 6.0], "m/s"), + "x": u.Q([1, 2], "m/s"), + "y": u.Q([3, 4], "m/s"), + "z": u.Q([5, 6], "m/s"), } v = cx.Tangent( data=data, chart=cxc.cart3d, basis=cxr.coord_basis, semantic=cxr.vel ) - assert v[1]["x"].value == jnp.array(2.0) - assert v[1]["y"].value == jnp.array(4.0) + assert v[1]["x"].value == jnp.array(2) + assert v[1]["y"].value == jnp.array(4) # ====================================================================== diff --git a/tests/usage/charts/test_jacobian.py b/tests/usage/charts/test_jacobian.py index 1effa0e1..3370f272 100644 --- a/tests/usage/charts/test_jacobian.py +++ b/tests/usage/charts/test_jacobian.py @@ -31,16 +31,14 @@ import unxt as u import coordinax.charts as cxc -from coordinax.internal import QuantityMatrix +from coordinax.internal import QMatrix # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _assert_jacobian_approx( - J1: QuantityMatrix, J2: QuantityMatrix, *, atol: float = 1e-5 -) -> None: +def _assert_jacobian_approx(J1: QMatrix, J2: QMatrix, *, atol: float = 1e-5) -> None: """Assert two Jacobians agree entry-wise (values only).""" np.testing.assert_allclose( np.asarray(J1.value), @@ -73,7 +71,7 @@ def test_curried_reuse_across_points(self) -> None: for x, y in [(1.0, 0.0), (0.0, 1.0), (1.0, 1.0)]: at = {"x": u.Q(x, "m"), "y": u.Q(y, "m")} J = jac_fn(at) - assert isinstance(J, QuantityMatrix) + assert isinstance(J, QMatrix) assert J.value.shape == (2, 2) def test_none_partial_matches_direct(self) -> None: @@ -99,7 +97,7 @@ def test_jit_curried_cart3d_sph3d(self) -> None: at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} jac_fn = jax.jit(cxc.jac_pt_map(cxc.cart3d, cxc.sph3d, usys=u.unitsystems.si)) J = jac_fn(at) - assert isinstance(J, QuantityMatrix) + assert isinstance(J, QMatrix) assert J.value.shape == (3, 3) # At (1,0,0): ∂r/∂x = 1 np.testing.assert_allclose(J.value[0, 0], 1.0, atol=1e-6) @@ -120,7 +118,7 @@ def jitted(at): return cxc.jac_pt_map(at, cxc.cart3d, cxc.sph3d) J = jitted(at) - assert isinstance(J, QuantityMatrix) + assert isinstance(J, QMatrix) np.testing.assert_allclose(J.value[0, 0], 1.0, atol=1e-6) @@ -139,9 +137,9 @@ def test_vmap_cart3d_sph3d(self) -> None: def single(x, y, z): return jac_fn({"x": u.Q(x, "m"), "y": u.Q(y, "m"), "z": u.Q(z, "m")}) - xs = jnp.array([1.0, 0.0, 3.0]) - ys = jnp.array([0.0, 1.0, 4.0]) - zs = jnp.array([0.0, 0.0, 0.0]) + xs = jnp.array([1, 0, 3]) + ys = jnp.array([0, 1, 4]) + zs = jnp.array([0, 0, 0]) Js = jax.vmap(single)(xs, ys, zs) assert Js.value.shape == (3, 3, 3) diff --git a/tests/usage/frames/test_act_usage.py b/tests/usage/frames/test_act_usage.py index 3b2dbfd6..ad27bd19 100644 --- a/tests/usage/frames/test_act_usage.py +++ b/tests/usage/frames/test_act_usage.py @@ -6,7 +6,7 @@ 1. ArrayLike (+usys) 2. Quantity -3. QuantityMatrix +3. QMatrix 4. CDict 5. Vector 6. Coordinate (with a concrete frame) @@ -29,7 +29,7 @@ import coordinax.main as cx import coordinax.representations as cxr import coordinax.transforms as cxfm -from coordinax.internal import QuantityMatrix +from coordinax.internal import QMatrix # =================================================================== # Helpers @@ -71,7 +71,7 @@ class TestIdentityUsage: """Identity transform leaves every object unchanged.""" def test_on_quantity(self): - q = u.Q([1.0, 2.0, 3.0], "km") + q = u.Q([1, 2, 3], "km") result = cxfm.act(cxfm.Identity(), None, q) assert result is q, "Identity should return the exact same object" @@ -100,51 +100,51 @@ def test_on_array(self, rot90z, usys): Must pass usys so the framework knows how to interpret the bare array as having metre units. """ - x = jnp.array([1.0, 0.0, 0.0]) + x = jnp.array([1, 0, 0]) result = cxfm.act(rot90z, None, x, cxc.cart3d, cxr.point, usys=usys) - _assert_close(result, [0.0, 1.0, 0.0]) + _assert_close(result, [0, 1, 0]) def test_on_quantity(self, rot90z): """Quantity [1,0,0] km → [0,1,0] km.""" - q = u.Q([1.0, 0.0, 0.0], "km") + q = u.Q([1, 0, 0], "km") result = cxfm.act(rot90z, None, q) - _assert_close(result.value, [0.0, 1.0, 0.0]) + _assert_close(result.value, [0, 1, 0]) assert result.unit == u.unit("km") def test_on_quantity_matrix(self, rot90z): - """QuantityMatrix (1,0,0) km → (0,1,0) km.""" - qm = QuantityMatrix( - jnp.array([1.0, 0.0, 0.0]), + """QMatrix (1,0,0) km → (0,1,0) km.""" + qm = QMatrix( + jnp.array([1, 0, 0]), unit=(u.unit("km"), u.unit("km"), u.unit("km")), ) result = cxfm.act(rot90z, None, qm) - assert isinstance(result, QuantityMatrix) - _assert_close(result.value, [0.0, 1.0, 0.0]) + assert isinstance(result, QMatrix) + _assert_close(result.value, [0, 1, 0]) def test_on_quantity_matrix_mixed_units(self, rot90z): - """QuantityMatrix with km,m,m: converted to common unit internally.""" - qm = QuantityMatrix( - jnp.array([1.0, 0.0, 0.0]), + """QMatrix with km,m,m: converted to common unit internally.""" + qm = QMatrix( + jnp.array([1, 0, 0]), unit=(u.unit("km"), u.unit("m"), u.unit("m")), ) result = cxfm.act(rot90z, None, qm) - assert isinstance(result, QuantityMatrix) + assert isinstance(result, QMatrix) # Internal conversion normalizes to common unit (km); x→0, y→1km, z→0 - _assert_close(result.value, [0.0, 1.0, 0.0], atol=1e-12) + _assert_close(result.value, [0, 1, 0], atol=1e-12) def test_on_cdict(self, rot90z): """CDict {x:1, y:0, z:0} km → {x:0, y:1, z:0} km.""" v = {"x": u.Q(1, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")} result = cxfm.act(rot90z, None, v, cxc.cart3d, cxr.point) - _assert_close(result["y"].value, 1.0) - _assert_close(result["x"].value, 0.0) + _assert_close(result["y"].value, 1) + _assert_close(result["x"].value, 0) def test_on_vector(self, rot90z): """Vector(x=1km) → Vector(y=1km) under 90° z-rotation.""" vec = cx.Point.from_({"x": u.Q(1, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")}) result = cxfm.act(rot90z, None, vec) - _assert_close(result.data["y"].value, 1.0) - _assert_close(result.data["x"].value, 0.0) + _assert_close(result.data["y"].value, 1) + _assert_close(result.data["x"].value, 0) def test_on_coordinate(self, rot90z): """Point at (1,0,0) km in Alice frame → (0,1,0) km.""" @@ -153,8 +153,8 @@ def test_on_coordinate(self, rot90z): cxf.alice, ) result = cxfm.act(rot90z, None, coord) - _assert_close(result.data["y"].value, 1.0) - _assert_close(result.data["x"].value, 0.0) + _assert_close(result.data["y"].value, 1) + _assert_close(result.data["x"].value, 0) # =================================================================== @@ -167,41 +167,40 @@ class TestTranslateUsage: def test_on_array(self, shift_1_2_3, usys): """Array [0,0,0] + shift [1,2,3] km → [1000,2000,3000] (in metres).""" - x = jnp.array([0.0, 0.0, 0.0]) + x = jnp.array([0, 0, 0]) result = cxfm.act(shift_1_2_3, None, x, cxc.cart3d, cxr.point, usys=usys) - _assert_close(result, [1000.0, 2000.0, 3000.0]) + _assert_close(result, [1000, 2000, 3000]) def test_on_quantity(self, shift_1_2_3): """Quantity [0,0,0] km + shift → [1,2,3] km.""" - q = u.Q([0.0, 0.0, 0.0], "km") + q = u.Q([0, 0, 0], "km") result = cxfm.act(shift_1_2_3, None, q) - _assert_close(result.value, [1.0, 2.0, 3.0]) + _assert_close(result.value, [1, 2, 3]) def test_on_quantity_matrix(self, shift_1_2_3): - """QuantityMatrix [0,0,0] km + shift → [1,2,3] km.""" - qm = QuantityMatrix( - jnp.array([0.0, 0.0, 0.0]), - unit=(u.unit("km"), u.unit("km"), u.unit("km")), + """QMatrix [0,0,0] km + shift → [1,2,3] km.""" + qm = QMatrix( + jnp.array([0, 0, 0]), unit=(u.unit("km"), u.unit("km"), u.unit("km")) ) result = cxfm.act(shift_1_2_3, None, qm) - assert isinstance(result, QuantityMatrix) - _assert_close(result.value, [1.0, 2.0, 3.0]) + assert isinstance(result, QMatrix) + _assert_close(result.value, [1, 2, 3]) def test_on_cdict(self, shift_1_2_3): """CDict {x:0, y:0, z:0} km → {x:1, y:2, z:3} km.""" v = {"x": u.Q(0, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")} result = cxfm.act(shift_1_2_3, None, v, cxc.cart3d, cxr.point) - _assert_close(result["x"].value, 1.0) - _assert_close(result["y"].value, 2.0) - _assert_close(result["z"].value, 3.0) + _assert_close(result["x"].value, 1) + _assert_close(result["y"].value, 2) + _assert_close(result["z"].value, 3) def test_on_vector(self, shift_1_2_3): """Vector at origin + (1,2,3) km shift.""" vec = cx.Point.from_({"x": u.Q(0, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")}) result = cxfm.act(shift_1_2_3, None, vec) - _assert_close(result.data["x"].value, 1.0) - _assert_close(result.data["y"].value, 2.0) - _assert_close(result.data["z"].value, 3.0) + _assert_close(result.data["x"].value, 1) + _assert_close(result.data["y"].value, 2) + _assert_close(result.data["z"].value, 3) def test_on_coordinate(self, shift_1_2_3): """Point at origin in Alice → translated to (1,2,3) km.""" @@ -210,9 +209,9 @@ def test_on_coordinate(self, shift_1_2_3): cxf.alice, ) result = cxfm.act(shift_1_2_3, None, coord) - _assert_close(result.data["x"].value, 1.0) - _assert_close(result.data["y"].value, 2.0) - _assert_close(result.data["z"].value, 3.0) + _assert_close(result.data["x"].value, 1) + _assert_close(result.data["y"].value, 2) + _assert_close(result.data["z"].value, 3) # =================================================================== @@ -230,35 +229,35 @@ def pipe(self, shift_1_2_3, rot90z): def test_on_quantity(self, pipe): """[0,0,0] km → translate → [1,2,3] km → rotate 90°z → [-2,1,3] km.""" - q = u.Q([0.0, 0.0, 0.0], "km") + q = u.Q([0, 0, 0], "km") result = cxfm.act(pipe, None, q) - _assert_close(result.value, [-2.0, 1.0, 3.0]) + _assert_close(result.value, [-2, 1, 3]) def test_on_quantity_matrix(self, pipe): - """QuantityMatrix through composed pipeline.""" - qm = QuantityMatrix( - jnp.array([0.0, 0.0, 0.0]), + """QMatrix through composed pipeline.""" + qm = QMatrix( + jnp.array([0, 0, 0]), unit=(u.unit("km"), u.unit("km"), u.unit("km")), ) result = cxfm.act(pipe, None, qm) - assert isinstance(result, QuantityMatrix) - _assert_close(result.value, [-2.0, 1.0, 3.0]) + assert isinstance(result, QMatrix) + _assert_close(result.value, [-2, 1, 3]) def test_on_cdict(self, pipe): """CDict through composed pipeline.""" v = {"x": u.Q(0, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")} result = cxfm.act(pipe, None, v, cxc.cart3d, cxr.point) - _assert_close(result["x"].value, -2.0) - _assert_close(result["y"].value, 1.0) - _assert_close(result["z"].value, 3.0) + _assert_close(result["x"].value, -2) + _assert_close(result["y"].value, 1) + _assert_close(result["z"].value, 3) def test_on_vector(self, pipe): """Vector through composed pipeline.""" vec = cx.Point.from_({"x": u.Q(0, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")}) result = cxfm.act(pipe, None, vec) - _assert_close(result.data["x"].value, -2.0) - _assert_close(result.data["y"].value, 1.0) - _assert_close(result.data["z"].value, 3.0) + _assert_close(result.data["x"].value, -2) + _assert_close(result.data["y"].value, 1) + _assert_close(result.data["z"].value, 3) def test_on_coordinate(self, pipe): """Point through composed pipeline.""" @@ -267,9 +266,9 @@ def test_on_coordinate(self, pipe): cxf.alice, ) result = cxfm.act(pipe, None, coord) - _assert_close(result.data["x"].value, -2.0) - _assert_close(result.data["y"].value, 1.0) - _assert_close(result.data["z"].value, 3.0) + _assert_close(result.data["x"].value, -2) + _assert_close(result.data["y"].value, 1) + _assert_close(result.data["z"].value, 3) # =================================================================== @@ -282,14 +281,14 @@ class TestRoundtripUsage: def test_rotate_roundtrip(self, rot90z): """Rotate then inverse-rotate a Quantity recovers original.""" - q = u.Q([3.0, -1.0, 2.0], "km") + q = u.Q([3, -1, 2], "km") fwd = cxfm.act(rot90z, None, q) back = cxfm.act(rot90z.inverse, None, fwd) _assert_close(back.value, q.value) def test_translate_roundtrip(self, shift_1_2_3): """Translate then inverse-translate recovers original.""" - q = u.Q([5.0, 7.0, -3.0], "km") + q = u.Q([5, 7, -3], "km") fwd = cxfm.act(shift_1_2_3, None, q) back = cxfm.act(shift_1_2_3.inverse, None, fwd) _assert_close(back.value, q.value) @@ -297,7 +296,7 @@ def test_translate_roundtrip(self, shift_1_2_3): def test_composed_roundtrip(self, rot90z, shift_1_2_3): """Composed pipeline and its inverse recover original.""" pipe = cxfm.Composed((shift_1_2_3, rot90z)) - q = u.Q([2.0, 4.0, 6.0], "km") + q = u.Q([2, 4, 6], "km") fwd = cxfm.act(pipe, None, q) back = cxfm.act(pipe.inverse, None, fwd) _assert_close(back.value, q.value) @@ -307,19 +306,19 @@ def test_roundtrip_on_vector(self, rot90z): vec = cx.Point.from_({"x": u.Q(3, "km"), "y": u.Q(-1, "km"), "z": u.Q(2, "km")}) fwd = cxfm.act(rot90z, None, vec) back = cxfm.act(rot90z.inverse, None, fwd) - _assert_close(back.data["x"].value, 3.0) - _assert_close(back.data["y"].value, -1.0) - _assert_close(back.data["z"].value, 2.0) + _assert_close(back.data["x"].value, 3) + _assert_close(back.data["y"].value, -1) + _assert_close(back.data["z"].value, 2) def test_roundtrip_on_quantity_matrix(self, rot90z): - """Rotate then inverse-rotate a QuantityMatrix recovers original.""" - qm = QuantityMatrix( - jnp.array([3.0, -1.0, 2.0]), + """Rotate then inverse-rotate a QMatrix recovers original.""" + qm = QMatrix( + jnp.array([3, -1, 2]), unit=(u.unit("km"), u.unit("km"), u.unit("km")), ) fwd = cxfm.act(rot90z, None, qm) back = cxfm.act(rot90z.inverse, None, fwd) - _assert_close(back.value, [3.0, -1.0, 2.0]) + _assert_close(back.value, [3, -1, 2]) # =================================================================== @@ -331,13 +330,13 @@ class TestCallSyntaxUsage: """Transforms are callable: op(x) == act(op, None, x).""" def test_rotate_call(self, rot90z): - q = u.Q([1.0, 0.0, 0.0], "km") + q = u.Q([1, 0, 0], "km") via_act = cxfm.act(rot90z, None, q) via_call = rot90z(q) _assert_close(via_call.value, via_act.value) def test_translate_call(self, shift_1_2_3): - q = u.Q([0.0, 0.0, 0.0], "km") + q = u.Q([0, 0, 0], "km") via_act = cxfm.act(shift_1_2_3, None, q) via_call = shift_1_2_3(q) _assert_close(via_call.value, via_act.value) From 752bab7f833e55e4b9af7a43d68c9b3a273a443d Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:28:39 -0400 Subject: [PATCH 06/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(coordinax.a?= =?UTF-8?q?stro):=20adapt=20to=20metric=20API=20renames=20and=20docstring?= =?UTF-8?q?=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove Examples section headers from docstrings in distance_modulus, frame_transforms, parallax, register_constructors, and register_converters. Update tests to use new metric and frame API conventions. --- .../coordinax/astro/_src/distance_modulus.py | 14 --------- .../coordinax/astro/_src/frame_transforms.py | 6 ---- .../src/coordinax/astro/_src/parallax.py | 12 ------- .../astro/_src/register_constructors.py | 6 ---- .../astro/_src/register_converters.py | 4 --- .../tests/hypothesis/test_distance_moduli.py | 2 +- .../tests/hypothesis/test_parallaxes.py | 10 +++--- .../tests/integration/test_plum.py | 4 +-- .../unit/distances/test_distance_modulus.py | 16 +++------- .../tests/unit/distances/test_parallax.py | 5 ++- .../tests/unit/test_frame_transforms.py | 31 ++++++------------- 11 files changed, 24 insertions(+), 86 deletions(-) diff --git a/packages/coordinax.astro/src/coordinax/astro/_src/distance_modulus.py b/packages/coordinax.astro/src/coordinax/astro/_src/distance_modulus.py index 647ea660..dc754c58 100644 --- a/packages/coordinax.astro/src/coordinax/astro/_src/distance_modulus.py +++ b/packages/coordinax.astro/src/coordinax/astro/_src/distance_modulus.py @@ -56,8 +56,6 @@ def from_( ) -> DistanceModulus: """Construct a distance. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import DistanceModulus @@ -74,8 +72,6 @@ def from_( ) -> DistanceModulus: """Compute distance modulus from distance modulus. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import DistanceModulus @@ -96,8 +92,6 @@ def from_( def from_(cls: type[DistanceModulus], dm: u.Q["mag"], /, **kw: Any) -> DistanceModulus: """Compute parallax from parallax. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import DistanceModulus @@ -119,8 +113,6 @@ def from_( ) -> DistanceModulus: """Compute distance modulus from distance. - Examples - -------- >>> import coordinax.distances as cxd >>> from coordinax.astro import DistanceModulus @@ -139,8 +131,6 @@ def from_( ) -> DistanceModulus: """Compute distance modulus from distance. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import DistanceModulus @@ -157,8 +147,6 @@ def from_( def from_(cls: type[DistanceModulus], p: u.Q["angle"], /, **kw: Any) -> DistanceModulus: """Compute distance modulus from parallax. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import DistanceModulus @@ -181,8 +169,6 @@ def from_( ) -> cxd.Distance: """Compute distance from distance modulus. - Examples - -------- >>> import coordinax.distances as cxd >>> from coordinax.astro import DistanceModulus diff --git a/packages/coordinax.astro/src/coordinax/astro/_src/frame_transforms.py b/packages/coordinax.astro/src/coordinax/astro/_src/frame_transforms.py index 6927b00a..81b72d89 100644 --- a/packages/coordinax.astro/src/coordinax/astro/_src/frame_transforms.py +++ b/packages/coordinax.astro/src/coordinax/astro/_src/frame_transforms.py @@ -26,8 +26,6 @@ def frame_transition( ) -> cxfm.Composed: """Compute frame transformations with ICRS as the intermediary. - Examples - -------- >>> import plum >>> import unxt as u >>> import coordinax.frames as cxf @@ -69,8 +67,6 @@ def frame_transition( def frame_transition(from_frame: ICRS, to_frame: ICRS, /) -> cxfm.Identity: """Return an identity operator for the ICRS->ICRS transformation. - Examples - -------- >>> import coordinax.frames as cxf >>> import coordinax.astro as cxastro @@ -92,8 +88,6 @@ def frame_transition( ) -> cxfm.Composed: """Return a sequence of operators for the Galactocentric frame self transformation. - Examples - -------- >>> import unxt as u >>> import coordinax.frames as cxf >>> import coordinax.astro as cxastro diff --git a/packages/coordinax.astro/src/coordinax/astro/_src/parallax.py b/packages/coordinax.astro/src/coordinax/astro/_src/parallax.py index e6419754..2b150960 100644 --- a/packages/coordinax.astro/src/coordinax/astro/_src/parallax.py +++ b/packages/coordinax.astro/src/coordinax/astro/_src/parallax.py @@ -83,8 +83,6 @@ def __check_init__(self) -> None: def from_(cls: type[Parallax], value: ArrayLike, unit: Any, /, **kw: Any) -> Parallax: """Construct a distance. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import Parallax @@ -99,8 +97,6 @@ def from_(cls: type[Parallax], value: ArrayLike, unit: Any, /, **kw: Any) -> Par def from_(cls: type[Parallax], p: Parallax, /, **kw: Any) -> Parallax: """Compute parallax from parallax. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import Parallax @@ -121,8 +117,6 @@ def from_(cls: type[Parallax], p: Parallax, /, **kw: Any) -> Parallax: def from_(cls: type[Parallax], p: u.Q["angle"], /, **kw: Any) -> Parallax: """Compute parallax from parallax. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import Parallax @@ -144,8 +138,6 @@ def from_( ) -> Parallax: """Compute parallax from distance. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import Parallax @@ -166,8 +158,6 @@ def from_( def from_(cls: type[Parallax], dm: u.Q["mag"], /, **kw: Any) -> Parallax: """Convert distance modulus to parallax. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import Parallax @@ -186,8 +176,6 @@ def from_(cls: type[Parallax], dm: u.Q["mag"], /, **kw: Any) -> Parallax: def from_(cls: type[cxd.Distance], p: Parallax, /, **kw: Any) -> cxd.Distance: """Compute distance from parallax. - Examples - -------- >>> import coordinax.distances as cxd >>> from coordinax.astro import Parallax diff --git a/packages/coordinax.astro/src/coordinax/astro/_src/register_constructors.py b/packages/coordinax.astro/src/coordinax/astro/_src/register_constructors.py index aefb6d6e..4e3b963b 100644 --- a/packages/coordinax.astro/src/coordinax/astro/_src/register_constructors.py +++ b/packages/coordinax.astro/src/coordinax/astro/_src/register_constructors.py @@ -13,10 +13,7 @@ def from_(cls: type[DistanceModulus], p: Parallax, /, **kw: Any) -> DistanceModulus: """Compute distance modulus from parallax. - Examples - -------- >>> from coordinax.astro import DistanceModulus, Parallax - >>> p = Parallax(1, "mas") >>> DistanceModulus.from_(p) DistanceModulus(10., 'mag') @@ -31,11 +28,8 @@ def from_(cls: type[DistanceModulus], p: Parallax, /, **kw: Any) -> DistanceModu def from_(cls: type[Parallax], dm: DistanceModulus, /, **kw: Any) -> Parallax: """Convert distance modulus to parallax. - Examples - -------- >>> import unxt as u >>> from coordinax.astro import DistanceModulus, Parallax - >>> dm = DistanceModulus(10, "mag") >>> Parallax.from_(dm).uconvert("mas").round(2) Parallax(1., 'mas') diff --git a/packages/coordinax.astro/src/coordinax/astro/_src/register_converters.py b/packages/coordinax.astro/src/coordinax/astro/_src/register_converters.py index eee8c840..77ddd729 100644 --- a/packages/coordinax.astro/src/coordinax/astro/_src/register_converters.py +++ b/packages/coordinax.astro/src/coordinax/astro/_src/register_converters.py @@ -14,8 +14,6 @@ def convert_quantity_to_distmod(q: u.AbstractQuantity, /) -> DistanceModulus: """Convert any quantity to a DistanceModulus. - Examples - -------- >>> from plum import convert >>> import unxt as u >>> from coordinax.astro import DistanceModulus @@ -40,8 +38,6 @@ def convert_quantity_to_distmod(q: u.AbstractQuantity, /) -> DistanceModulus: def convert_quantity_to_parallax(q: u.AbstractQuantity, /) -> Parallax: """Convert any quantity to a Parallax. - Examples - -------- >>> from plum import convert >>> import unxt as u >>> from coordinax.astro import Parallax diff --git a/packages/coordinax.astro/tests/hypothesis/test_distance_moduli.py b/packages/coordinax.astro/tests/hypothesis/test_distance_moduli.py index b794d876..26e88763 100644 --- a/packages/coordinax.astro/tests/hypothesis/test_distance_moduli.py +++ b/packages/coordinax.astro/tests/hypothesis/test_distance_moduli.py @@ -40,7 +40,7 @@ def test_distance_modulus_2d(dm: cxastro.DistanceModulus) -> None: def test_distance_modulus_with_custom_elements(dm: cxastro.DistanceModulus) -> None: """Test distance modulus with custom elements range.""" assert isinstance(dm, cxastro.DistanceModulus) - assert 0.0 <= dm.value <= 30.0 + assert 0 <= dm.value <= 30 assert dm.unit == "mag" diff --git a/packages/coordinax.astro/tests/hypothesis/test_parallaxes.py b/packages/coordinax.astro/tests/hypothesis/test_parallaxes.py index ac3c730b..79f51237 100644 --- a/packages/coordinax.astro/tests/hypothesis/test_parallaxes.py +++ b/packages/coordinax.astro/tests/hypothesis/test_parallaxes.py @@ -61,19 +61,17 @@ def test_parallax_with_strategy_check_negative(plx: cxastro.Parallax) -> None: @given( - plx=cxastrost.parallaxes( - elements=st.floats(min_value=1.0, max_value=100.0, width=32) - ) + plx=cxastrost.parallaxes(elements=st.floats(min_value=1, max_value=100, width=32)) ) def test_parallax_with_custom_elements(plx: cxastro.Parallax) -> None: """Test parallax with custom elements range.""" assert isinstance(plx, cxastro.Parallax) - assert 1.0 <= plx.value <= 100.0 + assert 1 <= plx.value <= 100 @given( plx=cxastrost.parallaxes( - check_negative=True, elements=st.floats(min_value=0.0, max_value=10.0, width=32) + check_negative=True, elements=st.floats(min_value=0, max_value=10, width=32) ) ) def test_parallax_check_negative_with_elements(plx: cxastro.Parallax) -> None: @@ -81,7 +79,7 @@ def test_parallax_check_negative_with_elements(plx: cxastro.Parallax) -> None: assert isinstance(plx, cxastro.Parallax) # When check_negative=True and elements provided, min_value should be adjusted assert plx.value >= 0 - assert plx.value <= 10.0 + assert plx.value <= 10 class TestParallaxFromType: diff --git a/packages/coordinax.astro/tests/integration/test_plum.py b/packages/coordinax.astro/tests/integration/test_plum.py index 067e8dc1..1c40bada 100644 --- a/packages/coordinax.astro/tests/integration/test_plum.py +++ b/packages/coordinax.astro/tests/integration/test_plum.py @@ -16,7 +16,7 @@ def test_promotion_rule(a): """Test the promotion rule for angles.""" # Quantities - q = u.Q(1.0, "rad") + q = u.Q(1, "rad") # Explicit promotion test a_p, q_p = plum.promote(a, q) @@ -102,7 +102,7 @@ def test_distance_to_dm_to_distance(self, d: cxd.Distance) -> None: d_back_pc = u.ustrip("pc", d_back) assert jnp.allclose(d_pc, d_back_pc, rtol=1e-4) - @given(dm=cxastrost.distance_moduli(elements={"min_value": 1.0, "max_value": 25.0})) + @given(dm=cxastrost.distance_moduli(elements={"min_value": 1, "max_value": 25})) @settings(deadline=None) def test_dm_to_distance_to_dm(self, dm: cxastro.DistanceModulus) -> None: """DM -> Distance -> DM roundtrip is consistent.""" diff --git a/packages/coordinax.astro/tests/unit/distances/test_distance_modulus.py b/packages/coordinax.astro/tests/unit/distances/test_distance_modulus.py index b4a99c07..4c56f957 100644 --- a/packages/coordinax.astro/tests/unit/distances/test_distance_modulus.py +++ b/packages/coordinax.astro/tests/unit/distances/test_distance_modulus.py @@ -59,7 +59,7 @@ def test_can_be_negative(self, dm: cxastro.DistanceModulus) -> None: def test_invalid_unit_raises(self) -> None: """DistanceModulus with non-mag unit raises ValueError.""" with pytest.raises(ValueError, match="magnitude"): - cxastro.DistanceModulus(15.0, "kpc") + cxastro.DistanceModulus(15, "kpc") class TestDistanceModulusArithmetic: @@ -76,7 +76,7 @@ def test_sub_distance_moduli(self, dm: cxastro.DistanceModulus) -> None: """DistanceModulus - DistanceModulus returns DistanceModulus with zero.""" result = dm - dm assert isinstance(result, cxastro.DistanceModulus) - assert jnp.allclose(result.value, 0.0) + assert jnp.allclose(result.value, 0) @given(dm=cxastrost.distance_moduli()) def test_scalar_mul(self, dm: cxastro.DistanceModulus) -> None: @@ -89,9 +89,7 @@ def test_scalar_mul(self, dm: cxastro.DistanceModulus) -> None: class TestDistanceModulusConversionProperties: """Tests for DistanceModulus conversion properties.""" - @given( - dm=cxastrost.distance_moduli(elements={"min_value": -5.0, "max_value": 25.0}) - ) + @given(dm=cxastrost.distance_moduli(elements={"min_value": -5, "max_value": 25})) @settings(deadline=None) def test_distance_property(self, dm: cxastro.DistanceModulus) -> None: """.distance property returns a Distance.""" @@ -147,18 +145,14 @@ def test_convert_to_quantity(self, dm: cxastro.DistanceModulus) -> None: assert q.unit is dm.unit assert q.value is dm.value - @given( - dm=cxastrost.distance_moduli(elements={"min_value": -5.0, "max_value": 25.0}) - ) + @given(dm=cxastrost.distance_moduli(elements={"min_value": -5, "max_value": 25})) @settings(deadline=None) def test_convert_to_distance(self, dm: cxastro.DistanceModulus) -> None: """Can convert DistanceModulus to Distance.""" d = plum.convert(dm, cxd.Distance) assert isinstance(d, cxd.Distance) - @given( - dm=cxastrost.distance_moduli(elements={"min_value": -5.0, "max_value": 25.0}) - ) + @given(dm=cxastrost.distance_moduli(elements={"min_value": -5, "max_value": 25})) @settings(deadline=None) def test_convert_to_parallax(self, dm: cxastro.DistanceModulus) -> None: """Can convert DistanceModulus to Parallax.""" diff --git a/packages/coordinax.astro/tests/unit/distances/test_parallax.py b/packages/coordinax.astro/tests/unit/distances/test_parallax.py index 29b3e3be..9b33e484 100644 --- a/packages/coordinax.astro/tests/unit/distances/test_parallax.py +++ b/packages/coordinax.astro/tests/unit/distances/test_parallax.py @@ -55,10 +55,9 @@ def test_scalar_default(self, plx: cxastro.Parallax) -> None: def test_negative_raises(self) -> None: """Parallax with negative value raises when check_negative=True.""" with pytest.raises( - (eqx.EquinoxRuntimeError, ValueError), - match="Parallax must be non-negative", + (eqx.EquinoxRuntimeError, ValueError), match="Parallax must be non-negative" ): - cxastro.Parallax(-1.0, "mas", check_negative=True) + cxastro.Parallax(-1, "mas", check_negative=True) @given(plx=cxastrost.parallaxes()) def test_has_value_and_unit(self, plx: cxastro.Parallax) -> None: diff --git a/packages/coordinax.astro/tests/unit/test_frame_transforms.py b/packages/coordinax.astro/tests/unit/test_frame_transforms.py index 4b69f635..b5db6ca2 100644 --- a/packages/coordinax.astro/tests/unit/test_frame_transforms.py +++ b/packages/coordinax.astro/tests/unit/test_frame_transforms.py @@ -90,37 +90,28 @@ def _astropy_gcf_to_icrs_xyz_pc(xyz_pc: Iterable[float], frame: cxastro.Galactoc @pytest.mark.parametrize( "xyz_pc", - [ - (0.0, 0.0, 0.0), - (100.0, -20.0, 50.0), - (-5000.0, 3200.0, 1200.0), - ], + [(0, 0, 0), (100, -20, 50), (-5000, 3200, 1200)], ) def test_icrs_to_galactocentric_matches_astropy_positions(xyz_pc) -> None: """ICRS->Galactocentric position transforms match Astropy.""" gcf = cxastro.Galactocentric() op = cxf.frame_transition(cxastro.ICRS(), gcf) - got = _to_np(cxfm.act(op, None, u.Q(jnp.asarray(xyz_pc), "pc")), "pc") + got = cxfm.act(op, None, u.Q(jnp.asarray(xyz_pc), "pc")).ustrip("pc") expected = _astropy_icrs_to_gcf_xyz_pc(xyz_pc, gcf) np.testing.assert_allclose(got, expected, rtol=0.0, atol=1e-6) @pytest.mark.parametrize( - "xyz_pc", - [ - (-8122.0, 0.0, 21.0), - (-7800.0, 600.0, -200.0), - (-9200.0, -500.0, 300.0), - ], + "xyz_pc", [(-8122, 0, 21), (-7800, 600, -200), (-9200, -500, 300)] ) def test_galactocentric_to_icrs_matches_astropy_positions(xyz_pc) -> None: """Galactocentric->ICRS position transforms match Astropy.""" gcf = cxastro.Galactocentric() op = cxf.frame_transition(gcf, cxastro.ICRS()) - got = _to_np(cxfm.act(op, None, u.Q(jnp.asarray(xyz_pc), "pc")), "pc") + got = cxfm.act(op, None, u.Q(jnp.asarray(xyz_pc), "pc")).ustrip("pc") expected = _astropy_gcf_to_icrs_xyz_pc(xyz_pc, gcf) np.testing.assert_allclose(got, expected, rtol=0.0, atol=1e-6) @@ -134,7 +125,7 @@ def test_icrs_galactocentric_transitions_are_inverse_for_positions() -> None: fwd = cxf.frame_transition(icrs, gcf) bwd = cxf.frame_transition(gcf, icrs) - q = u.Q(jnp.asarray([450.0, -100.0, 220.0]), "pc") + q = u.Q(jnp.asarray([450, -100, 220]), "pc") back = cxfm.act(bwd, None, cxfm.act(fwd, None, q)) np.testing.assert_allclose(_to_np(back, "pc"), _to_np(q, "pc"), rtol=0.0, atol=1e-6) @@ -165,7 +156,7 @@ def test_icrs_gcf_icrs_roundtrip(self, q: u.AbstractQuantity) -> None: back = cxfm.act(bwd, None, cxfm.act(fwd, None, q)) np.testing.assert_allclose( - _to_np(back, "pc"), _to_np(q, "pc"), rtol=0.0, atol=1e-6 + _to_np(back, "pc"), _to_np(q, "pc"), rtol=0, atol=1e-6 ) @given( @@ -185,9 +176,7 @@ def test_gcf_icrs_gcf_roundtrip(self, q: u.AbstractQuantity) -> None: bwd = cxf.frame_transition(icrs, gcf) back = cxfm.act(bwd, None, cxfm.act(fwd, None, q)) - np.testing.assert_allclose( - _to_np(back, "pc"), _to_np(q, "pc"), rtol=0.0, atol=1e-6 - ) + np.testing.assert_allclose(back.ustrip("pc"), q.ustrip("pc"), rtol=s, atol=1e-6) @given( q=ust.quantities( @@ -216,7 +205,7 @@ def test_inverse_is_frame_transition_in_reverse( via_bwd = cxfm.act(bwd, None, q_gcf) np.testing.assert_allclose( - _to_np(via_inverse, "pc"), _to_np(via_bwd, "pc"), rtol=0.0, atol=1e-6 + via_inverse.ustrip("pc"), via_bwd.ustrip("pc"), rtol=0.0, atol=1e-6 ) @given( @@ -234,7 +223,7 @@ def test_icrs_to_gcf_matches_astropy_on_random_positions( gcf = cxastro.Galactocentric() op = cxf.frame_transition(cxastro.ICRS(), gcf) - xyz = _to_np(q, "pc") - got = _to_np(cxfm.act(op, None, q), "pc") + xyz = q.ustrip("pc") + got = cxfm.act(op, None, q).ustrip("pc") expected = _astropy_icrs_to_gcf_xyz_pc((xyz[0], xyz[1], xyz[2]), gcf) np.testing.assert_allclose(got, expected, rtol=0.0, atol=1e-6) From 7aa4441ad3d5fbcb9b1cb6081de11d7ae1d03cfe Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:29:24 -0400 Subject: [PATCH 07/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(coordinax.c?= =?UTF-8?q?urveframes):=20adapt=20to=20updated=20frame=20and=20manifold=20?= =?UTF-8?q?API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove unused register_frames dispatches and update tests to match the current frame and manifold API conventions. --- .../src/coordinax/curveframes/_src/base.py | 2 +- .../src/coordinax/curveframes/_src/bishop.py | 4 +- .../curveframes/_src/frenetserret.py | 1 + .../curveframes/_src/register_frames.py | 24 --- .../tests/test_bishop.py | 142 +++++++++--------- .../tests/test_bishop_frame.py | 48 +++--- .../tests/test_frenet_serret.py | 126 ++++++++-------- .../tests/test_frenet_serret_frame.py | 54 ++++--- 8 files changed, 186 insertions(+), 215 deletions(-) diff --git a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/base.py b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/base.py index d7be0365..c9860e3a 100644 --- a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/base.py +++ b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/base.py @@ -300,7 +300,7 @@ def tangent(self, tau: u.AbstractQuantity, /) -> Any: Q([-0., 1., 0.], '') """ - R = self._rotation_matrix(tau) + R = self._rotation_matrix(tau.astype(float)) return u.Q(R[0], "") # --------------------------------------------------------------- diff --git a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/bishop.py b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/bishop.py index ffb0dbc0..e514d45a 100644 --- a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/bishop.py +++ b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/bishop.py @@ -340,10 +340,10 @@ def from_curve( def tangent_fn(tau: u.AbstractQuantity) -> u.AbstractQuantity: r"""Compute unit tangent $\mathbf{T} = \gamma'/\|\gamma'\|$.""" - return _normalize(dcurve(tau)) + return _normalize(dcurve(tau.astype(float))) # Compute initial tangent and normal at the reference parameter. - T0 = tangent_fn(tau_0) + T0 = tangent_fn(tau_0) # dimensionless unit vector T0_val = T0.value # dimensionless plain array if initial_normal is not None: diff --git a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/frenetserret.py b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/frenetserret.py index b5cb9c22..f260efb9 100644 --- a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/frenetserret.py +++ b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/frenetserret.py @@ -281,6 +281,7 @@ def rotation_matrix_fn(tau: u.AbstractQuantity) -> Any: 3. Cross product: $\mathbf{B} = \mathbf{T} \times \mathbf{N}$. 4. Stack rows into a $3 \times 3$ matrix. """ + tau = tau.astype(float) dp = dcurve(tau) d2p = d2curve(tau) diff --git a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/register_frames.py b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/register_frames.py index 19aae73f..1cfd18f5 100644 --- a/packages/coordinax.curveframes/src/coordinax/curveframes/_src/register_frames.py +++ b/packages/coordinax.curveframes/src/coordinax/curveframes/_src/register_frames.py @@ -98,18 +98,6 @@ def frame_transition( # noqa: F811 (\mathcal{B} \to \mathcal{A}) $$ - Parameters - ---------- - from_frame : AbstractParallelTransportFrame - The source curve-attached frame. - to_frame : AbstractReferenceFrame - The target (ambient) frame. - - Returns - ------- - AbstractTransform - The composed frame-transition operator. - Examples -------- >>> import jax.numpy as jnp @@ -153,18 +141,6 @@ def frame_transition( # noqa: F811 general to-curve-frame and from-curve-frame dispatches when both arguments are ``AbstractParallelTransportFrame``. - Parameters - ---------- - from_frame : AbstractParallelTransportFrame - The source curve-attached frame. - to_frame : AbstractParallelTransportFrame - The target curve-attached frame. - - Returns - ------- - AbstractTransform - The composed frame-transition operator. - Examples -------- >>> import jax.numpy as jnp diff --git a/packages/coordinax.curveframes/tests/test_bishop.py b/packages/coordinax.curveframes/tests/test_bishop.py index 6dee7eb8..1914b048 100644 --- a/packages/coordinax.curveframes/tests/test_bishop.py +++ b/packages/coordinax.curveframes/tests/test_bishop.py @@ -12,13 +12,13 @@ # ── Fixtures ────────────────────────────────────────────────────────── -def _circle_curve(tau: u.Quantity) -> u.Quantity: +def _circle_curve(tau: u.Q) -> u.Q: """Unit circle in the x-y plane, period = 2*pi seconds.""" t = tau.ustrip("s") return u.Q(jnp.stack([jnp.cos(t), jnp.sin(t), jnp.zeros_like(t)]), "km") -def _straight_line(tau: u.Quantity) -> u.Quantity: +def _straight_line(tau: u.Q) -> u.Q: """Straight line along x-axis (kappa=0 everywhere). Frenet-Serret frame is singular on this curve, but Bishop is not. @@ -27,19 +27,19 @@ def _straight_line(tau: u.Quantity) -> u.Quantity: return u.Q(jnp.stack([t, jnp.zeros_like(t), jnp.zeros_like(t)]), "km") -def _helix_curve(tau: u.Quantity) -> u.Quantity: +def _helix_curve(tau: u.Q) -> u.Q: """Helix with pitch along z-axis.""" t = tau.ustrip("s") return u.Q(jnp.stack([jnp.cos(t), jnp.sin(t), 0.3 * t]), "km") -def _circle_curve_yr(tau: u.Quantity) -> u.Quantity: +def _circle_curve_yr(tau: u.Q) -> u.Q: """Circle in x-y plane with tau in years.""" omega = u.Q(2 * jnp.pi, "rad/yr") phase = (omega * tau).uconvert("rad").ustrip("rad") - x = u.Q(5.0, "km") * jnp.cos(phase) - y = u.Q(5.0, "km") * jnp.sin(phase) - z = u.Q(0.0, "km") * jnp.ones_like(phase) + x = u.Q(5, "km") * jnp.cos(phase) + y = u.Q(5, "km") * jnp.sin(phase) + z = u.Q(0, "km") * jnp.ones_like(phase) return qnp.stack([x, y, z], axis=-1) @@ -70,8 +70,8 @@ class TestBishopTransformLocation: """The location field should be the curve itself.""" def test_location_at_zero(self, circle_bishop: cxfc.BishopTransform): - loc = circle_bishop.location(u.Q(0.0, "s")) - expected = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + loc = circle_bishop.location(u.Q(0, "s")) + expected = u.Q(jnp.array([1, 0, 0]), "km") assert jnp.allclose(loc.value, expected.value, atol=1e-6) def test_location_is_curve(self, circle_bishop: cxfc.BishopTransform): @@ -86,25 +86,25 @@ class TestBishopTransformTangent: def test_tangent_at_zero(self, circle_bishop: cxfc.BishopTransform): """At tau=0 on a unit circle, T = (0, 1, 0).""" - T = circle_bishop.tangent(u.Q(0.0, "s")) - expected = jnp.array([0.0, 1.0, 0.0]) + T = circle_bishop.tangent(u.Q(0, "s")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_tangent_at_pi_over_2(self, circle_bishop: cxfc.BishopTransform): """At tau=pi/2, T = (-1, 0, 0).""" T = circle_bishop.tangent(u.Q(jnp.pi / 2, "s")) - expected = jnp.array([-1.0, 0.0, 0.0]) + expected = jnp.array([-1, 0, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_tangent_is_unit_vector(self, circle_bishop: cxfc.BishopTransform): T = circle_bishop.tangent(u.Q(1.23, "s")) norm = jnp.sqrt(jnp.sum(T.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) def test_tangent_straight_line(self, line_bishop: cxfc.BishopTransform): """Tangent of a straight line along x is always (1,0,0).""" - T = line_bishop.tangent(u.Q(5.0, "s")) - expected = jnp.array([1.0, 0.0, 0.0]) + T = line_bishop.tangent(u.Q(5, "s")) + expected = jnp.array([1, 0, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) @@ -117,44 +117,44 @@ class TestBishopTransformNormals: def test_normal1_is_unit_vector(self, circle_bishop: cxfc.BishopTransform): U1 = circle_bishop.normal1(u.Q(0.7, "s")) norm = jnp.sqrt(jnp.sum(U1.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) def test_normal2_is_unit_vector(self, circle_bishop: cxfc.BishopTransform): U2 = circle_bishop.normal2(u.Q(0.7, "s")) norm = jnp.sqrt(jnp.sum(U2.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) def test_normals_perpendicular_to_tangent( self, circle_bishop: cxfc.BishopTransform ): - tau = u.Q(1.0, "s") + tau = u.Q(1, "s") T = circle_bishop.tangent(tau).value U1 = circle_bishop.normal1(tau).value U2 = circle_bishop.normal2(tau).value - assert jnp.allclose(jnp.dot(T, U1), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(T, U2), 0.0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U1), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U2), 0, atol=1e-4) def test_normals_perpendicular_to_each_other( self, circle_bishop: cxfc.BishopTransform ): - tau = u.Q(1.0, "s") + tau = u.Q(1, "s") U1 = circle_bishop.normal1(tau).value U2 = circle_bishop.normal2(tau).value - assert jnp.allclose(jnp.dot(U1, U2), 0.0, atol=1e-4) + assert jnp.allclose(jnp.dot(U1, U2), 0, atol=1e-4) def test_straight_line_normals_exist(self, line_bishop: cxfc.BishopTransform): """Bishop normals are defined even on a straight line (kappa=0).""" - U1 = line_bishop.normal1(u.Q(1.0, "s")) - U2 = line_bishop.normal2(u.Q(1.0, "s")) + U1 = line_bishop.normal1(u.Q(1, "s")) + U2 = line_bishop.normal2(u.Q(1, "s")) norm1 = jnp.sqrt(jnp.sum(U1.value**2)) norm2 = jnp.sqrt(jnp.sum(U2.value**2)) - assert jnp.allclose(norm1, 1.0, atol=1e-4) - assert jnp.allclose(norm2, 1.0, atol=1e-4) + assert jnp.allclose(norm1, 1, atol=1e-4) + assert jnp.allclose(norm2, 1, atol=1e-4) def test_straight_line_normals_constant(self, line_bishop: cxfc.BishopTransform): """On a straight line, parallel transport keeps U1, U2 constant.""" - U1_0 = line_bishop.normal1(u.Q(0.0, "s")).value - U1_5 = line_bishop.normal1(u.Q(5.0, "s")).value + U1_0 = line_bishop.normal1(u.Q(0, "s")).value + U1_5 = line_bishop.normal1(u.Q(5, "s")).value assert jnp.allclose(U1_0, U1_5, atol=1e-4) @@ -164,17 +164,17 @@ def test_straight_line_normals_constant(self, line_bishop: cxfc.BishopTransform) class TestBishopTransformOrthonormality: """T, U1, U2 should form an orthonormal right-handed triad.""" - @pytest.mark.parametrize("tau_val", [0.0, 0.5, 1.0, 2.5, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 0.5, 1, 2.5, jnp.pi]) def test_orthogonality(self, circle_bishop: cxfc.BishopTransform, tau_val: float): tau = u.Q(tau_val, "s") T = circle_bishop.tangent(tau).value U1 = circle_bishop.normal1(tau).value U2 = circle_bishop.normal2(tau).value - assert jnp.allclose(jnp.dot(T, U1), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(T, U2), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(U1, U2), 0.0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U1), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U2), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(U1, U2), 0, atol=1e-4) - @pytest.mark.parametrize("tau_val", [0.0, 1.0, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 1, jnp.pi]) def test_right_handed(self, circle_bishop: cxfc.BishopTransform, tau_val: float): """U2 should equal T x U1 (right-handed frame).""" tau = u.Q(tau_val, "s") @@ -191,18 +191,18 @@ class TestBishopTransformOpaqueUnits: """Test with a curve whose internal unit (yr) differs from caller's.""" def test_tangent_at_zero(self, circle_yr_bishop: cxfc.BishopTransform): - T = circle_yr_bishop.tangent(u.Q(0.0, "yr")) - expected = jnp.array([0.0, 1.0, 0.0]) + T = circle_yr_bishop.tangent(u.Q(0, "yr")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_normals_orthogonal_at_zero(self, circle_yr_bishop: cxfc.BishopTransform): - tau = u.Q(0.0, "yr") + tau = u.Q(0, "yr") T = circle_yr_bishop.tangent(tau).value U1 = circle_yr_bishop.normal1(tau).value U2 = circle_yr_bishop.normal2(tau).value - assert jnp.allclose(jnp.dot(T, U1), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(T, U2), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(U1, U2), 0.0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U1), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U2), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(U1, U2), 0, atol=1e-4) # ── tau_0 parameter ────────────────────────────────────────────────── @@ -212,16 +212,16 @@ class TestBishopTransformTau0: """The tau_0 field sets the reference parameter.""" def test_default_tau_0(self, circle_bishop: cxfc.BishopTransform): - """Default tau_0 is Q(0.0, tau_unit).""" - assert jnp.allclose(circle_bishop.tau_0.value, 0.0) + """Default tau_0 is Q(0, tau_unit).""" + assert jnp.allclose(circle_bishop.tau_0.value, 0) def test_custom_tau_0(self): """Custom tau_0 shifts the origin of parallel transport.""" - bt = cxfc.BishopTransform.from_curve(_circle_curve, tau_0=u.Q(1.0, "s")) + bt = cxfc.BishopTransform.from_curve(_circle_curve, tau_0=u.Q(1, "s")) # The tangent is still computed correctly - T = bt.tangent(u.Q(1.0, "s")) + T = bt.tangent(u.Q(1, "s")) norm = jnp.sqrt(jnp.sum(T.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) def test_initial_normal_field(self): """initial_normal is stored for reconstruction.""" @@ -238,21 +238,21 @@ class TestBishopTransformJAX: """Verify compatibility with jit and vmap.""" def test_jit_tangent(self, circle_bishop: cxfc.BishopTransform): - T = jax.jit(circle_bishop.tangent)(u.Q(0.0, "s")) - expected = jnp.array([0.0, 1.0, 0.0]) + T = jax.jit(circle_bishop.tangent)(u.Q(0, "s")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_jit_normal1(self, circle_bishop: cxfc.BishopTransform): """normal1 (ODE-based) works under jit.""" U1 = jax.jit(circle_bishop.normal1)(u.Q(0.5, "s")) norm = jnp.sqrt(jnp.sum(U1.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-4) + assert jnp.allclose(norm, 1, atol=1e-4) def test_vmap(self, circle_bishop: cxfc.BishopTransform): taus = u.Q(jnp.linspace(0, 2 * jnp.pi, 8), "s") Ts = jax.vmap(circle_bishop.tangent)(taus) norms = jnp.sqrt(jnp.sum(Ts.value**2, axis=-1)) - assert jnp.allclose(norms, 1.0, atol=1e-5) + assert jnp.allclose(norms, 1, atol=1e-5) # ── from_ constructor ──────────────────────────────────────────────── @@ -263,8 +263,8 @@ class TestBishopTransformFrom: def test_from_dispatches_to_from_curve(self): bt = cxfc.BishopTransform.from_(_circle_curve) - loc = bt.location(u.Q(0.0, "s")) - expected = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + loc = bt.location(u.Q(0, "s")) + expected = u.Q(jnp.array([1, 0, 0]), "km") assert jnp.allclose(loc.value, expected.value, atol=1e-6) @@ -278,7 +278,7 @@ def test_inverse_is_bishop_transform(self, circle_bishop: cxfc.BishopTransform): inv = circle_bishop.inverse assert isinstance(inv, cxfc.BishopTransform) - @pytest.mark.parametrize("tau_val", [0.0, 0.5, 1.0, 2.5, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 0.5, 1, 2.5, jnp.pi]) def test_inverse_orthonormality( self, circle_bishop: cxfc.BishopTransform, tau_val: float ): @@ -287,20 +287,20 @@ def test_inverse_orthonormality( T = inv.tangent(tau).value U1 = inv.normal1(tau).value U2 = inv.normal2(tau).value - assert jnp.allclose(jnp.dot(T, U1), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(T, U2), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(U1, U2), 0.0, atol=1e-4) - assert jnp.allclose(jnp.linalg.norm(T), 1.0, atol=1e-4) - assert jnp.allclose(jnp.linalg.norm(U1), 1.0, atol=1e-4) - assert jnp.allclose(jnp.linalg.norm(U2), 1.0, atol=1e-4) - - @pytest.mark.parametrize("tau_val", [0.0, 1.0, jnp.pi]) + assert jnp.allclose(jnp.dot(T, U1), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U2), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(U1, U2), 0, atol=1e-4) + assert jnp.allclose(jnp.linalg.norm(T), 1, atol=1e-4) + assert jnp.allclose(jnp.linalg.norm(U1), 1, atol=1e-4) + assert jnp.allclose(jnp.linalg.norm(U2), 1, atol=1e-4) + + @pytest.mark.parametrize("tau_val", [0, 1, jnp.pi]) def test_roundtrip_forward_inverse( self, circle_bishop: cxfc.BishopTransform, tau_val: float ): """Apply forward then inverse: recovers original point.""" tau = u.Q(tau_val, "s") - p = u.Q(jnp.array([2.0, 3.0, 4.0]), "km") + p = u.Q(jnp.array([2, 3, 4]), "km") # Forward: p' = R @ (p - gamma) g = circle_bishop.location(tau) @@ -327,7 +327,7 @@ def test_roundtrip_forward_inverse( assert jnp.allclose(p_rec.value, p.value, atol=1e-3) - @pytest.mark.parametrize("tau_val", [0.0, 1.0, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 1, jnp.pi]) def test_double_inverse(self, circle_bishop: cxfc.BishopTransform, tau_val: float): """inverse.inverse should recover the original frame fields.""" tau = u.Q(tau_val, "s") @@ -341,14 +341,14 @@ def test_double_inverse(self, circle_bishop: cxfc.BishopTransform, tau_val: floa def test_inverse_jit(self, circle_bishop: cxfc.BishopTransform): inv = circle_bishop.inverse - loc = jax.jit(inv.location)(u.Q(0.0, "s")) + loc = jax.jit(inv.location)(u.Q(0, "s")) # Location is defined, check it runs assert loc.shape == (3,) def test_inverse_opaque_units(self, circle_yr_bishop: cxfc.BishopTransform): """Inverse should work with opaque-unit curves too.""" inv = circle_yr_bishop.inverse - tau = u.Q(0.0, "yr") + tau = u.Q(0, "yr") loc = inv.location(tau) assert loc.shape == (3,) @@ -360,23 +360,23 @@ class TestBishopTransformHelix: """Bishop frame on a helix — tests 3D behaviour.""" def test_tangent_is_unit(self, helix_bishop: cxfc.BishopTransform): - T = helix_bishop.tangent(u.Q(1.0, "s")) + T = helix_bishop.tangent(u.Q(1, "s")) norm = jnp.sqrt(jnp.sum(T.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) def test_orthonormality(self, helix_bishop: cxfc.BishopTransform): - tau = u.Q(1.0, "s") + tau = u.Q(1, "s") T = helix_bishop.tangent(tau).value U1 = helix_bishop.normal1(tau).value U2 = helix_bishop.normal2(tau).value - assert jnp.allclose(jnp.dot(T, U1), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(T, U2), 0.0, atol=1e-4) - assert jnp.allclose(jnp.dot(U1, U2), 0.0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U1), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(T, U2), 0, atol=1e-4) + assert jnp.allclose(jnp.dot(U1, U2), 0, atol=1e-4) def test_roundtrip(self, helix_bishop: cxfc.BishopTransform): """Forward then inverse roundtrip on helix.""" - tau = u.Q(1.0, "s") - p = u.Q(jnp.array([2.0, -1.0, 3.0]), "km") + tau = u.Q(1, "s") + p = u.Q(jnp.array([2, -1, 3]), "km") g = helix_bishop.location(tau) T = helix_bishop.tangent(tau) diff --git a/packages/coordinax.curveframes/tests/test_bishop_frame.py b/packages/coordinax.curveframes/tests/test_bishop_frame.py index 76f91e76..8fe34bf1 100644 --- a/packages/coordinax.curveframes/tests/test_bishop_frame.py +++ b/packages/coordinax.curveframes/tests/test_bishop_frame.py @@ -123,15 +123,15 @@ def test_act_forward_at_tau_zero(self, circle_bishop_transform) -> None: p = gamma(0) => delta=0 => result = (0,0,0) """ - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([1, 0, 0]), "km") result = cxfm.act(circle_bishop_transform, tau, p) - np.testing.assert_allclose(_as_array(result, "km"), [0.0, 0.0, 0.0], atol=1e-5) + np.testing.assert_allclose(_as_array(result, "km"), [0, 0, 0], atol=1e-5) def test_act_inverse_roundtrip(self, circle_bishop_transform) -> None: """Forward then inverse recovers original point.""" tau = u.Q(0.5, "s") - p = u.Q(jnp.array([3.0, -1.0, 2.0]), "km") + p = u.Q(jnp.array([3, -1, 2]), "km") p_curve = cxfm.act(circle_bishop_transform, tau, p) p_back = cxfm.act(circle_bishop_transform.inverse, tau, p_curve) @@ -142,10 +142,10 @@ def test_act_inverse_roundtrip(self, circle_bishop_transform) -> None: def test_act_at_different_tau_values(self, circle_bishop_transform) -> None: """Different tau values give different results for same point.""" - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + p = u.Q(jnp.array([2, 0, 0]), "km") - r1 = cxfm.act(circle_bishop_transform, u.Q(0.0, "s"), p) - r2 = cxfm.act(circle_bishop_transform, u.Q(1.0, "s"), p) + r1 = cxfm.act(circle_bishop_transform, u.Q(0, "s"), p) + r2 = cxfm.act(circle_bishop_transform, u.Q(1, "s"), p) assert not np.allclose(_as_array(r1, "km"), _as_array(r2, "km"), atol=1e-3) @@ -168,7 +168,7 @@ def test_transition_from_bishop_frame(self, circle_bishop_frame) -> None: def test_roundtrip_alice_to_bishop_and_back(self, circle_bishop_frame) -> None: """Alice -> Bishop -> Alice is identity.""" tau = u.Q(0.5, "s") - p = u.Q(jnp.array([3.0, -1.0, 2.0]), "km") + p = u.Q(jnp.array([3, -1, 2]), "km") op_fwd = cxf.frame_transition(cxf.Alice(), circle_bishop_frame) op_bwd = cxf.frame_transition(circle_bishop_frame, cxf.Alice()) @@ -183,8 +183,8 @@ def test_roundtrip_alice_to_bishop_and_back(self, circle_bishop_frame) -> None: def test_alice_bishop_alex_chain(self) -> None: """Alice -> Bishop(tau) -> Alex chain and reverse.""" b_frame = cxfc.BishopFrame.from_curve(cxf.Alice(), _circle_curve) - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([2, 0, 0]), "km") # Alice -> Bishop op_a_to_b = cxf.frame_transition(cxf.Alice(), b_frame) @@ -214,7 +214,7 @@ def test_full_chain_alice_bishop_alex_roundtrip(self) -> None: """Alice -> Bishop -> Alex -> Bishop -> Alice recovers original.""" b_frame = cxfc.BishopFrame.from_curve(cxf.Alice(), _circle_curve) tau = u.Q(0.3, "s") - p = u.Q(jnp.array([5.0, -2.0, 1.0]), "km") + p = u.Q(jnp.array([5, -2, 1]), "km") op1 = cxf.frame_transition(cxf.Alice(), b_frame) op2 = cxf.frame_transition(b_frame, cxf.Alex()) @@ -230,8 +230,8 @@ def test_full_chain_alice_bishop_alex_roundtrip(self) -> None: def test_straight_line_frame_transition(self, line_bishop_frame) -> None: """Frame transition works on a straight line (kappa=0).""" - tau = u.Q(1.0, "s") - p = u.Q(jnp.array([2.0, 1.0, 0.0]), "km") + tau = u.Q(1, "s") + p = u.Q(jnp.array([2, 1, 0]), "km") op = cxf.frame_transition(cxf.Alice(), line_bishop_frame) p_bishop = cxfm.act(op, tau, p) @@ -253,8 +253,8 @@ class TestBishopFrameJAX: def test_act_jit(self, circle_bishop_transform) -> None: """Act with BishopTransform is JIT-compatible.""" - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([2, 0, 0]), "km") result_eager = cxfm.act(circle_bishop_transform, tau, p) result_jit = jax.jit(lambda t, x: cxfm.act(circle_bishop_transform, t, x))( @@ -269,8 +269,8 @@ def test_act_jit(self, circle_bishop_transform) -> None: def test_act_vmap_over_tau(self, circle_bishop_transform) -> None: """Act can be vmapped over the tau parameter.""" - taus = u.Q(jnp.linspace(0.0, 2.0, 5), "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + taus = u.Q(jnp.linspace(0, 2, 5), "s") + p = u.Q(jnp.array([2, 0, 0]), "km") results = jax.vmap(lambda t: cxfm.act(circle_bishop_transform, t, p))(taus) @@ -286,26 +286,26 @@ class TestBishopActiveSemantics: def test_forward_moves_point_to_curve_frame(self, circle_bishop_transform) -> None: """Point at gamma(0) maps to (0,0,0) in the curve frame.""" - tau = u.Q(0.0, "s") - p_on_curve = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p_on_curve = u.Q(jnp.array([1, 0, 0]), "km") result = cxfm.act(circle_bishop_transform, tau, p_on_curve) - np.testing.assert_allclose(_as_array(result, "km"), [0.0, 0.0, 0.0], atol=1e-5) + np.testing.assert_allclose(_as_array(result, "km"), [0, 0, 0], atol=1e-5) def test_inverse_moves_point_back_to_ambient(self, circle_bishop_transform) -> None: """Origin of curve frame at tau=0 maps back to gamma(0).""" - tau = u.Q(0.0, "s") - p_origin = u.Q(jnp.array([0.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p_origin = u.Q(jnp.array([0, 0, 0]), "km") result = cxfm.act(circle_bishop_transform.inverse, tau, p_origin) - np.testing.assert_allclose(_as_array(result, "km"), [1.0, 0.0, 0.0], atol=1e-3) + np.testing.assert_allclose(_as_array(result, "km"), [1, 0, 0], atol=1e-3) def test_frame_transition_matches_direct_transform( self, circle_bishop_frame, circle_bishop_transform ) -> None: """frame_transition(Alice, bishop_frame) matches direct act.""" tau = u.Q(0.5, "s") - p = u.Q(jnp.array([2.0, 1.0, 0.0]), "km") + p = u.Q(jnp.array([2, 1, 0]), "km") op = cxf.frame_transition(cxf.Alice(), circle_bishop_frame) result_ft = cxfm.act(op, tau, p) diff --git a/packages/coordinax.curveframes/tests/test_frenet_serret.py b/packages/coordinax.curveframes/tests/test_frenet_serret.py index ff83c45b..37fd8d73 100644 --- a/packages/coordinax.curveframes/tests/test_frenet_serret.py +++ b/packages/coordinax.curveframes/tests/test_frenet_serret.py @@ -12,13 +12,13 @@ # ── Fixtures ────────────────────────────────────────────────────────── -def _circle_curve(tau: u.Quantity) -> u.Quantity: +def _circle_curve(tau: u.Q) -> u.Q: """Unit circle in the x-y plane, period = 2*pi seconds.""" t = tau.ustrip("s") return u.Q(jnp.stack([jnp.cos(t), jnp.sin(t), jnp.zeros_like(t)]), "km") -def _circle_curve_yr(tau: u.Quantity) -> u.Quantity: +def _circle_curve_yr(tau: u.Q) -> u.Q: """Circle in x-y plane with angular speed omega = 2*pi rad/yr. This curve internally converts tau to radians, so the "natural" tau-unit @@ -26,9 +26,9 @@ def _circle_curve_yr(tau: u.Quantity) -> u.Quantity: """ omega = u.Q(2 * jnp.pi, "rad/yr") phase = (omega * tau).uconvert("rad").ustrip("rad") - x = u.Q(5.0, "km") * jnp.cos(phase) - y = u.Q(5.0, "km") * jnp.sin(phase) - z = u.Q(0.0, "km") * jnp.ones_like(phase) + x = u.Q(5, "km") * jnp.cos(phase) + y = u.Q(5, "km") * jnp.sin(phase) + z = u.Q(0, "km") * jnp.ones_like(phase) return qnp.stack([x, y, z], axis=-1) @@ -49,8 +49,8 @@ class TestFrenetSerretTransformLocation: """The location field should be the curve itself.""" def test_location_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): - loc = circle_fs.location(u.Q(0.0, "s")) - expected = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + loc = circle_fs.location(u.Q(0, "s")) + expected = u.Q(jnp.array([1, 0, 0]), "km") assert jnp.allclose(loc.value, expected.value, atol=1e-6) @@ -62,24 +62,24 @@ class TestFrenetSerretTransformTangent: def test_tangent_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): """At tau=0 on a unit circle, T = (0, 1, 0).""" - T = circle_fs.tangent(u.Q(0.0, "s")) - expected = jnp.array([0.0, 1.0, 0.0]) + T = circle_fs.tangent(u.Q(0, "s")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_tangent_at_pi_over_2(self, circle_fs: cxfc.FrenetSerretTransform): """At tau=pi/2, T = (-1, 0, 0).""" T = circle_fs.tangent(u.Q(jnp.pi / 2, "s")) - expected = jnp.array([-1.0, 0.0, 0.0]) + expected = jnp.array([-1, 0, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_tangent_is_unit_vector(self, circle_fs: cxfc.FrenetSerretTransform): T = circle_fs.tangent(u.Q(1.23, "s")) norm = jnp.sqrt(jnp.sum(T.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) def test_tangent_has_speed_units(self, circle_fs: cxfc.FrenetSerretTransform): """The raw (un-normalised) derivative should carry km/s units.""" - T = circle_fs.tangent(u.Q(0.0, "s")) + T = circle_fs.tangent(u.Q(0, "s")) # After normalization: dimensionless assert T.unit == u.unit("") @@ -92,20 +92,20 @@ class TestFrenetSerretTransformNormal: def test_normal_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): """At tau=0 on a unit circle, N = (-1, 0, 0) (points inward).""" - N = circle_fs.normal(u.Q(0.0, "s")) - expected = jnp.array([-1.0, 0.0, 0.0]) + N = circle_fs.normal(u.Q(0, "s")) + expected = jnp.array([-1, 0, 0]) assert jnp.allclose(N.value, expected, atol=1e-5) def test_normal_at_pi_over_2(self, circle_fs: cxfc.FrenetSerretTransform): """At tau=pi/2, N = (0, -1, 0).""" N = circle_fs.normal(u.Q(jnp.pi / 2, "s")) - expected = jnp.array([0.0, -1.0, 0.0]) + expected = jnp.array([0, -1, 0]) assert jnp.allclose(N.value, expected, atol=1e-5) def test_normal_is_unit_vector(self, circle_fs: cxfc.FrenetSerretTransform): N = circle_fs.normal(u.Q(0.7, "s")) norm = jnp.sqrt(jnp.sum(N.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) # ── Binormal ────────────────────────────────────────────────────────── @@ -116,14 +116,14 @@ class TestFrenetSerretTransformBinormal: def test_binormal_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): """For a circle in the x-y plane, B = (0, 0, 1) everywhere.""" - B = circle_fs.binormal(u.Q(0.0, "s")) - expected = jnp.array([0.0, 0.0, 1.0]) + B = circle_fs.binormal(u.Q(0, "s")) + expected = jnp.array([0, 0, 1]) assert jnp.allclose(B.value, expected, atol=1e-5) def test_binormal_is_unit_vector(self, circle_fs: cxfc.FrenetSerretTransform): - B = circle_fs.binormal(u.Q(2.0, "s")) + B = circle_fs.binormal(u.Q(2, "s")) norm = jnp.sqrt(jnp.sum(B.value**2)) - assert jnp.allclose(norm, 1.0, atol=1e-5) + assert jnp.allclose(norm, 1, atol=1e-5) # ── Orthonormality ──────────────────────────────────────────────────── @@ -132,17 +132,17 @@ def test_binormal_is_unit_vector(self, circle_fs: cxfc.FrenetSerretTransform): class TestFrenetSerretTransformOrthonormality: """T, N, B should form an orthonormal right-handed triad.""" - @pytest.mark.parametrize("tau_val", [0.0, 0.5, 1.0, 2.5, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 0.5, 1, 2.5, jnp.pi]) def test_orthogonality(self, circle_fs: cxfc.FrenetSerretTransform, tau_val: float): tau = u.Q(tau_val, "s") T = circle_fs.tangent(tau).value N = circle_fs.normal(tau).value B = circle_fs.binormal(tau).value - assert jnp.allclose(jnp.dot(T, N), 0.0, atol=1e-5) - assert jnp.allclose(jnp.dot(T, B), 0.0, atol=1e-5) - assert jnp.allclose(jnp.dot(N, B), 0.0, atol=1e-5) + assert jnp.allclose(jnp.dot(T, N), 0, atol=1e-5) + assert jnp.allclose(jnp.dot(T, B), 0, atol=1e-5) + assert jnp.allclose(jnp.dot(N, B), 0, atol=1e-5) - @pytest.mark.parametrize("tau_val", [0.0, 1.0, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 1, jnp.pi]) def test_right_handed(self, circle_fs: cxfc.FrenetSerretTransform, tau_val: float): """B should equal T x N (right-handed frame).""" tau = u.Q(tau_val, "s") @@ -160,18 +160,18 @@ class TestFrenetSerretTransformOpaqueUnits: def test_tangent_at_zero(self, circle_yr_fs: cxfc.FrenetSerretTransform): """Tangent at tau=0 should still be (0, 1, 0) direction.""" - T = circle_yr_fs.tangent(u.Q(0.0, "yr")) - expected = jnp.array([0.0, 1.0, 0.0]) + T = circle_yr_fs.tangent(u.Q(0, "yr")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_normal_at_zero(self, circle_yr_fs: cxfc.FrenetSerretTransform): - N = circle_yr_fs.normal(u.Q(0.0, "yr")) - expected = jnp.array([-1.0, 0.0, 0.0]) + N = circle_yr_fs.normal(u.Q(0, "yr")) + expected = jnp.array([-1, 0, 0]) assert jnp.allclose(N.value, expected, atol=1e-5) def test_binormal_at_zero(self, circle_yr_fs: cxfc.FrenetSerretTransform): - B = circle_yr_fs.binormal(u.Q(0.0, "yr")) - expected = jnp.array([0.0, 0.0, 1.0]) + B = circle_yr_fs.binormal(u.Q(0, "yr")) + expected = jnp.array([0, 0, 1]) assert jnp.allclose(B.value, expected, atol=1e-5) @@ -182,8 +182,8 @@ class TestFrenetSerretTransformJAX: """Verify compatibility with jit and vmap.""" def test_jit(self, circle_fs: cxfc.FrenetSerretTransform): - T = jax.jit(circle_fs.tangent)(u.Q(0.0, "s")) - expected = jnp.array([0.0, 1.0, 0.0]) + T = jax.jit(circle_fs.tangent)(u.Q(0, "s")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_vmap(self, circle_fs: cxfc.FrenetSerretTransform): @@ -191,7 +191,7 @@ def test_vmap(self, circle_fs: cxfc.FrenetSerretTransform): Ts = jax.vmap(circle_fs.tangent)(taus) # All tangent vectors should be unit length norms = jnp.sqrt(jnp.sum(Ts.value**2, axis=-1)) - assert jnp.allclose(norms, 1.0, atol=1e-5) + assert jnp.allclose(norms, 1, atol=1e-5) # ── from_ constructor ──────────────────────────────────────────────── @@ -202,8 +202,8 @@ class TestFrenetSerretTransformFrom: def test_from_dispatches_to_from_curve(self): fs = cxfc.FrenetSerretTransform.from_(_circle_curve) - loc = fs.location(u.Q(0.0, "s")) - expected = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + loc = fs.location(u.Q(0, "s")) + expected = u.Q(jnp.array([1, 0, 0]), "km") assert jnp.allclose(loc.value, expected.value, atol=1e-6) @@ -223,29 +223,29 @@ class TestFrenetSerretTransformInverse: def test_inverse_location_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): inv = circle_fs.inverse - loc = inv.location(u.Q(0.0, "s")) - expected = jnp.array([0.0, 1.0, 0.0]) + loc = inv.location(u.Q(0, "s")) + expected = jnp.array([0, 1, 0]) assert jnp.allclose(loc.value, expected, atol=1e-5) def test_inverse_tangent_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): inv = circle_fs.inverse - T = inv.tangent(u.Q(0.0, "s")) - expected = jnp.array([0.0, -1.0, 0.0]) + T = inv.tangent(u.Q(0, "s")) + expected = jnp.array([0, -1, 0]) assert jnp.allclose(T.value, expected, atol=1e-5) def test_inverse_normal_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): inv = circle_fs.inverse - N = inv.normal(u.Q(0.0, "s")) - expected = jnp.array([1.0, 0.0, 0.0]) + N = inv.normal(u.Q(0, "s")) + expected = jnp.array([1, 0, 0]) assert jnp.allclose(N.value, expected, atol=1e-5) def test_inverse_binormal_at_zero(self, circle_fs: cxfc.FrenetSerretTransform): inv = circle_fs.inverse - B = inv.binormal(u.Q(0.0, "s")) - expected = jnp.array([0.0, 0.0, 1.0]) + B = inv.binormal(u.Q(0, "s")) + expected = jnp.array([0, 0, 1]) assert jnp.allclose(B.value, expected, atol=1e-5) - @pytest.mark.parametrize("tau_val", [0.0, 0.5, 1.0, 2.5, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 0.5, 1, 2.5, jnp.pi]) def test_inverse_orthonormality( self, circle_fs: cxfc.FrenetSerretTransform, tau_val: float ): @@ -254,20 +254,20 @@ def test_inverse_orthonormality( T = inv.tangent(tau).value N = inv.normal(tau).value B = inv.binormal(tau).value - assert jnp.allclose(jnp.dot(T, N), 0.0, atol=1e-5) - assert jnp.allclose(jnp.dot(T, B), 0.0, atol=1e-5) - assert jnp.allclose(jnp.dot(N, B), 0.0, atol=1e-5) - assert jnp.allclose(jnp.linalg.norm(T), 1.0, atol=1e-5) - assert jnp.allclose(jnp.linalg.norm(N), 1.0, atol=1e-5) - assert jnp.allclose(jnp.linalg.norm(B), 1.0, atol=1e-5) - - @pytest.mark.parametrize("tau_val", [0.0, 1.0, jnp.pi]) + assert jnp.allclose(jnp.dot(T, N), 0, atol=1e-5) + assert jnp.allclose(jnp.dot(T, B), 0, atol=1e-5) + assert jnp.allclose(jnp.dot(N, B), 0, atol=1e-5) + assert jnp.allclose(jnp.linalg.norm(T), 1, atol=1e-5) + assert jnp.allclose(jnp.linalg.norm(N), 1, atol=1e-5) + assert jnp.allclose(jnp.linalg.norm(B), 1, atol=1e-5) + + @pytest.mark.parametrize("tau_val", [0, 1, jnp.pi]) def test_roundtrip_forward_inverse( self, circle_fs: cxfc.FrenetSerretTransform, tau_val: float ): """Apply forward then inverse: R_inv @ (R @ (p - g) - g_inv) == p.""" tau = u.Q(tau_val, "s") - p = u.Q(jnp.array([2.0, 3.0, 4.0]), "km") + p = u.Q(jnp.array([2, 3, 4]), "km") # Forward: p' = R @ (p - gamma) g = circle_fs.location(tau) @@ -285,16 +285,12 @@ def test_roundtrip_forward_inverse( Bi = inv.binormal(tau) diff_inv = p_fwd - g_inv p_rec = qnp.stack( - [ - qnp.sum(Ti * diff_inv), - qnp.sum(Ni * diff_inv), - qnp.sum(Bi * diff_inv), - ] + [qnp.sum(Ti * diff_inv), qnp.sum(Ni * diff_inv), qnp.sum(Bi * diff_inv)] ) assert jnp.allclose(p_rec.value, p.value, atol=1e-4) - @pytest.mark.parametrize("tau_val", [0.0, 1.0, jnp.pi]) + @pytest.mark.parametrize("tau_val", [0, 1, jnp.pi]) def test_double_inverse( self, circle_fs: cxfc.FrenetSerretTransform, tau_val: float ): @@ -310,16 +306,16 @@ def test_double_inverse( def test_inverse_jit(self, circle_fs: cxfc.FrenetSerretTransform): inv = circle_fs.inverse - loc = jax.jit(inv.location)(u.Q(0.0, "s")) - expected = jnp.array([0.0, 1.0, 0.0]) - assert jnp.allclose(loc.value, expected, atol=1e-5) + loc = jax.jit(inv.location)(u.Q(0, "s")) + exp = jnp.array([0, 1, 0]) + assert jnp.allclose(loc.value, exp, atol=1e-5) def test_inverse_opaque_units(self, circle_yr_fs: cxfc.FrenetSerretTransform): """Inverse should work with opaque-unit curves too.""" inv = circle_yr_fs.inverse - tau = u.Q(0.0, "yr") + tau = u.Q(0, "yr") # For the yr-circle at tau=0: gamma=(5,0,0)km, T=(0,1,0), N=(-1,0,0) # inv_location = -[T·g, N·g, B·g] = -[0, -5, 0] = (0, 5, 0) loc = inv.location(tau) - expected = jnp.array([0.0, 5.0, 0.0]) + expected = jnp.array([0, 5, 0]) assert jnp.allclose(loc.value, expected, atol=1e-3) diff --git a/packages/coordinax.curveframes/tests/test_frenet_serret_frame.py b/packages/coordinax.curveframes/tests/test_frenet_serret_frame.py index 7ec8187f..dda77b18 100644 --- a/packages/coordinax.curveframes/tests/test_frenet_serret_frame.py +++ b/packages/coordinax.curveframes/tests/test_frenet_serret_frame.py @@ -112,10 +112,10 @@ def test_act_forward_at_tau_zero(self, circle_fs_transform) -> None: At tau=0: gamma=(1,0,0) km, T=(0,1,0), N=(-1,0,0), B=(0,0,1) R @ (p - gamma) where p=(1,0,0) km => R @ (0,0,0) = (0,0,0) km. """ - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([1, 0, 0]), "km") result = cxfm.act(circle_fs_transform, tau, p) - np.testing.assert_allclose(_as_array(result, "km"), [0.0, 0.0, 0.0], atol=1e-6) + np.testing.assert_allclose(_as_array(result, "km"), [0, 0, 0], atol=1e-6) def test_act_forward_off_curve(self, circle_fs_transform) -> None: """Forward transform at tau=0 on a point offset from the curve. @@ -124,15 +124,15 @@ def test_act_forward_off_curve(self, circle_fs_transform) -> None: p=(2,0,0) km => delta=(1,0,0) R @ delta = [T·delta, N·delta, B·delta] = [0, -1, 0] km """ - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([2, 0, 0]), "km") result = cxfm.act(circle_fs_transform, tau, p) - np.testing.assert_allclose(_as_array(result, "km"), [0.0, -1.0, 0.0], atol=1e-6) + np.testing.assert_allclose(_as_array(result, "km"), [0, -1, 0], atol=1e-6) def test_act_inverse_roundtrip(self, circle_fs_transform) -> None: """Forward then inverse recovers original point.""" tau = u.Q(0.5, "s") - p = u.Q(jnp.array([3.0, -1.0, 2.0]), "km") + p = u.Q(jnp.array([3, -1, 2]), "km") p_curve = cxfm.act(circle_fs_transform, tau, p) p_back = cxfm.act(circle_fs_transform.inverse, tau, p_curve) @@ -143,10 +143,10 @@ def test_act_inverse_roundtrip(self, circle_fs_transform) -> None: def test_act_at_different_tau_values(self, circle_fs_transform) -> None: """Different tau values give different results for same point.""" - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + p = u.Q(jnp.array([2, 0, 0]), "km") - r1 = cxfm.act(circle_fs_transform, u.Q(0.0, "s"), p) - r2 = cxfm.act(circle_fs_transform, u.Q(1.0, "s"), p) + r1 = cxfm.act(circle_fs_transform, u.Q(0, "s"), p) + r2 = cxfm.act(circle_fs_transform, u.Q(1, "s"), p) # These should be different since the frame rotates assert not np.allclose(_as_array(r1, "km"), _as_array(r2, "km"), atol=1e-3) @@ -172,7 +172,7 @@ def test_transition_from_fs_frame(self, circle_fs_frame) -> None: def test_roundtrip_alice_to_fs_and_back(self, circle_fs_frame) -> None: """Alice -> FS -> Alice is identity for any point.""" tau = u.Q(0.5, "s") - p = u.Q(jnp.array([3.0, -1.0, 2.0]), "km") + p = u.Q(jnp.array([3, -1, 2]), "km") op_fwd = cxf.frame_transition(cxf.Alice(), circle_fs_frame) op_bwd = cxf.frame_transition(circle_fs_frame, cxf.Alice()) @@ -193,8 +193,8 @@ def test_alice_fs_alex_chain(self) -> None: Alice-to-Alex transition. """ fs_frame = cxfc.FrenetSerretFrame.from_curve(cxf.Alice(), _circle_curve) - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([2, 0, 0]), "km") # Alice -> FS op_a_to_fs = cxf.frame_transition(cxf.Alice(), fs_frame) @@ -224,7 +224,7 @@ def test_full_chain_alice_fs_alex_roundtrip(self) -> None: """Alice -> FS -> Alex -> FS -> Alice recovers original point.""" fs_frame = cxfc.FrenetSerretFrame.from_curve(cxf.Alice(), _circle_curve) tau = u.Q(0.3, "s") - p = u.Q(jnp.array([5.0, -2.0, 1.0]), "km") + p = u.Q(jnp.array([5, -2, 1]), "km") # Alice -> FS -> Alex op1 = cxf.frame_transition(cxf.Alice(), fs_frame) @@ -250,8 +250,8 @@ class TestFrenetSerretFrameJAX: def test_act_jit(self, circle_fs_transform) -> None: """Act with FrenetSerretTransform is JIT-compatible.""" - tau = u.Q(0.0, "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p = u.Q(jnp.array([2, 0, 0]), "km") result_eager = cxfm.act(circle_fs_transform, tau, p) result_jit = jax.jit(lambda t, x: cxfm.act(circle_fs_transform, t, x))(tau, p) @@ -262,8 +262,8 @@ def test_act_jit(self, circle_fs_transform) -> None: def test_act_vmap_over_tau(self, circle_fs_transform) -> None: """Act can be vmapped over the tau parameter.""" - taus = u.Q(jnp.linspace(0.0, 2.0, 5), "s") - p = u.Q(jnp.array([2.0, 0.0, 0.0]), "km") + taus = u.Q(jnp.linspace(0, 2, 5), "s") + p = u.Q(jnp.array([2, 0, 0]), "km") results = jax.vmap(lambda t: cxfm.act(circle_fs_transform, t, p))(taus) @@ -284,29 +284,29 @@ def test_forward_moves_point_to_curve_frame(self, circle_fs_transform) -> None: At tau=0, the curve is at (1,0,0) km with T=(0,1,0), N=(-1,0,0). The point at the curve origin should map to (0,0,0) in the frame. """ - tau = u.Q(0.0, "s") - p_on_curve = u.Q(jnp.array([1.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p_on_curve = u.Q(jnp.array([1, 0, 0]), "km") result = cxfm.act(circle_fs_transform, tau, p_on_curve) - np.testing.assert_allclose(_as_array(result, "km"), [0.0, 0.0, 0.0], atol=1e-6) + np.testing.assert_allclose(_as_array(result, "km"), [0, 0, 0], atol=1e-6) def test_inverse_moves_point_back_to_ambient(self, circle_fs_transform) -> None: """Inverse transform moves a curve-frame point back to ambient. Origin of curve frame at tau=0 should map back to gamma(0)=(1,0,0). """ - tau = u.Q(0.0, "s") - p_origin = u.Q(jnp.array([0.0, 0.0, 0.0]), "km") + tau = u.Q(0, "s") + p_origin = u.Q(jnp.array([0, 0, 0]), "km") result = cxfm.act(circle_fs_transform.inverse, tau, p_origin) - np.testing.assert_allclose(_as_array(result, "km"), [1.0, 0.0, 0.0], atol=1e-6) + np.testing.assert_allclose(_as_array(result, "km"), [1, 0, 0], atol=1e-6) def test_frame_transition_matches_direct_transform( self, circle_fs_frame, circle_fs_transform ) -> None: """frame_transition(Alice, fs_frame) applies the same as the xop.""" tau = u.Q(0.5, "s") - p = u.Q(jnp.array([2.0, 1.0, -1.0]), "km") + p = u.Q(jnp.array([2, 1, -1]), "km") via_transition = cxfm.act( cxf.frame_transition(cxf.Alice(), circle_fs_frame), tau, p @@ -314,7 +314,5 @@ def test_frame_transition_matches_direct_transform( via_direct = cxfm.act(circle_fs_transform, tau, p) np.testing.assert_allclose( - _as_array(via_transition, "km"), - _as_array(via_direct, "km"), - atol=1e-10, + _as_array(via_transition, "km"), _as_array(via_direct, "km"), atol=1e-10 ) From 1263f30e3e1bdaabddd3cd5e461449ce47636ed7 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:33:17 -0400 Subject: [PATCH 08/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(coordinax.h?= =?UTF-8?q?ypothesis):=20adapt=20tests=20to=20updated=20chart=20and=20mani?= =?UTF-8?q?fold=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update strategy tests and chart/vector/distance tests to reflect the current API conventions introduced by the metric refactor. --- .../docs/testing-guide.md | 2 +- .../tests/test_composite_dispatch.py | 38 ++++++++----------- .../tests/unit/charts/test_cdicts.py | 16 ++++---- .../unit/charts/test_chart_init_kwargs.py | 6 +-- .../tests/unit/charts/test_performance.py | 3 +- .../tests/unit/distances/test_distances.py | 10 ++--- .../tests/unit/vectors/test_vectors.py | 10 ++--- 7 files changed, 38 insertions(+), 47 deletions(-) diff --git a/packages/coordinax.hypothesis/docs/testing-guide.md b/packages/coordinax.hypothesis/docs/testing-guide.md index 75ff2894..645c1f91 100644 --- a/packages/coordinax.hypothesis/docs/testing-guide.md +++ b/packages/coordinax.hypothesis/docs/testing-guide.md @@ -140,7 +140,7 @@ def test_angle_arithmetic(angle1, angle2): """Angles support arithmetic operations.""" # Can add/subtract angles diff = angle1 - angle2 - assert isinstance(diff, u.Quantity) + assert isinstance(diff, u.Q) ``` ### Using Assumptions diff --git a/packages/coordinax.hypothesis/tests/test_composite_dispatch.py b/packages/coordinax.hypothesis/tests/test_composite_dispatch.py index de489658..86b3bd73 100644 --- a/packages/coordinax.hypothesis/tests/test_composite_dispatch.py +++ b/packages/coordinax.hypothesis/tests/test_composite_dispatch.py @@ -71,8 +71,8 @@ def bounded_value(draw, x: int): @dispatch @st.composite def bounded_value(draw, x: float): # noqa: F811 - """Strategy: draw a float in [x, x+1.0].""" - return draw(st.floats(min_value=x, max_value=x + 1.0, allow_nan=False)) + """Strategy: draw a float in [x, x+1].""" + return draw(st.floats(min_value=x, max_value=x + 1, allow_nan=False)) @given(bounded_value(5)) @@ -143,10 +143,10 @@ def test_dispatch_two_int_args(v): assert 0 <= v <= 100 -@given(interval_value(0.0, 1.0)) +@given(interval_value(0, 1)) def test_dispatch_two_float_args(v): assert isinstance(v, float) - assert 0.0 <= v <= 1.0 + assert 0 <= v <= 1 # ───────────────────────────────────────────────────────────────────────────── @@ -158,7 +158,7 @@ def test_dispatch_two_float_args(v): @st.composite def typed_number(draw, x: Number): """Fallback: any number → float in [0, 1].""" - return draw(st.floats(min_value=0.0, max_value=1.0, allow_nan=False)) + return draw(st.floats(min_value=0, max_value=1, allow_nan=False)) @dispatch @@ -166,11 +166,7 @@ def typed_number(draw, x: Number): def typed_number(draw, x: Real): # noqa: F811 """Specialisation for real numbers: float near x.""" return draw( - st.floats( - min_value=float(x) - 1.0, - max_value=float(x) + 1.0, - allow_nan=False, - ) + st.floats(min_value=float(x) - 1, max_value=float(x) + 1, allow_nan=False) ) @@ -193,7 +189,7 @@ def test_real_overload_for_float(v): """float is a Real but not an int → Real overload is selected.""" assert isinstance(v, float) x = 3.14 - assert float(x) - 1.0 <= v <= float(x) + 1.0 + assert float(x) - 1 <= v <= float(x) + 1 # ───────────────────────────────────────────────────────────────────────────── @@ -213,8 +209,8 @@ def sorted_pair(draw, n: int): @dispatch @st.composite def sorted_pair(draw, x: float): # noqa: F811 - """Draw (a, b) with 0.0 ≤ a ≤ b ≤ x (floats).""" - a = draw(st.floats(min_value=0.0, max_value=x, allow_nan=False)) + """Draw (a, b) with 0 ≤ a ≤ b ≤ x (floats).""" + a = draw(st.floats(min_value=0, max_value=x, allow_nan=False)) b = draw(st.floats(min_value=a, max_value=x, allow_nan=False)) return (a, b) @@ -227,12 +223,12 @@ def test_sorted_int_pair(pair): assert 0 <= b <= 100 -@given(sorted_pair(1.0)) +@given(sorted_pair(1)) def test_sorted_float_pair(pair): a, b = pair assert a <= b - assert 0.0 <= a <= 1.0 - assert 0.0 <= b <= 1.0 + assert 0 <= a <= 1 + assert 0 <= b <= 1 # ───────────────────────────────────────────────────────────────────────────── @@ -253,9 +249,7 @@ def multi_bounded(draw, *xs: int): @st.composite def multi_bounded(draw, *xs: float): # noqa: F811 """Draw one float in [x, x+1] for each seed value x.""" - return [ - draw(st.floats(min_value=x, max_value=x + 1.0, allow_nan=False)) for x in xs - ] + return [draw(st.floats(min_value=x, max_value=x + 1, allow_nan=False)) for x in xs] @given(multi_bounded(0, 10, 20)) @@ -273,7 +267,7 @@ def test_varargs_float_overload(values): """All-float varargs resolves to the float overload.""" assert len(values) == 2 assert all(isinstance(v, float) for v in values) - assert 0.0 <= values[0] <= 1.0 + assert 0 <= values[0] <= 1 assert 0.5 <= values[1] <= 1.5 @@ -284,7 +278,7 @@ def test_varargs_mixed_type_raises(): type. Use the per-element pattern (Section 7) for heterogeneous lists. """ with pytest.raises(NotFoundLookupError): - multi_bounded(1, 2.0) # int + float → no overload matches + multi_bounded(1, 2.5) # int + float → no overload matches # ───────────────────────────────────────────────────────────────────────────── @@ -302,7 +296,7 @@ def _element_strategy(x: int) -> st.SearchStrategy: @dispatch def _element_strategy(x: float) -> st.SearchStrategy: - return st.floats(min_value=x, max_value=x + 1.0, allow_nan=False) + return st.floats(min_value=x, max_value=x + 1, allow_nan=False) @st.composite diff --git a/packages/coordinax.hypothesis/tests/unit/charts/test_cdicts.py b/packages/coordinax.hypothesis/tests/unit/charts/test_cdicts.py index 9f9d7b94..4a9724d6 100644 --- a/packages/coordinax.hypothesis/tests/unit/charts/test_cdicts.py +++ b/packages/coordinax.hypothesis/tests/unit/charts/test_cdicts.py @@ -47,7 +47,7 @@ class TestCDictValueControl: @given( p=cxst.cdicts( - cxc.cart3d, elements=st.floats(min_value=1.0, max_value=100.0, width=32) + cxc.cart3d, elements=st.floats(min_value=1, max_value=100, width=32) ) ) def test_first_octant_via_elements(self, p): @@ -58,7 +58,7 @@ def test_first_octant_via_elements(self, p): @given( p=cxst.cdicts( - cxc.cart3d, elements=st.floats(min_value=-100.0, max_value=-1.0, width=32) + cxc.cart3d, elements=st.floats(min_value=-100, max_value=-1, width=32) ) ) def test_negative_octant_via_elements(self, p): @@ -70,16 +70,14 @@ def test_negative_octant_via_elements(self, p): @given( p=cxst.cdicts( cxc.cart2d, - elements=st.floats( - min_value=-10.0, max_value=10.0, allow_nan=False, width=32 - ), + elements=st.floats(min_value=-10, max_value=10, allow_nan=False, width=32), ) ) def test_bounded_range(self, p): """elements= with explicit bounds keeps all component magnitudes in range.""" for key in ("x", "y"): val = float(p[key].value) - assert -10.0 <= val <= 10.0 + assert -10 <= val <= 10 @given(data=st.data()) def test_second_quadrant_per_component(self, data): @@ -90,12 +88,12 @@ def test_second_quadrant_per_component(self, data): p_x = data.draw( cxst.cdicts( cxc.cart2d, - elements=st.floats(min_value=-100.0, max_value=-1.0, width=32), + elements=st.floats(min_value=-100, max_value=-1, width=32), ) ) p_y = data.draw( cxst.cdicts( - cxc.cart2d, elements=st.floats(min_value=1.0, max_value=100.0, width=32) + cxc.cart2d, elements=st.floats(min_value=1, max_value=100, width=32) ) ) @@ -104,7 +102,7 @@ def test_second_quadrant_per_component(self, data): @given( p=cxst.cdicts( - cxc.sph3d, elements=st.floats(min_value=1.0, max_value=100.0, width=32) + cxc.sph3d, elements=st.floats(min_value=1, max_value=100, width=32) ) ) def test_spherical_positive_elements(self, p): diff --git a/packages/coordinax.hypothesis/tests/unit/charts/test_chart_init_kwargs.py b/packages/coordinax.hypothesis/tests/unit/charts/test_chart_init_kwargs.py index 16a50bd2..31dd1b02 100644 --- a/packages/coordinax.hypothesis/tests/unit/charts/test_chart_init_kwargs.py +++ b/packages/coordinax.hypothesis/tests/unit/charts/test_chart_init_kwargs.py @@ -1,11 +1,11 @@ """Tests for the chart_init_kwargs strategy.""" import hypothesis.strategies as st -from hypothesis import example, given - import unxt as u +from hypothesis import example, given import coordinax.charts as cxc + import coordinax.hypothesis.main as cxst @@ -37,7 +37,7 @@ def _chart_class_with_kwargs(draw): @example(pair=(cxc.Cart3D, {})) @example(pair=(cxc.Cylindrical3D, {})) @example(pair=(cxc.Spherical3D, {})) -@example(pair=(cxc.ProlateSpheroidal3D, {"Delta": u.StaticQuantity(1.0, "kpc")})) +@example(pair=(cxc.ProlateSpheroidal3D, {"Delta": u.StaticQuantity(1, "kpc")})) @example(pair=(cxc.PoincarePolar6D, {})) @example(pair=(cxc.CartND, {})) def test_chart_init_kwargs_instantiates(pair) -> None: diff --git a/packages/coordinax.hypothesis/tests/unit/charts/test_performance.py b/packages/coordinax.hypothesis/tests/unit/charts/test_performance.py index 06578570..630880dc 100644 --- a/packages/coordinax.hypothesis/tests/unit/charts/test_performance.py +++ b/packages/coordinax.hypothesis/tests/unit/charts/test_performance.py @@ -5,6 +5,7 @@ import pytest import coordinax.charts as cxc + import coordinax.hypothesis.main as cxst @@ -60,7 +61,7 @@ def test_drawing_charts_is_fast(): elapsed = time.perf_counter() - start # Drawing 10 examples should complete in < 1 second - assert elapsed < 1.0, f"Drawing 10 examples took {elapsed:.3f}s" + assert elapsed < 1, f"Drawing 10 examples took {elapsed:.3f}s" # Verify we got valid charts assert all(isinstance(chart, cxc.AbstractChart) for chart in examples) diff --git a/packages/coordinax.hypothesis/tests/unit/distances/test_distances.py b/packages/coordinax.hypothesis/tests/unit/distances/test_distances.py index 2be9791b..c8749ae2 100644 --- a/packages/coordinax.hypothesis/tests/unit/distances/test_distances.py +++ b/packages/coordinax.hypothesis/tests/unit/distances/test_distances.py @@ -53,18 +53,16 @@ def test_distance_with_strategy_check_negative(dist: cxd.Distance) -> None: # check_negative varies, so we can't assert about the sign -@given( - dist=cxst.distances(elements=st.floats(min_value=1.0, max_value=100.0, width=32)) -) +@given(dist=cxst.distances(elements=st.floats(min_value=1, max_value=100, width=32))) def test_distance_with_custom_elements(dist: cxd.Distance) -> None: """Test distance with custom elements range.""" assert isinstance(dist, cxd.Distance) - assert 1.0 <= dist.value <= 100.0 + assert 1 <= dist.value <= 100 @given( dist=cxst.distances( - check_negative=True, elements=st.floats(min_value=0.0, max_value=10.0, width=32) + check_negative=True, elements=st.floats(min_value=0, max_value=10, width=32) ) ) def test_distance_check_negative_with_elements(dist: cxd.Distance) -> None: @@ -72,7 +70,7 @@ def test_distance_check_negative_with_elements(dist: cxd.Distance) -> None: assert isinstance(dist, cxd.Distance) # When check_negative=True and elements provided, min_value should be adjusted assert dist.value >= 0 - assert dist.value <= 50.0 + assert dist.value <= 50 class TestDistanceFromType: diff --git a/packages/coordinax.hypothesis/tests/unit/vectors/test_vectors.py b/packages/coordinax.hypothesis/tests/unit/vectors/test_vectors.py index f6825df9..3decf4bf 100644 --- a/packages/coordinax.hypothesis/tests/unit/vectors/test_vectors.py +++ b/packages/coordinax.hypothesis/tests/unit/vectors/test_vectors.py @@ -16,9 +16,9 @@ ) # Shared float32 element strategies (width=32 matches the default JAX dtype). -_F32_POS = st.floats(min_value=1.0, max_value=100.0, width=32) -_F32_NEG = st.floats(min_value=-100.0, max_value=-1.0, width=32) -_F32_BOUNDED = st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, width=32) +_F32_POS = st.floats(min_value=1, max_value=100, width=32) +_F32_NEG = st.floats(min_value=-100, max_value=-1, width=32) +_F32_BOUNDED = st.floats(min_value=-10, max_value=10, allow_nan=False, width=32) @given(vec=vector_strategy()) @@ -128,5 +128,5 @@ def test_second_quadrant_per_component(self, data: st.DataObject) -> None: @given(vec=vector_strategy(cxc.cart2d, cxr.point, elements=_F32_BOUNDED)) def test_bounded_range(self, vec: cxv.Point) -> None: """elements= with explicit bounds keeps all component magnitudes in range.""" - assert -10.0 <= vec.data["x"].ustrip("m") <= 10.0 - assert -10.0 <= vec.data["y"].ustrip("m") <= 10.0 + assert -10 <= vec.data["x"].ustrip("m") <= 10 + assert -10 <= vec.data["y"].ustrip("m") <= 10 From 19ef6fd4b9af02ccdd94205201a1d0a0fa89d620 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:34:41 -0400 Subject: [PATCH 09/15] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(coordinax.i?= =?UTF-8?q?nterop.astropy):=20adapt=20to=20QMatrix=20rename=20and=20docstr?= =?UTF-8?q?ing=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace QuantityMatrix with QMatrix in qmatrix.py and remove Examples section headers from angles, distances, frames, ptmap, vec_constructors, and vec_converters. Tests updated to match. --- .../coordinax/interop/astropy/_src/angles.py | 3 -- .../interop/astropy/_src/distances.py | 13 +---- .../coordinax/interop/astropy/_src/frames.py | 6 --- .../coordinax/interop/astropy/_src/ptmap.py | 8 --- .../coordinax/interop/astropy/_src/qmatrix.py | 18 +++---- .../interop/astropy/_src/vec_constructors.py | 12 ----- .../interop/astropy/_src/vec_converters.py | 50 ++++++------------- .../tests/test_angles.py | 8 +-- .../tests/test_distances.py | 18 +++---- .../tests/test_ptmap_cdict.py | 12 ++--- .../tests/test_qmatrix.py | 16 +++--- .../tests/test_vectors.py | 2 +- 12 files changed, 49 insertions(+), 117 deletions(-) diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/angles.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/angles.py index 1dc34f57..6819cfb1 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/angles.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/angles.py @@ -15,12 +15,9 @@ def convert_astropy_angle_to_cx_angle(q: AstropyAngle, /) -> cxa.Angle: """Convert a `astropy.coordinates.Angle` to a `coordinax.angles.Angle`. - Examples - -------- >>> import astropy.coordinates as apyc >>> import plum >>> import coordinax.angles as cxa - >>> plum.convert(apyc.Angle(1.0, "rad"), cxa.Angle) Angle(1., 'rad') diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/distances.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/distances.py index 5322830a..7b4099ee 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/distances.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/distances.py @@ -18,12 +18,9 @@ def convert_astropy_quantity_to_unxt_distance(q: apyu.Quantity, /) -> cxd.Distance: """Convert a `astropy.units.Quantity` to a `coordinax.distances.Distance`. - Examples - -------- >>> import astropy.units as apyu >>> import plum >>> import coordinax.distances as cxd - >>> plum.convert(apyu.Quantity(1.0, "cm"), cxd.Distance) Distance(1., 'cm') @@ -39,12 +36,9 @@ def convert_astropy_quantity_to_unxt_distance(q: apyu.Quantity, /) -> cxd.Distan def convert_astropy_quantity_to_unxt_parallax(q: apyu.Quantity, /) -> cxastro.Parallax: """Convert a `astropy.units.Quantity` to a `coordinax.astro.Parallax`. - Examples - -------- >>> import astropy.units as apyu >>> import plum >>> import coordinax.astro as cxastro - >>> plum.convert(apyu.Quantity(1.0, "radian"), cxastro.Parallax) Parallax(1., 'rad') @@ -57,17 +51,12 @@ def convert_astropy_quantity_to_unxt_parallax(q: apyu.Quantity, /) -> cxastro.Pa @conversion_method(type_from=apyu.Quantity, type_to=cxastro.DistanceModulus) -def convert_astropy_quantity_to_unxt_distmod( - q: apyu.Quantity, / -) -> cxastro.DistanceModulus: +def convert_astropy_q_to_unxt_distmod(q: apyu.Quantity, /) -> cxastro.DistanceModulus: """Convert a `astropy.units.Quantity` to a `coordinax.astro.DistanceModulus`. - Examples - -------- >>> import astropy.units as apyu >>> import plum >>> import coordinax.astro as cxastro - >>> plum.convert(apyu.Quantity(1.0, "mag"), cxastro.DistanceModulus) DistanceModulus(1., 'mag') diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/frames.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/frames.py index d7207969..fb38091b 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/frames.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/frames.py @@ -83,8 +83,6 @@ def coordinax_icrs_to_astropy_icrs(frame: cxastro.ICRS, /) -> apyc.ICRS: coordinax and Astropy implementations have no frame-specific parameters, so the conversion is straightforward. - Examples - -------- >>> import coordinax.astro as cxa >>> import astropy.coordinates as apyc >>> import plum @@ -104,8 +102,6 @@ def coordinax_icrs_to_astropy_icrs(frame: cxastro.ICRS, /) -> apyc.ICRS: def from_(cls: type[cxastro.ICRS], obj: apyc.ICRS, /) -> cxastro.ICRS: """Construct from a `astropy.coordinates.ICRS`. - Examples - -------- >>> import coordinax.astro as cxastro >>> from plum import convert >>> import astropy.coordinates as apyc @@ -232,8 +228,6 @@ def from_( ) -> cxastro.Galactocentric: """Construct from a `astropy.coordinates.Galactocentric`. - Examples - -------- >>> import astropy.coordinates as apyc >>> import coordinax.frames as cxf diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/ptmap.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/ptmap.py index e2eb7ac6..6778364a 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/ptmap.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/ptmap.py @@ -22,8 +22,6 @@ def convert_cx_cdict_to_astropy_cartrep(p: CDict, /) -> apyc.CartesianRepresentation: """Convert a CDict to an astropy CartesianRepresentation. - Examples - -------- >>> import astropy.coordinates as apyc >>> import coordinax.charts as cxc >>> import unxt as u @@ -70,8 +68,6 @@ def cdict(r: apyc.CartesianRepresentation) -> CDict: def convert_cx_cdict_to_astropy_cylrep(p: CDict, /) -> apyc.CylindricalRepresentation: """Convert a CDict to an astropy CylindricalRepresentation. - Examples - -------- >>> import astropy.coordinates as apyc >>> import coordinax.charts as cxc >>> import unxt as u @@ -120,8 +116,6 @@ def convert_cx_cdict_to_astropy_physsphrep( ) -> apyc.PhysicsSphericalRepresentation: """Convert a CDict to an astropy PhysicsSphericalRepresentation. - Examples - -------- >>> import astropy.coordinates as apyc >>> import coordinax.charts as cxc >>> import unxt as u @@ -169,8 +163,6 @@ def cdict(r: apyc.PhysicsSphericalRepresentation) -> CDict: def convert_cx_cdict_to_astropy_sphrep(p: CDict, /) -> apyc.SphericalRepresentation: """Convert a CDict to an astropy SphericalRepresentation. - Examples - -------- >>> import astropy.coordinates as apyc >>> import coordinax.charts as cxc >>> import unxt as u diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/qmatrix.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/qmatrix.py index 913726b4..93ae295b 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/qmatrix.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/qmatrix.py @@ -7,7 +7,7 @@ import astropy.units as apyu -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax.internal import QMatrix, UnitsMatrix def _structured_unit_to_tuple(obj: apyu.StructuredUnit) -> tuple: @@ -29,8 +29,6 @@ def _structured_unit_to_tuple(obj: apyu.StructuredUnit) -> tuple: def unitsmatrix_to_structured_unit(obj: UnitsMatrix, /) -> apyu.StructuredUnit: """Convert a ``UnitsMatrix`` to an ``astropy.units.StructuredUnit``. - Examples - -------- >>> import plum >>> import astropy.units as apyu >>> from coordinax.internal import UnitsMatrix @@ -55,8 +53,6 @@ def unitsmatrix_to_structured_unit(obj: UnitsMatrix, /) -> apyu.StructuredUnit: def structured_unit_to_unitsmatrix(obj: apyu.StructuredUnit, /) -> UnitsMatrix: """Convert an ``astropy.units.StructuredUnit`` to a ``UnitsMatrix``. - Examples - -------- >>> import plum >>> import astropy.units as apyu >>> from coordinax.internal import UnitsMatrix @@ -74,18 +70,16 @@ def structured_unit_to_unitsmatrix(obj: apyu.StructuredUnit, /) -> UnitsMatrix: return UnitsMatrix(_structured_unit_to_tuple(obj)) -@plum.conversion_method(QuantityMatrix, apyu.Quantity) -def convert_qmatrix_to_astropy_quantity(q: QuantityMatrix, /) -> apyu.Quantity: - """Convert a `coordinax.internal.QuantityMatrix` to an `astropy.units.Quantity`. +@plum.conversion_method(QMatrix, apyu.Quantity) +def convert_qmatrix_to_astropy_quantity(q: QMatrix, /) -> apyu.Quantity: + """Convert a `coordinax.internal.QMatrix` to an `astropy.units.Quantity`. - Examples - -------- >>> import jax.numpy as jnp >>> import astropy.units as apyu >>> import plum - >>> from coordinax.internal import QuantityMatrix + >>> from coordinax.internal import QMatrix - >>> qmat = QuantityMatrix(jnp.array([1.0, 2.0]), unit=("km", "s")) + >>> qmat = QMatrix(jnp.array([1.0, 2.0]), unit=("km", "s")) >>> plum.convert(qmat, apyu.Quantity) diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_constructors.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_constructors.py index cd2608d3..8e0aedef 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_constructors.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_constructors.py @@ -22,8 +22,6 @@ def from_astropy_cartesian_representation( ) -> cxv.Point: """Construct Point from Astropy CartesianRepresentation. - Examples - -------- >>> import coordinax.vectors as cxv >>> from astropy.coordinates import CartesianRepresentation @@ -42,8 +40,6 @@ def from_astropy_cylindrical_representation( ) -> cxv.Point: """Construct Point from Astropy CylindricalRepresentation. - Examples - -------- >>> import astropy.units as apyu >>> import coordinax.vectors as cxv >>> from astropy.coordinates import CylindricalRepresentation @@ -67,8 +63,6 @@ def from_astropy_physics_spherical_representation( ) -> cxv.Point: """Construct Point from Astropy PhysicsSphericalRepresentation. - Examples - -------- >>> import coordinax.vectors as cxv >>> from astropy.coordinates import PhysicsSphericalRepresentation >>> import astropy.units as apyu @@ -92,8 +86,6 @@ def from_astropy_spherical_representation( ) -> cxv.Point: """Construct Point from Astropy SphericalRepresentation. - Examples - -------- >>> import coordinax.vectors as cxv >>> from astropy.coordinates import SphericalRepresentation >>> import astropy.units as apyu @@ -119,8 +111,6 @@ def from_astropy_spherical_representation( def from_(cls: type[cxv.Point], obj: apyc.BaseCoordinateFrame, /) -> cxv.Point: """Construct Point from Astropy frame with data. - Examples - -------- >>> import astropy.units as apyu >>> import astropy.coordinates as apyc >>> import coordinax.vectors as cxv @@ -164,8 +154,6 @@ def convert_astropy_frame_with_data_to_cx_point( ) -> cxv.Point: """Convert an Astropy frame with data to a Coordinax Point. - Examples - -------- >>> import astropy.units as apyu >>> import astropy.coordinates as apyc >>> import plum diff --git a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_converters.py b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_converters.py index f91c2971..86610356 100644 --- a/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_converters.py +++ b/packages/coordinax.interop.astropy/src/coordinax/interop/astropy/_src/vec_converters.py @@ -40,8 +40,6 @@ def check_semantics(obj: V, /, need: type[cxr.AbstractSemanticKind]) -> V: def vec_to_q(obj: cxv.Point, /) -> Shaped[apyu.Quantity, "*batch 3"]: """`coordinax.Point` -> `astropy.units.Quantity`. - Examples - -------- >>> import unxt as u >>> import coordinax.vectors as cxv >>> from plum import convert @@ -75,14 +73,12 @@ def vec_to_q(obj: cxv.Point, /) -> Shaped[apyu.Quantity, "*batch 3"]: def convert_vector_to_astropy(obj: cxv.Point, /) -> apyc.BaseRepresentation: """Convert a `coordinax.Point` to a `astropy.coordinates.BaseRepresentation`. - The specific Astropy representation type (e.g., Cartesian vs. Cylindrical) is - determined by the chart of the input point. The point's role must be + The specific Astropy representation type (e.g., Cartesian vs. Cylindrical) + is determined by the chart of the input point. The point's role must be compatible with a position/location (e.g., Point), and the chart must be one of the supported types (e.g., Cart3D, Cyl3D, Sph3D, etc.), or else a ValueError will be raised. - Examples - -------- >>> import unxt as u >>> import coordinax.vectors as cxv >>> from plum import convert @@ -132,15 +128,12 @@ def convert_vector_to_astropy(obj: cxv.Point, /) -> apyc.BaseRepresentation: def cart3_to_apycart3(obj: cxv.Point, /) -> apyc.CartesianRepresentation: """`coordinax.Point` -> `astropy.CartesianRepresentation`. - This conversion is only used for Cartesian-charted points or Point-role points. - For Pos/Vel/Acc roles with non-Cartesian charts, convert to the corresponding - Astropy representation type first (e.g., Cyl → CylindricalRepresentation, - then use Astropy's `.represent_as()` if needed). + This conversion is only used for Cartesian-charted points or Point-role + points. For Pos/Vel/Acc roles with non-Cartesian charts, convert to the + corresponding Astropy representation type first (e.g., Cyl → + CylindricalRepresentation, then use Astropy's `.represent_as()` if needed). - Examples - -------- >>> import coordinax.vectors as cxv - >>> point = cxv.Point.from_([1, 2, 3], "km") >>> convert(point, apyc.CartesianRepresentation) apyc.CartesianRepresentation: def apycart3_to_cart3(obj: apyc.CartesianRepresentation, /) -> cxv.Point: """`astropy.CartesianRepresentation` -> `coordinax.Cart3D`. - Examples - -------- >>> import coordinax.vectors as cxv >>> from astropy.coordinates import CartesianRepresentation - >>> vec = CartesianRepresentation(1, 2, 3, unit="km") >>> print(convert(vec, cxv.Point)) apyc.CylindricalRepresentation: """`Cyl3D` -> `astropy.CylindricalRepresentation`. - Examples - -------- >>> import unxt as u >>> import coordinax.vectors as cxv @@ -225,13 +213,12 @@ def cyl_to_apycyl( def apycyl_to_cyl(obj: apyc.CylindricalRepresentation, /) -> cxv.Point: """`astropy.CylindricalRepresentation` -> `coordinax.Cylindrical3D`. - Examples - -------- - >>> import astropy.units as u + >>> import astropy.units as apyu + >>> import astropy.coordinates as apyc >>> import coordinax.vectors as cxv - >>> from astropy.coordinates import CylindricalRepresentation - >>> cyl = CylindricalRepresentation(rho=1 * u.km, phi=2 * u.deg, z=30 * u.m) + >>> cyl = apyc.CylindricalRepresentation(rho=1 * apyu.km, phi=2 * apyu.deg, + ... z=30 * apyu.m) >>> print(convert(cyl, cxv.Point)) @@ -250,11 +237,8 @@ def sph_to_apysph( ) -> apyc.PhysicsSphericalRepresentation: """`coordinax.Point` -> `astropy.PhysicsSphericalRepresentation`. - Examples - -------- >>> import unxt as u >>> import coordinax.vectors as cxv - >>> vec = cxv.Point.from_({"r": u.Q(1,"m"), "theta": u.Q(2,"deg"), ... "phi": u.Q(3,"deg")}, cxc.sph3d) >>> convert(vec, apyc.PhysicsSphericalRepresentation) @@ -279,8 +263,6 @@ def sph_to_apysph( def apysph_to_sph(obj: apyc.PhysicsSphericalRepresentation, /) -> cxv.Point: """`astropy.PhysicsSphericalRepresentation` -> `coordinax.Spherical3D`. - Examples - -------- >>> import astropy.units as u >>> import coordinax.vectors as cxv >>> from astropy.coordinates import PhysicsSphericalRepresentation @@ -305,8 +287,6 @@ def lonlatsph_to_apysph( ) -> apyc.SphericalRepresentation: """`coordinax.LonLatSpherical3D` -> `astropy.SphericalRepresentation`. - Examples - -------- >>> import unxt as u >>> import coordinax.vectors as cxv @@ -334,14 +314,12 @@ def lonlatsph_to_apysph( def apysph_to_lonlatsph(obj: apyc.SphericalRepresentation, /) -> cxv.Point: """`astropy.SphericalRepresentation` -> `coordinax.LonLatSpherical3D`. - Examples - -------- - >>> import astropy.units as u + >>> import astropy.units as apyu + >>> import astropy.coordinates as apyc >>> import coordinax.vectors as cxv - >>> from astropy.coordinates import SphericalRepresentation - >>> sph = SphericalRepresentation(lon=2 * u.deg, lat=3 * u.deg, - ... distance=1 * u.km) + >>> sph = apyc.SphericalRepresentation(lon=2 * apyu.deg, lat=3 * apyu.deg, + ... distance=1 * apyu.km) >>> print(convert(sph, cxv.Point)) diff --git a/packages/coordinax.interop.astropy/tests/test_angles.py b/packages/coordinax.interop.astropy/tests/test_angles.py index 21768691..25082fdc 100644 --- a/packages/coordinax.interop.astropy/tests/test_angles.py +++ b/packages/coordinax.interop.astropy/tests/test_angles.py @@ -16,11 +16,11 @@ def test_astropy_angle_to_cx_angle() -> None: """Test converting AstropyAngle to cxa.Angle.""" - apy = AstropyAngle(1.0, "rad") + apy = AstropyAngle(1, "rad") angle = convert(apy, cxa.Angle) assert isinstance(angle, cxa.Angle) - assert angle.value == pytest.approx(1.0) + assert angle.value == pytest.approx(1) assert str(angle.unit) == "rad" @@ -41,11 +41,11 @@ def test_astropy_angle_to_cx_angle_hypothesis(unit: str) -> None: def test_cx_angle_to_astropy_angle() -> None: """Test converting cxa.Angle to AstropyAngle.""" - angle = cxa.Angle(1.0, "rad") + angle = cxa.Angle(1, "rad") apy = convert(angle, AstropyAngle) assert isinstance(apy, AstropyAngle) - assert apy.value == pytest.approx(1.0) + assert apy.value == pytest.approx(1) assert str(apy.unit) == "rad" diff --git a/packages/coordinax.interop.astropy/tests/test_distances.py b/packages/coordinax.interop.astropy/tests/test_distances.py index 42cdeb2d..092a84d5 100644 --- a/packages/coordinax.interop.astropy/tests/test_distances.py +++ b/packages/coordinax.interop.astropy/tests/test_distances.py @@ -37,7 +37,7 @@ def float32s( @given(unit=ust.units("length")) def test_astropy_quantity_to_distance(unit: str) -> None: """Test converting Astropy Quantity to Distance.""" - apyq = apyu.Quantity(42.0, unit) + apyq = apyu.Quantity(42, unit) dist = convert(apyq, Distance) assert isinstance(dist, Distance) @@ -47,7 +47,7 @@ def test_astropy_quantity_to_distance(unit: str) -> None: @given( dist=ust.quantities( - ust.units("length"), elements=float32s(1.0, 1e6), quantity_cls=Distance + ust.units("length"), elements=float32s(1, 1e6), quantity_cls=Distance ) ) def test_distance_to_astropy_quantity(dist: Distance) -> None: @@ -61,7 +61,7 @@ def test_distance_to_astropy_quantity(dist: Distance) -> None: @given( dist=ust.quantities( - ust.units("length"), elements=float32s(1.0, 1e6), quantity_cls=Distance + ust.units("length"), elements=float32s(1, 1e6), quantity_cls=Distance ) ) def test_distance_roundtrip(dist: Distance) -> None: @@ -90,7 +90,7 @@ def test_astropy_quantity_to_parallax(unit: str) -> None: @given( plx=ust.quantities( - ust.units("angle"), elements=float32s(0.0625, 1.0), quantity_cls=Parallax + ust.units("angle"), elements=float32s(0.0625, 1), quantity_cls=Parallax ) ) def test_parallax_to_astropy_quantity(plx: Parallax) -> None: @@ -104,7 +104,7 @@ def test_parallax_to_astropy_quantity(plx: Parallax) -> None: @given( plx=ust.quantities( - ust.units("angle"), elements=float32s(0.0625, 1.0), quantity_cls=Parallax + ust.units("angle"), elements=float32s(0.0625, 1), quantity_cls=Parallax ) ) def test_parallax_roundtrip(plx: Parallax) -> None: @@ -122,17 +122,17 @@ def test_parallax_roundtrip(plx: Parallax) -> None: def test_astropy_quantity_to_distancemodulus() -> None: """Test converting AstropyQuantity to DistanceModulus.""" - q = AstropyQuantity(5.0, "mag") + q = AstropyQuantity(5, "mag") dm = convert(q, DistanceModulus) assert isinstance(dm, DistanceModulus) - assert dm.value == pytest.approx(5.0) + assert dm.value == pytest.approx(5) assert str(dm.unit) == "mag" def test_distancemodulus_to_astropy_quantity() -> None: """Test converting DistanceModulus to AstropyQuantity.""" - dm = DistanceModulus(5.0, "mag") + dm = DistanceModulus(5, "mag") apyq = convert(dm, apyu.Quantity) assert isinstance(apyq, apyu.Quantity) @@ -142,7 +142,7 @@ def test_distancemodulus_to_astropy_quantity() -> None: def test_distancemodulus_roundtrip() -> None: """Test roundtrip conversion for DistanceModulus.""" - dm = DistanceModulus(5.0, "mag") + dm = DistanceModulus(5, "mag") apyq = convert(dm, apyu.Quantity) dm_back = convert(apyq, DistanceModulus) diff --git a/packages/coordinax.interop.astropy/tests/test_ptmap_cdict.py b/packages/coordinax.interop.astropy/tests/test_ptmap_cdict.py index 7a481b70..17d361f6 100644 --- a/packages/coordinax.interop.astropy/tests/test_ptmap_cdict.py +++ b/packages/coordinax.interop.astropy/tests/test_ptmap_cdict.py @@ -55,12 +55,12 @@ def make_strat( ) -_pos_km = make_strat("km", 0.5, 100.0) -_any_km = make_strat("km", -100.0, 100.0) -_phi_rad = make_strat("rad", -3.0, 3.0) +_pos_km = make_strat("km", 0.5, 100) +_any_km = make_strat("km", -100, 100) +_phi_rad = make_strat("rad", -3, 3) # 0.1 avoids the polar singularity; 3.04 avoids theta≈π (south pole) _theta_rad = make_strat("rad", 0.1, 3.04) -_lon_rad = make_strat("rad", -3.0, 3.0) +_lon_rad = make_strat("rad", -3, 3) _lat_rad = make_strat("rad", -1.5, 1.5) @@ -85,12 +85,12 @@ def _approx_angle_equal( apy_val = float(apy.to(apy.unit).value) # Convert to radians for the modular comparison - scale = math.pi / 180.0 if apy.unit == "deg" else 1.0 + scale = math.pi / 180 if apy.unit == "deg" else 1 diff_rad = (got_val - apy_val) * scale # Reduce to (-π, π] diff_rad = (diff_rad + math.pi) % (2 * math.pi) - math.pi - assert abs(diff_rad) == pytest.approx(0.0, abs=abs_tol) + assert abs(diff_rad) == pytest.approx(0, abs=abs_tol) # --------------------------------------------------------------------------- diff --git a/packages/coordinax.interop.astropy/tests/test_qmatrix.py b/packages/coordinax.interop.astropy/tests/test_qmatrix.py index e15b07a9..2b652060 100644 --- a/packages/coordinax.interop.astropy/tests/test_qmatrix.py +++ b/packages/coordinax.interop.astropy/tests/test_qmatrix.py @@ -8,7 +8,7 @@ import unxt as u -from coordinax.internal import QuantityMatrix, UnitsMatrix +from coordinax.internal import QMatrix, UnitsMatrix class TestUnitsMatrixToStructuredUnit: @@ -81,23 +81,23 @@ def test_roundtrip_via_unitsmatrix(self) -> None: assert result == su -class TestQuantityMatrixToAstropyQuantity: - """Tests for QuantityMatrix → apyu.Quantity conversion.""" +class TestQMatrixToAstropyQuantity: + """Tests for QMatrix → apyu.Quantity conversion.""" def test_1d(self) -> None: - qmat = QuantityMatrix(jnp.array([1.0, 2.0]), unit=("km", "s")) + qmat = QMatrix(jnp.array([1, 2]), unit=("km", "s")) result = plum.convert(qmat, apyu.Quantity) assert isinstance(result, apyu.Quantity) def test_1d_unit(self) -> None: - qmat = QuantityMatrix(jnp.array([1.0, 2.0]), unit=("km", "s")) + qmat = QMatrix(jnp.array([1, 2]), unit=("km", "s")) result = plum.convert(qmat, apyu.Quantity) assert result.unit == apyu.StructuredUnit(("km", "s")) def test_1d_values(self) -> None: - qmat = QuantityMatrix(jnp.array([3.0, 4.0]), unit=("m", "kg")) + qmat = QMatrix(jnp.array([3, 4]), unit=("m", "kg")) result = plum.convert(qmat, apyu.Quantity) arr = np.array(result) - assert float(arr["f0"]) == pytest.approx(3.0) - assert float(arr["f1"]) == pytest.approx(4.0) + assert float(arr["f0"]) == pytest.approx(3) + assert float(arr["f1"]) == pytest.approx(4) diff --git a/packages/coordinax.interop.astropy/tests/test_vectors.py b/packages/coordinax.interop.astropy/tests/test_vectors.py index 9324f2f2..5f7e3b8d 100644 --- a/packages/coordinax.interop.astropy/tests/test_vectors.py +++ b/packages/coordinax.interop.astropy/tests/test_vectors.py @@ -47,7 +47,7 @@ "nu": u.Q([0.1, 0.2, 0.3, 0.4], "kpc2"), "phi": u.Q([0, 1, 2, 3], "rad"), }, - cxc.ProlateSpheroidal3D(Delta=u.StaticQuantity(1.0, "kpc")), + cxc.ProlateSpheroidal3D(Delta=u.StaticQuantity(1, "kpc")), ) apyprolatesph = None # No corresponding Astropy representation From 8dc53d356d30e9ec156e5c085d625507a192f2e5 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:35:29 -0400 Subject: [PATCH 10/15] =?UTF-8?q?=F0=9F=93=9D=20docs:=20update=20spec=20an?= =?UTF-8?q?d=20guides=20for=20metric=20refactor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit spec.md is updated to reflect the new metric module design (AbstractMetricField, metric_matrix dispatch, DiagonalMetric/DenseMetric matrix types). API and guide pages for manifolds, charts, and representations are updated to use the new class names. README and tutorials are adjusted for minor API changes. --- README.md | 2 +- docs/api/internal.md | 8 +- docs/api/manifolds.md | 2 +- docs/guides/charts.md | 4 +- docs/guides/manifolds.md | 31 +-- docs/guides/representations.md | 5 +- docs/spec.md | 412 ++++++++++++++++++-------------- docs/tutorials/point_objects.md | 7 +- 8 files changed, 263 insertions(+), 208 deletions(-) diff --git a/README.md b/README.md index a6809acd..6bbaff85 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ Use the built-in $S^2$ two-sphere and its round metric to measure angles between >>> cxm.S2 HyperSphericalManifold(ndim=2) >>> cxm.S2.metric -HyperSphericalMetric(ndim=2) +RoundMetric(ndim=2) >>> # At the equator, measure the angle between northward and eastward tangents >>> at = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")} diff --git a/docs/api/internal.md b/docs/api/internal.md index feb68633..56afef90 100644 --- a/docs/api/internal.md +++ b/docs/api/internal.md @@ -25,9 +25,9 @@ These utilities are primarily useful when implementing downstream transforms, Ja ```python import jax.numpy as jnp import unxt as u -from coordinax.internal import QuantityMatrix +from coordinax.internal import QMatrix -J = QuantityMatrix( +J = QMatrix( value=jnp.eye(3), unit=( (u.unit("m/m"), u.unit("m/rad"), u.unit("m/rad")), @@ -37,7 +37,7 @@ J = QuantityMatrix( ) ``` -`QuantityMatrix` supports both 1-D and 2-D cases. This makes it suitable for heterogeneous vectors as well as Jacobians and metric tensors whose entries do not all share the same unit. +`QMatrix` supports both 1-D and 2-D cases. This makes it suitable for heterogeneous vectors as well as Jacobians and metric tensors whose entries do not all share the same unit. ## Packing Helpers @@ -66,7 +66,7 @@ Use `pack_uniform_unit` when all components should be expressed in a shared unit ### Heterogeneous Unit Containers -- `QuantityMatrix`: N-D quantity container with per-element units; currently supports 1-D vectors and 2-D matrices +- `QMatrix`: N-D quantity container with per-element units; currently supports 1-D vectors and 2-D matrices - `UnitsMatrix`: immutable nested tuple of units with tuple-style indexing and shape metadata ### Packing Utilities diff --git a/docs/api/manifolds.md b/docs/api/manifolds.md index 7b186aa1..3dae08f5 100644 --- a/docs/api/manifolds.md +++ b/docs/api/manifolds.md @@ -39,7 +39,7 @@ M3 = cxm.guess_manifold(cxc.sph2) at = {"x": u.Q(0, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")} uvec = {"x": u.Q(1, "km"), "y": u.Q(0, "km"), "z": u.Q(0, "km")} vvec = {"x": u.Q(0, "km"), "y": u.Q(1, "km"), "z": u.Q(0, "km")} -ang = M.angle_between(cxc.cart3d, uvec, vvec, at=at) +ang = cxm.angle_between(cxc.cart3d, uvec, vvec, at=at) ``` ## Functional API diff --git a/docs/guides/charts.md b/docs/guides/charts.md index 05b2a93e..10b574c2 100644 --- a/docs/guides/charts.md +++ b/docs/guides/charts.md @@ -104,7 +104,7 @@ Product-chart transitions are factorwise. ### Direct call — quantity-valued dictionary input -Passing a component dictionary with `unxt.Quantity` values returns a `QuantityMatrix` whose element `[j, i]` carries the unit `output_unit_j / input_unit_i`: +Passing a component dictionary with `unxt.Quantity` values returns a `QMatrix` whose element `[j, i]` carries the unit `output_unit_j / input_unit_i`: ```{code-block} python >>> import coordinax.charts as cxc @@ -113,7 +113,7 @@ Passing a component dictionary with `unxt.Quantity` values returns a `QuantityMa >>> at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} >>> J = cxc.jac_pt_map(at, cxc.cart3d, cxc.sph3d) >>> J -QuantityMatrix( +QMatrix( [[ 1., 0., 0.], [-0., -0., -1.], [ 0., 1., 0.]], diff --git a/docs/guides/manifolds.md b/docs/guides/manifolds.md index 8f6410da..8e44b2f4 100644 --- a/docs/guides/manifolds.md +++ b/docs/guides/manifolds.md @@ -47,12 +47,12 @@ False >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm ->>> E2 = cxm.EuclideanManifold(2) ->>> E2.default_chart() +>>> R2 = cxm.EuclideanManifold(2) +>>> R2.default_chart() Cart2D(M=Rn(2)) ->>> E2.has_chart(cxc.cart2d) +>>> R2.has_chart(cxc.cart2d) True ->>> E2.has_chart(cxc.polar2d) +>>> R2.has_chart(cxc.polar2d) True ``` @@ -64,7 +64,7 @@ True >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm ->>> S2 = cxm.HyperSphericalManifold() +>>> S2 = cxm.HyperSphericalManifold(2) >>> S2.default_chart() SphericalTwoSphere(M=Sn(2)) >>> S2.has_chart(cxc.sph2) @@ -97,7 +97,6 @@ Use `cxc.pt_map` (or `cxm.pt_map`) to convert a point between two charts on the >>> import coordinax.manifolds as cxm >>> import unxt as u ->>> M = cxm.EuclideanManifold(2) >>> p = {"x": u.Q(1, "km"), "y": u.Q(1, "km")} >>> p_pol = cxc.pt_map(p, cxc.cart2d, cxc.polar2d) >>> sorted(p_pol) @@ -108,7 +107,7 @@ Use `cxc.pt_map` (or `cxm.pt_map`) to convert a point between two charts on the Use `scale_factors` when you want the diagonal entries of the metric matrix in a chart. -This returns the metric diagonal $g_{ii}$, not the basis lengths $\sqrt{g_{ii}}$. The result is a 1-D `QuantityMatrix` because different coordinate directions can carry different units. +This returns the metric diagonal $g_{ii}$, not the basis lengths $\sqrt{g_{ii}}$. The result is a 1-D `QMatrix` because different coordinate directions can carry different units. ```{code-block} python >>> import coordinax.charts as cxc @@ -116,14 +115,13 @@ This returns the metric diagonal $g_{ii}$, not the basis lengths $\sqrt{g_{ii}}$ >>> import quaxed.numpy as jnp >>> import unxt as u ->>> M = cxm.EuclideanManifold(3) >>> at = { ... "r": u.Q(2, "km"), ... "theta": u.Angle(jnp.pi / 2, "rad"), ... "phi": u.Angle(0, "rad"), ... } ->>> gdiag = M.scale_factors(cxc.sph3d, at=at) +>>> gdiag = cxm.scale_factors(cxc.sph3d, at=at) >>> gdiag.shape (3,) >>> jnp.allclose(gdiag.value, jnp.array([1.0, 4.0, 4.0])) @@ -132,7 +130,7 @@ Array(True, dtype=bool) '(, km2 / rad2, km2 / rad2)' ``` -For generic metrics, `scale_factors` follows the metric matrix path and returns the diagonal. For `EuclideanMetric`, coordinax uses a more efficient specialization that avoids forming the full metric matrix. +For generic metrics, `scale_factors` follows the metric matrix path and returns the diagonal. For `FlatMetric`, coordinax uses a more efficient specialization that avoids forming the full metric matrix. ## Measuring Angles Between Tangent Vectors @@ -146,12 +144,11 @@ This is a tangent-space operation, not a point-to-point operation. The vectors a >>> import quaxed.numpy as jnp >>> import unxt as u ->>> M = cxm.EuclideanManifold(2) >>> at = {"x": u.Q(0, "m"), "y": u.Q(0, "m")} >>> uvec = {"x": u.Q(1, "m"), "y": u.Q(0, "m")} >>> vvec = {"x": u.Q(0, "m"), "y": u.Q(1, "m")} ->>> ang = M.angle_between(cxc.cart2d, uvec, vvec, at=at) +>>> ang = cxm.angle_between(cxc.cart2d, uvec, vvec, at=at) >>> jnp.allclose(u.ustrip("rad", ang), jnp.pi / 2) Array(True, dtype=bool) ``` @@ -159,7 +156,7 @@ Array(True, dtype=bool) For curvilinear charts, the angle is still intrinsic, but the metric weights the coordinate directions at the supplied base point: ```{code-block} python ->>> metric = cxm.HyperSphericalMetric(ndim=2) +>>> metric = cxm.RoundMetric(ndim=2) >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} >>> uvec = {"theta": jnp.array(1.0), "phi": jnp.array(0.0)} >>> vvec = {"theta": jnp.array(1.0), "phi": jnp.array(1.0)} @@ -237,8 +234,7 @@ Use `EmbeddedManifold` when you need explicit manifold objects with atlas compat >>> import unxt as u >>> em = cxm.EmbeddedManifold( -... intrinsic=cxm.HyperSphericalManifold(), -... ambient=cxm.EuclideanManifold(3), +... intrinsic=cxm.S2, ambient=cxm.R3, ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "km")), ... ) >>> em.ndim @@ -300,7 +296,7 @@ When needed, build manifolds from explicit chart sets with `CustomAtlas` and `Cu >>> import coordinax.manifolds as cxm >>> A = cxm.CustomAtlas(charts=(cxc.Cart2D, cxc.Polar2D), chart_default=cxc.cart2d) ->>> M = cxm.CustomManifold(A, metric=cxm.EuclideanMetric(2)) +>>> M = cxm.CustomManifold(A, metric=cxm.FlatMetric(2)) >>> M.has_chart(cxc.cart2d) True @@ -316,8 +312,7 @@ Product manifolds combine independent factors and sum dimensions. >>> import coordinax.manifolds as cxm >>> MP = cxm.CartesianProductManifold( -... factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), -... factor_names=("S2", "R1"), +... factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1") ... ) >>> MP.ndim 3 diff --git a/docs/guides/representations.md b/docs/guides/representations.md index 5bf154cc..089b1563 100644 --- a/docs/guides/representations.md +++ b/docs/guides/representations.md @@ -136,7 +136,7 @@ The representation design is intentionally extensible. Future geometric kinds (f - supported basis changes: `CoordinateBasis` $\rightleftarrows$ `PhysicalBasis` - supported representations: tangent representations such as `coord_disp` and `phys_disp` - point representations are not supported as genuine basis-changing inputs; however, `NoBasis -> CoordinateBasis` and `NoBasis -> PhysicalBasis` are supported as identity reinterpretations when the dimensions are compatible -- non-Cartesian support: available for tangent basis changes on charts with basis-change rules (for example `sph3d`), and generally via an explicit metric/manifold +- non-Cartesian support: available for tangent basis changes on charts with basis-change rules (for example `sph3d`), and generally via an explicit manifold ```{code-block} python >>> import coordinax.charts as cxc @@ -169,8 +169,7 @@ The representation design is intentionally extensible. Future geometric kinds (f >>> cxr.change_basis(v_sph, cxc.sph3d, cxr.coord_basis, cxr.phys_basis, at=at_sph) {'r': Q(5, 'm / s'), 'theta': Q(2, 'm / s'), 'phi': Q(2., 'm / s')} ->>> metric = cxm.EuclideanMetric(3) ->>> cxr.change_basis(v_sph, cxc.sph3d, metric, cxr.coord_basis, cxr.phys_basis, at=at_sph) +>>> cxr.change_basis(v_sph, cxc.sph3d, cxm.R3, cxr.coord_basis, cxr.phys_basis, at=at_sph) {'r': Q(5, 'm / s'), 'theta': Q(2, 'm / s'), 'phi': Q(2., 'm / s')} ``` diff --git a/docs/spec.md b/docs/spec.md index 2a069618..97c84731 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -969,7 +969,7 @@ A non-exhaustive table of exported objects are: | `coordinax.charts` | `CartesianProductChart`,
`cartesian_chart`, `guess_chart`, `cdict`, `pt_map`, `jac_pt_map`,
`cart0d`,
`cart1d`, `radial1d`, `time1d`,
`cart2d`, `polar2d`,
`cart3d`, `cyl3d`, `sph3d`, `lonlat_sph3d`, `loncoslat_sph3d`, `math_sph3d`,
`cartnd`,
`minkowskict`, `galileanct` | | `coordinax.representations` | `cconvert`, `change_basis`, `tangent_map`,
`Representation`, `point`, `coord_disp`, `coord_vel`, `coord_acc`, `phys_disp`, `phys_vel`, `phys_acc`,
`PointGeometry`, `point_geom`, `TangentGeometry`, `tangent_geom`,
`NoBasis`, `no_basis`, `CoordinateBasis`, `coord_basis`, `PhysicalBasis`, `phys_basis`,
`Location`, `loc`, `Displacement`, `dpl`, `Velocity`, `vel`, `Acceleration`, `acc`,
`guess_geometry_kind`, `guess_semantic_kind`, `guess_rep` | | `coordinax.vectors` | `Point`, `Tangent`, `Coordinate`, `ToUnitsOptions` | -| `coordinax.manifolds` | `guess_manifold`, `scale_factors`, `angle_between`,
`EuclideanManifold`, `Rn`, `EuclideanMetric`, `R3`,
`EmbeddedManifold`, `EmbeddedChart`
`S2`, `embedded_twosphere`,
`CustomManifold`,`CustomAtlas`,
`CartesianProductManifold`, `galilean_spacetime` | +| `coordinax.manifolds` | `guess_manifold`, `scale_factors`, `angle_between`,
`EuclideanManifold`, `Rn`, `FlatMetric`, `R3`,
`EmbeddedManifold`, `EmbeddedChart`
`S2`, `embedded_twosphere`,
`CustomManifold`,`CustomAtlas`,
`CartesianProductManifold`, `galilean_spacetime` | | `coordinax.transforms` | `act`, `simplify`, `compose`, `materialize_transform`,
`AbstractTransform`, `Identity`, `Composed`, `Translate`, `Rotate`, `Reflect`, `Scale`, `Shear`, `identity`,
`AbstractTransformGroup`, `IdentityGroup`, `DiffeomorphismGroup`, `AffineGroup`, `EuclideanGroup`, `OrthogonalGroup`, `SpecialOrthogonalGroup`, `PoincareGroup`, `LorentzGroup`, `ProperOrthochronousLorentzGroup` | | `coordinax.frames` | `frame_transition`,
`AbstractReferenceFrame`, `FrameTransformError`,
`NoFrame`, `Alice`, `Alex`, `TransformedReferenceFrame` | @@ -983,8 +983,8 @@ These names are importable and supported for inter-package use, but they are **n Semi-public API: -- `QuantityMatrix`: heterogeneous 1-D or 2-D quantity container with a per-element unit structure -- `UnitsMatrix`: immutable, hashable wrapper around a numpy object array of `AbstractUnit` elements, aligned with a `QuantityMatrix`; supports tuple-style indexing, iteration, and `to_tuple()`/`to_string()`. **Not** a subclass of `astropy.StructuredUnit`; bidirectional converters to/from `astropy.StructuredUnit` live in `coordinax.interop.astropy`. +- `QMatrix`: heterogeneous 1-D or 2-D quantity container with a per-element unit structure +- `UnitsMatrix`: immutable, hashable wrapper around a numpy object array of `AbstractUnit` elements, aligned with a `QMatrix`; supports tuple-style indexing, iteration, and `to_tuple()`/`to_string()`. **Not** a subclass of `astropy.StructuredUnit`; bidirectional converters to/from `astropy.StructuredUnit` live in `coordinax.interop.astropy`. - `cdict_units`: extract per-component units from a coordinate dictionary - `pack_uniform_unit`: stack component data into an array using a shared unit - `pack_nonuniform_unit`: stack component data into an array while preserving per-component units @@ -1142,7 +1142,7 @@ The `coordinax.charts` module provides the chart-facing API for representing poi - Input `unxt.AbstractQuantity` with last axis size 1, 2, or 3: call `guess_chart` on the quantity to infer chart dimensionality, then apply the chart-based dispatch. - Input array-like with chart context: split last axis into named components using `chart.components`. Last axis length MUST match the chart's component count. - Input `unxt.AbstractQuantity` with chart context: split last axis into named quantities using `chart.components`. Requires last axis size to match chart. - - Input `QuantityMatrix` with chart context: extract heterogeneous per-component quantities, one for each chart component. + - Input `QMatrix` with chart context: extract heterogeneous per-component quantities, one for each chart component. Failure semantics: @@ -1212,7 +1212,7 @@ The `coordinax.charts` module provides the chart-facing API for representing poi raw array Jacobian. Requires `usys`. - `(at: CDict, from_chart, to_chart, /, *, usys: OptUSys = None)` -> - `Array | QuantityMatrix`. The general dict dispatch. Branches on whether `at` + `Array | QMatrix`. The general dict dispatch. Branches on whether `at` values are plain arrays or quantities: - **Array-valued** (`is_array=True`): stacks `at` into a plain array via @@ -1220,13 +1220,13 @@ The `coordinax.charts` module provides the chart-facing API for representing poi forwarded and must be an `AbstractUnitSystem` unless a more-specific analytical dispatch handles it. - - **Quantity-valued** (`is_array=False`): packs `at` into a 1-D `QuantityMatrix` + - **Quantity-valued** (`is_array=False`): packs `at` into a 1-D `QMatrix` via `pack_to_qmatrix(at, keys=from_chart.components)`, casts to `float`, then computes `J_qq = jax.jacfwd(pt_map_fn)(at_in)`. The jacfwd result is a - `QuantityMatrix` whose `.value` is itself a `QuantityMatrix` encoding the input + `QMatrix` whose `.value` is itself a `QMatrix` encoding the input units, and whose `.unit` encodes the output units. `_repack_q_from_jac` extracts both to build the correct 2-D `UnitsMatrix` and returns - `QuantityMatrix(J_arr, unit=unit_matrix)` of shape `(n_out, n_in)`. + `QMatrix(J_arr, unit=unit_matrix)` of shape `(n_out, n_in)`. **`usys` parameter:** required for the `None`-partial, curried, and plain-`Array` dispatches. Optional (`None`) for the `CDict` generic dispatch's quantity-valued @@ -1822,7 +1822,17 @@ A representation is therefore **not** the same thing as a chart: the chart deter **`cconvert` integration:** `cconvert` dispatches to `tangent_map` when the source representation has `TangentGeometry`, passing `at` through the `at` keyword argument. - **Same-chart basis conversion:** `change_basis(v, chart, from_basis, to_basis, /, *, at)` changes tangent component conventions without changing charts. In v1 it is defined only for Cartesian charts and `CoordinateBasis` $ + **Same-chart basis conversion:** `change_basis(v, chart, from_basis, to_basis, /, *, at)` changes tangent component conventions without changing charts. It is defined for `CoordinateBasis` ↔ `PhysicalBasis` conversions on any chart that carries a manifold (via `chart.M`), and as an identity map for `NoBasis` and same-basis cases. The following primary overloads are defined: + + - `change_basis(v, chart, from_basis, to_basis, /, *, at)` — uses `chart.M` as the implicit manifold. + - `change_basis(v, chart, manifold, from_basis, to_basis, /, *, at)` — uses an explicit `manifold: AbstractManifold`; internally delegates to `manifold.metric`. + + **Basis semantics for `change_basis`:** + + - **`CoordinateBasis` → `PhysicalBasis`**: Call `metric_matrix(manifold, at, chart)`. If the result is a `DiagonalMetric` (i.e. `isinstance(manifold.metric, AbstractDiagonalMetricField)` held when the dispatch registered the rule), scale each component by the scale factor $h_i = \sqrt{g_{ii}}$: $\hat{v}^i = h_i\,v^i$ (no summation). If the result is a `DenseMetric`, compute the Cholesky factor $L$ of the metric matrix and apply the vielbein $E = L^\top$: $\hat{v} = E\,v$. + - **`PhysicalBasis` → `CoordinateBasis`**: inverse of the above. For diagonal metrics, $v^i = \hat{v}^i / h_i$. For general metrics, triangular-solve $E\,v = \hat{v}$. + - **Cartesian charts** (`Cart0D` … `CartND`) with any manifold: always the identity map, returned at `precedence=1` (scale factors are all 1). + - **`NoBasis` → `CoordinateBasis`** or **`NoBasis` → `PhysicalBasis`**: identity map (used for basis-agnostic data). The `NoBasis` → `PhysicalBasis` overload additionally checks that all components share the same physical dimension. **Point-to-Displacement promotion:** `change_basis(v: Point, to_basis, /, *, at)` promotes a `Point` to a `Tangent` with `Displacement` semantics. The component data are unchanged; only the geometric interpretation is recast from a manifold point (affine, `PointGeometry`) to a tangent-space displacement vector (`TangentGeometry`, `Displacement`). The resulting `Tangent` carries the same chart as the input `Point`, and its basis is `to_basis`. This operation is the inverse of treating a displacement as an absolute position, and it respects the affine/tangent distinction enforced throughout the spec. @@ -2636,34 +2646,26 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial cxm.scale_factors(chart, /, *, at, usys=None) ``` - Or via convenience wrappers on metric and manifold objects: - - ``` - metric.scale_factors(chart, at=at, usys=usys) - manifold.scale_factors(chart, at=at, usys=usys) - ``` - **Arguments:** - - `metric_or_manifold`: an `AbstractMetric` instance (metric-level call) or an `AbstractManifold` instance (manifold-level call). When a manifold is passed, `scale_factors` delegates to `manifold.metric`. + - `manifold_or_metric`: an `AbstractManifold` or `AbstractMetricField` instance. - `chart`: the coordinate chart in which the metric is expressed. - `at` (keyword): the base point $p$ in chart coordinates at which the metric is evaluated. - `usys` (keyword, optional): unit system forwarded to metric evaluation when needed. **Return:** - - Always a 1-D `QuantityMatrix` of length `ndim`. - - A `QuantityMatrix` is used even when the diagonal entries are dimensionless, because different coordinate directions may carry different units. + - Always a 1-D `QMatrix` of length `ndim`. + - A `QMatrix` is used even when the diagonal entries are dimensionless, because different coordinate directions may carry different units. **Dispatch behavior:** - - Generic metric dispatch: evaluate `metric.metric_matrix(chart, at=at, usys=usys)` and return `QuantityMatrix.diag()` on the result. If the metric matrix is array-valued, it is first promoted to a dimensionless `QuantityMatrix` and then diagonalized. - - Manifold dispatch: resolve to `scale_factors(manifold.metric, chart, at=at, usys=usys)`. - - `EuclideanMetric` specialization: compute the diagonal more efficiently than forming the full metric matrix. In Cartesian charts this returns a vector of ones directly; in non-Cartesian Euclidean charts it uses the chart-to-Cartesian Jacobian and computes only the squared column norms needed for the diagonal entries. + - Evaluate `metric_matrix(manifold, at, chart)` and extract the diagonal. + - `FlatMetric` specialization: compute the diagonal more efficiently than forming the full metric matrix. In Cartesian charts this returns a vector of ones directly; in non-Cartesian Euclidean charts it uses the chart-to-Cartesian Jacobian and computes only the squared column norms needed for the diagonal entries. **Position dependence:** - - For flat metrics (for example `EuclideanMetric` in Cartesian coordinates), the result may be position-independent numerically, though `at` remains part of the API. + - For flat metrics (for example `FlatMetric` in Cartesian coordinates), the result may be position-independent numerically, though `at` remains part of the API. - For curved or curvilinear cases, the returned diagonal entries depend on the supplied base point. **Examples** @@ -2683,20 +2685,8 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> gdiag = cxm.scale_factors(cxc.sph3d, at=at) >>> gdiag.shape (3,) - >>> gdiag.unit.to_string() - '(, km2 / rad2, km2 / rad2)' - - >>> metric = cxm.HyperSphericalMetric(2) - >>> at_s2 = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} - >>> cxm.scale_factors(metric, cxc.sph2, at=at_s2).value - Array([1., 1.], dtype=float64) ``` - **Notes:** - - - `AbstractMetric.scale_factors` and `AbstractManifold.scale_factors` are thin wrappers over `cxm.scale_factors`. - - The name `scale_factors` in the software API follows the library convention for metric diagonal entries, even though some mathematical texts reserve “scale factor” for $\sqrt{g_{ii}}$. - (software-spec-angle-between)= !!! info `angle_between` @@ -2723,23 +2713,16 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial **Signature:** ``` - cxm.angle_between(metric, chart, u, v, /, *, at, usys=None) + cxm.angle_between(manifold, chart, u, v, /, *, at, usys=None) ``` or ``` cxm.angle_between(chart, u, v, /, *, at, usys=None) ``` - Or via convenience wrappers on metric and manifold objects: - - ``` - metric.angle_between(chart, u, v, at=at, usys=usys) - manifold.angle_between(chart, u, v, at=at, usys=usys) - ``` - **Arguments:** - - `metric_or_manifold`: an `AbstractMetric` instance (metric-level call) or an `AbstractManifold` instance (manifold-level call). When a manifold is passed, `angle_between` delegates to `manifold.metric`. + - `manifold_or_metric`: an `AbstractManifold` or `AbstractMetricField` instance. When a manifold is passed, `angle_between` uses `manifold.metric`. - `chart`: the coordinate chart in whose basis the tangent-vector components are expressed. - `u`, `v`: `CDict` tangent-vector components keyed by `chart.components`. - `at` (keyword): the base point $p$ in chart coordinates at which the metric is evaluated. @@ -2752,8 +2735,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial **Dispatch behavior:** - - Generic metric dispatch: evaluate `metric.metric_matrix(chart, at=at, usys=usys)`, compute the bilinear forms `u^T g v`, `u^T g u`, and `v^T g v`, then return `arccos(...)` of the normalized inner product. - - Manifold dispatch: resolve to `angle_between(chart, u, v, at=at, usys=usys)`. + - Evaluate `metric_matrix(manifold, at, chart)`, compute the bilinear forms `u^T g v`, `u^T g u`, and `v^T g v`, then return `arccos(...)` of the normalized inner product. - The implementation supports full symmetric metric matrices; it is not restricted to diagonal metrics. **Failure semantics:** @@ -2770,7 +2752,6 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> M = cxm.EuclideanManifold(2) >>> at = {"x": u.Q(0.0, "m"), "y": u.Q(0.0, "m")} >>> uvec = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m")} >>> vvec = {"x": u.Q(0.0, "m"), "y": u.Q(1.0, "m")} @@ -2851,11 +2832,11 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Chart-level transitions can be called directly via `cxc.pt_map`. -(software-spec-abstractmetric)= +(software-spec-abstractmetricfield)= -!!! info `AbstractMetric` +!!! info `AbstractMetricField` - `AbstractMetric` is the abstract base for metric tensors used by manifold objects. + `AbstractMetricField` is the abstract base for metric tensors used by manifold objects. A metric assigns a symmetric, non-degenerate bilinear form to each tangent space: @@ -2868,49 +2849,45 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial **Immutability and JAX-static requirements:** - - All metric classes are immutable frozen dataclasses. - - All metric classes are registered with `@jax.tree_util.register_static`. - - Metric instances therefore flatten as static PyTree nodes (no dynamic leaves). + - Most metric classes are immutable frozen dataclasses registered with `@jax.tree_util.register_static` and therefore flatten as static PyTree nodes (no dynamic leaves). + - Exception: `RoundMetric` from `coordinax._src.metric.field` is an `equinox.Module` with a dynamic `radius` leaf, allowing JIT-compilation and differentiation through the radius parameter. **Core API contract:** - `signature: tuple[int, ...]` (abstract property): encodes metric signature signs in coordinate order and has length equal to the metric dimension. - `ndim: int`: defined as `len(signature)`. - - `metric_matrix(chart, /, *, at, usys=None) -> QuantityMatrix | Array` (abstract method): returns the metric matrix expressed in `chart`, evaluated at base point `at`. - **Metric matrix requirements:** + **Subclassing requirements:** - - Shape is `(ndim, ndim)`. - - Matrix is symmetric. - - The return type follows input/unit context. - - Quantity-valued evaluation returns `QuantityMatrix`. - - Array-valued evaluation may return `Array`. + - Implement `signature`. + - Preserve immutability and static PyTree behavior (or use `equinox.Module` if carrying dynamic parameters). + - Remain JAX-transform compatible (`jit`, `vmap`) as pure functions of inputs. - **Behavioral guarantees:** + **Three-layer metric architecture:** - - Position-independent metrics (for example Minkowski in canonical coordinates) may ignore `at` numerically, but must still satisfy the same interface. - - Position-dependent metrics (for example induced or hyperspherical metrics) evaluate `metric_matrix` at the supplied base point. - - Chart compatibility is a manifold-level concern; metrics assume callers provide charts and points compatible with the surrounding manifold contract. + The metric matrix computation follows a three-layer design: - **Subclassing requirements:** + 1. **Metric field** (`AbstractMetricField` subclass): describes the *kind* of geometry (flat, round, Lorentzian, etc.) without specifying a coordinate representation or computing matrix components. + 2. **Dispatch layer** (`metric_matrix` function): given a manifold, a base point, and a chart, computes and returns the metric matrix as a typed result object. + 3. **Metric matrix** (`AbstractMetricMatrix` subclass: `DiagonalMetric` or `DenseMetric`): encodes the computed matrix and its sparsity structure. - - Implement `signature` and `metric_matrix`. - - Preserve immutability and static PyTree behavior. - - Remain JAX-transform compatible (`jit`, `vmap`) as pure functions of inputs. + The metric matrix is obtained via the dispatch function, not via a method on the metric field object: - See the [Metrics](#software-spec-metrics) section for concrete metric families and formulas. + ```python + import coordinax.api.manifolds as cxmapi - **Methods:** + g = cxmapi.metric_matrix(manifold, point, chart) # → DiagonalMetric or DenseMetric + ``` + + See [metric_matrix dispatch function](#software-spec-metric-matrix-dispatch) and [AbstractMetricMatrix](#software-spec-abstractmetricmatrix) for details. - - `scale_factors(chart, /, *, at, usys=None)`: convenience wrapper around [`cxm.scale_factors`](#software-spec-scale-factors). Returns the 1-D `QuantityMatrix` of diagonal metric entries in `chart` at base point `at`. - - `is_diagonal(chart, /, *, at, usys=None) -> Bool[Array, ""]`: returns `True` if all off-diagonal entries of the metric matrix vanish at `at`, checked numerically via `jnp.allclose`. This is a **point-specific, numerical** check. For a **structural, global** guarantee on a metric's diagonal chart domain (typically orthogonal charts), use `isinstance(metric, AbstractDiagonalMetric)`. - - `cholesky(chart, /, *, at, usys=None) -> QuantityMatrix | Array`: returns the lower-triangular Cholesky factor $L$ satisfying $g = L\,L^\top$. The vielbein is $E = L^\top$; see [Diagonal metrics and orthogonal coordinate systems](#diagonal-metrics-and-orthogonal-coordinate-systems) for the relationship to physical-basis components. Returns a `QuantityMatrix` when the metric matrix carries units, otherwise a plain `Array`. Element $L_{ij}$ carries unit $\sqrt{u_{ij}}$ where $u_{ij}$ is the unit of $g_{ij}$. + See the [Metrics](#software-spec-metrics) section for concrete metric families and formulas. -(software-spec-abstractdiagonalmetric)= +(software-spec-abstractdiagonalmetricfield)= -!!! info `AbstractDiagonalMetric` +!!! info `AbstractDiagonalMetricField` - `AbstractDiagonalMetric` is an abstract subclass of `AbstractMetric` for metrics whose matrix is diagonal at every base point in every compatible chart. + `AbstractDiagonalMetricField` is an abstract subclass of `AbstractMetricField` for metrics whose matrix is diagonal at every base point in every compatible chart. A metric is **diagonal** (equivalently, the coordinate chart is an **orthogonal coordinate system**) when all off-diagonal entries of the metric matrix vanish: @@ -2926,7 +2903,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial **Role: structural marker, not behavioral interface.** - `AbstractDiagonalMetric` adds no new abstract methods beyond those of `AbstractMetric`. Its sole purpose is to **declare** that `metric_matrix` will always return a diagonal matrix. This allows: + `AbstractDiagonalMetricField` adds no new abstract methods beyond those of `AbstractMetricField`. Its sole purpose is to **declare** that `metric_matrix` will always return a diagonal matrix. This allows: - Dispatch specialisations that compute `scale_factors` more efficiently (e.g., extracting only the diagonal of $g$, or using squared Jacobian @@ -2936,60 +2913,142 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial **Subclassing contract:** - Subclasses must implement the two abstract members inherited from - `AbstractMetric`: + Subclasses must implement the abstract member inherited from + `AbstractMetricField`: - `signature` (abstract property): a tuple of $\pm 1$ of length `ndim` encoding the metric signature in coordinate order. Positive entries are Riemannian (space-like); a ``-1`` entry is pseudo-Riemannian (time-like). - - `metric_matrix(chart, /, *, at, usys=None)` (abstract method): must - return a diagonal `QuantityMatrix` (or plain `Array`) of shape - `(ndim, ndim)` with all off-diagonal entries exactly zero. - All other behavioral requirements of `AbstractMetric` also apply: - immutability (frozen dataclass), static JAX PyTree registration, and - `jit`/`vmap` compatibility. + All other behavioral requirements of `AbstractMetricField` also apply: + immutability (frozen dataclass or `equinox.Module`), static JAX PyTree + registration, and `jit`/`vmap` compatibility. - **Relationship to `AbstractMetric.is_diagonal`:** + **Relationship to `metric_representation`:** - `AbstractMetric.is_diagonal(chart, at=at)` inspects the metric matrix at - a **specific base point** and returns a `bool`. - `AbstractDiagonalMetric` makes this an unconditional **structural - promise** across all base points: instances are always diagonal, - regardless of which chart or point is supplied. + `metric_representation(manifold, chart)` returns the *type* of metric matrix + (`DiagonalMetric` or `DenseMetric`) that will be produced for a given + manifold–chart pair. When `isinstance(manifold.metric, AbstractDiagonalMetricField)`, + dispatch rules can return a `DiagonalMetric` (storing only the diagonal) + rather than a full `DenseMetric`. This is a **structural, global** + guarantee — instances are always diagonal regardless of the base point. **Concrete subclasses (built-in):** | Class | Manifold | Diagonal in | |-------|----------|-------------| - | [`EuclideanMetric`](#software-spec-euclideanmetric) | $\mathbb{R}^n$ | Cartesian charts; orthogonal curvilinear charts via $g = J^\top J$ | - | [`HyperSphericalMetric`](#software-spec-hypersphericalmetric) | $S^{n-1}$ | Intrinsic hyperspherical chart; cumulative-sine rule $g_{kk} = \prod_{j>> from coordinax._src.manifolds.diagonal import AbstractDiagonalMetric + >>> from coordinax._src.base import AbstractDiagonalMetricField >>> import coordinax.manifolds as cxm - >>> isinstance(cxm.EuclideanMetric(3), AbstractDiagonalMetric) + >>> isinstance(cxm.FlatMetric(3), AbstractDiagonalMetricField) True - >>> isinstance(cxm.MinkowskiMetric(), AbstractDiagonalMetric) + >>> isinstance(cxm.MinkowskiMetric(), AbstractDiagonalMetricField) True >>> import unxt as u >>> isinstance( - ... cxm.InducedMetric( + ... cxm.PullbackMetric( ... cxm.TwoSphereIn3D(radius=u.Q(1.0, "m")), - ... cxm.EuclideanMetric(3), + ... cxm.FlatMetric(3), ... ), - ... AbstractDiagonalMetric, + ... AbstractDiagonalMetricField, ... ) False ``` +(software-spec-metric-matrix-dispatch)= + +!!! info `metric_matrix` and `metric_representation` dispatch functions + + The metric matrix computation is separated from the metric field type. + Two standalone dispatch functions bridge metric fields to concrete matrix results: + + **`metric_matrix(manifold, point, chart) → AbstractMetricMatrix`** + + Evaluates the metric at `point` in `chart` and returns a typed matrix object: + + ``` + cxmapi.metric_matrix(manifold, point, chart) → DiagonalMetric | DenseMetric + ``` + + - `manifold`: an `AbstractManifold` instance (e.g. `cxm.R3`, `cxm.S2`, `EmbeddedManifold`). + - `point`: a coordinate dict in the given chart's coordinate system. + - `chart`: the coordinate chart in which the metric is expressed. + + Returns a `DiagonalMetric` when the manifold's metric is an `AbstractDiagonalMetricField` and the chart is orthogonal; otherwise returns a `DenseMetric`. + + **`metric_representation(manifold, chart) → type[AbstractMetricMatrix]`** + + Returns the *type* of metric matrix that `metric_matrix` will produce for the given manifold–chart pair, without evaluating at a point: + + ``` + cxmapi.metric_representation(manifold, chart) → type[DiagonalMetric] | type[DenseMetric] + ``` + + **Example:** + + ```pycon + >>> import jax.numpy as jnp + >>> import coordinax.api.manifolds as cxmapi + >>> import coordinax.charts as cxc + >>> import coordinax.manifolds as cxm + + >>> at = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + >>> g = cxmapi.metric_matrix(cxm.R3, at, cxc.cart3d) + >>> g.diagonal + Array([1., 1., 1.], dtype=float64) + + >>> cxmapi.metric_representation(cxm.R3, cxc.cart3d) + + ``` + +(software-spec-abstractmetricmatrix)= + +!!! info `AbstractMetricMatrix`, `DiagonalMetric`, `DenseMetric` + + These types encode the result of `metric_matrix` together with its sparsity structure. + + **`AbstractMetricMatrix`** (base class): + + - `ndim: int` — dimension of the metric. + - `to_dense() → DenseMetric` — expand to a full $(n \times n)$ matrix. + + **`DiagonalMetric`** (returned for orthogonal charts): + + - Stores only the 1-D diagonal `(g_{11}, \ldots, g_{nn})` as an `Array` or `QMatrix`. + - `diagonal: Array | QMatrix` — the diagonal entries. + - `inverse → DiagonalMetric` — reciprocal of each diagonal entry. + - `determinant → Array | Quantity` — product of diagonal entries. + - `__matmul__(v) → Array` — element-wise product (O(n), not O(n²)). + + **`DenseMetric`** (returned for non-orthogonal charts and embedded manifolds): + + - Stores the full $(n \times n)$ matrix as an `Array` or `QMatrix`. + - `matrix: Array | QMatrix` — the full matrix. + - `inverse → DenseMetric` — via `jnp.linalg.inv`; preserves unit tracking. + - `determinant → Array | Quantity` — via a custom `det_p` JAX primitive; preserves unit tracking. + - `__matmul__(v) → Array | QMatrix` — full matrix-vector product. + + **Unit tracking:** + + Both types propagate units through `QMatrix` fields. For metrics + induced by Jacobian pullback, units are `cart_unit² / (intrinsic_unit_i × intrinsic_unit_j)`. + + **Basis change integration:** + + `change_basis` uses the metric matrix type to select the efficient path: + - `DiagonalMetric` → scale each component by the scale factor $h_i = \sqrt{g_{ii}}$. + - `DenseMetric` → compute Cholesky vielbein $E = L^\top$ and apply $\hat{v} = E\,v$. + (software-spec-abstractmanifold)= !!! info `AbstractManifold` @@ -3027,7 +3086,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial The manifold provides thin wrappers around coordinate transformations that ensure atlas compatibility before delegating to chart‑level machinery. - ``pt_map(...)`` performs chart transitions while checking that both charts belong to the manifold. - - ``scale_factors(chart, /, *, at, usys=None)``: convenience wrapper that delegates to the manifold metric. Returns the 1-D `QuantityMatrix` of diagonal metric entries in `chart` at base point `at`. See the [`scale_factors` functional API section](#software-spec-scale-factors) for full semantics. + - ``scale_factors(chart, /, *, at, usys=None)``: convenience wrapper that delegates to the manifold metric. Returns the 1-D `QMatrix` of diagonal metric entries in `chart` at base point `at`. See the [`scale_factors` functional API section](#software-spec-scale-factors) for full semantics. Pre-defined manifolds: @@ -3138,11 +3197,11 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial False ``` -(software-spec-euclideanmetric)= +(software-spec-flatmetric)= -!!! info `EuclideanMetric` +!!! info `FlatMetric` - `EuclideanMetric` is the flat Riemannian metric on $\mathbb{R}^n$. + `FlatMetric` is the flat Riemannian metric on $\mathbb{R}^n$. In Cartesian coordinates, the metric matrix is the identity: @@ -3159,16 +3218,16 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Construction: ```text - EuclideanMetric(ndim: int) + FlatMetric(ndim: int) ``` Semantics: - `signature = (1,) * ndim`. - - `metric_matrix(chart, /, *, at, usys=None)` returns a `QuantityMatrix`. - - For Cartesian charts, `metric_matrix` returns a dimensionless identity matrix of shape `(ndim, ndim)`. - - For compatible non-Cartesian charts, `metric_matrix` is computed as `J^T J`, where `J = jac_pt_map(at, chart, chart.cartesian, usys=usys)`. - - This pullback is diagonal exactly for orthogonal charts. Therefore the `AbstractDiagonalMetric` interpretation of `EuclideanMetric` is scoped to that orthogonal chart domain; atlas compatibility alone does not guarantee diagonality. + - `metric_matrix(manifold, point, chart)` returns a `DiagonalMetric` for orthogonal charts. + - For Cartesian charts, the diagonal is all ones (dimensionless). + - For compatible non-Cartesian charts, the diagonal is computed from `J^T J`, where `J = jac_pt_map(at, chart, chart.cartesian, usys=usys)`. + - This pullback is diagonal exactly for orthogonal charts. Therefore the `AbstractDiagonalMetricField` interpretation of `FlatMetric` is scoped to that orthogonal chart domain; atlas compatibility alone does not guarantee diagonality. - If a chart has no global Cartesian sibling, the current implementation falls back to a dimensionless identity matrix. **Example** @@ -3178,14 +3237,15 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> m = cxm.EuclideanMetric(3) - >>> at = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} - >>> m.metric_matrix(cxc.cart3d, at=at) - QuantityMatrix([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], '((, , ), (, , ), (, , ))') + >>> import coordinax.api.manifolds as cxmapi + >>> m = cxm.FlatMetric(3) >>> m.signature (1, 1, 1) + >>> m.ndim + 3 + >>> at = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + >>> cxmapi.metric_matrix(cxm.R3, at, cxc.cart3d).diagonal + Array([1., 1., 1.], dtype=float64) ``` (software-spec-euclideanmanifold)= @@ -3204,7 +3264,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial A `EuclideanManifold(n)` provides both Euclidean smooth and metric structure: - `atlas = EuclideanAtlas(n)` - - `metric = EuclideanMetric(n)` + - `metric = FlatMetric(n)` with manifold dimension $ \dim M = n. $ @@ -3228,7 +3288,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.manifolds as cxm >>> import coordinax.charts as cxc - >>> M = cxm.EuclideanManifold(3) + >>> M = cxm.R3 >>> M.ndim 3 >>> M.default_chart @@ -3275,11 +3335,11 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial As with `EuclideanAtlas`, membership is determined by chart registration rather than by hard-coded enumeration using the ``register`` class method. -(software-spec-hypersphericalmetric)= +(software-spec-roundmetric)= -!!! info `HyperSphericalMetric` +!!! info `RoundMetric` - `HyperSphericalMetric` is the round Riemannian metric on the unit sphere in hyperspherical coordinates. + `RoundMetric` is the round Riemannian metric on the unit sphere in hyperspherical coordinates. For $S^2$ with chart $(\theta, \phi)$, the metric is @@ -3301,15 +3361,14 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Construction: ```text - HyperSphericalMetric(ndim: int) + RoundMetric(ndim: int) ``` Semantics: - `signature = (1,) * ndim`. - - `metric_matrix(chart, /, *, at, usys=None)` returns either a plain array or a `QuantityMatrix`, depending on whether inputs are unitful. + - `metric_matrix(manifold, point, chart)` returns a `DiagonalMetric` in the intrinsic hyperspherical chart basis. - Angular inputs are interpreted in radians by default, or via `usys["angle"]` when a unit system is provided. - - The returned metric is diagonal in the intrinsic hyperspherical chart basis. **Example** @@ -3318,13 +3377,15 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm - >>> m = cxm.HyperSphericalMetric(2) - >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} - >>> m.metric_matrix(cxc.sph2, at=at) - Array([[1., 0.], - [0., 1.]], dtype=float64) + >>> import coordinax.api.manifolds as cxmapi + >>> m = cxm.RoundMetric(2) >>> m.signature (1, 1) + >>> m.ndim + 2 + >>> at = {"theta": jnp.array(jnp.pi / 2), "phi": jnp.array(0.0)} + >>> cxmapi.metric_matrix(cxm.S2, at, cxc.sph2).diagonal + Array([1., 1.], dtype=float64) ``` (software-spec-twospheremanifold)= @@ -3348,7 +3409,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Structure: - `atlas = HyperSphericalAtlas()` - - `metric = HyperSphericalMetric(ndim)` + - `metric = RoundMetric(ndim)` The intrinsic dimension is $ \dim S^2 = 2$. @@ -3380,7 +3441,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.manifolds as cxm >>> import coordinax.charts as cxc - >>> M = cxm.HyperSphericalManifold() + >>> M = cxm.HyperSphericalManifold(2) >>> M.ndim 2 @@ -3463,9 +3524,8 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial - `signature = (-1, 1, 1, 1)`. - `ndim = 4`. - - `metric_matrix(chart, /, *, at, usys=None)` returns a `QuantityMatrix`. - - In the canonical `MinkowskiCT` chart, `metric_matrix` returns `diag(-1, 1, 1, 1)` directly. - - For other registered charts, `metric_matrix` returns `J^T η J`. + - `metric_matrix(manifold, point, chart)` returns a `DiagonalMetric` in the canonical `MinkowskiCT` chart with diagonal `(-1, 1, 1, 1)`. + - For other registered charts, the full matrix is `J^T η J` (returned as a `DenseMetric`). **Example** @@ -3474,20 +3534,21 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm + >>> import coordinax.api.manifolds as cxmapi >>> m = cxm.MinkowskiMetric() + >>> m.signature + (-1, 1, 1, 1) + >>> m.ndim + 4 + >>> M = cxm.MinkowskiManifold() >>> at = { ... "ct": jnp.array(0.0), ... "x": jnp.array(0.0), ... "y": jnp.array(0.0), ... "z": jnp.array(0.0), ... } - >>> m.metric_matrix(cxc.minkowskict, at=at).value - Array([[-1., 0., 0., 0.], - [ 0., 1., 0., 0.], - [ 0., 0., 1., 0.], - [ 0., 0., 0., 1.]], dtype=float64) - >>> m.signature - (-1, 1, 1, 1) + >>> cxmapi.metric_matrix(M, at, cxc.minkowskict).diagonal + Array([-1., 1., 1., 1.], dtype=float64) ``` (software-spec-minkowskimanifold)= @@ -3597,19 +3658,19 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial ```text CustomMetric( - metric_matrix: Callable[[AbstractChart], QuantityMatrix | Array], + metric_matrix: Callable[[AbstractChart], QMatrix | Array], signature: tuple[int, ...], ) ``` Semantics: - - `metric_matrix(chart, /, *, at, usys=None)` is supplied by the caller and must satisfy the [`AbstractMetric`](#software-spec-abstractmetric) contract. + - The `metric_matrix` callable is supplied by the caller and must satisfy the [`AbstractMetricField`](#software-spec-abstractmetricfield) contract. - `signature` is the metric signature as a tuple of `+1` and `-1` entries. - `ndim = len(signature)`. - `CustomMetric` is immutable and registered as a static JAX PyTree, matching the behavior required of all concrete metric types. - This type exists so users can define metrics for custom manifolds without subclassing `AbstractMetric`. + This type exists so users can define metrics for custom manifolds without subclassing `AbstractMetricField`. **Example** @@ -3642,7 +3703,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Construction: ```text - CustomManifold(atlas: CustomAtlas, metric: AbstractMetric) + CustomManifold(atlas: CustomAtlas, metric: AbstractMetricField) ``` The manifold is intentionally thin: it forwards chart-membership checks, default-chart selection, and point transition wrappers to the provided atlas, while storing an explicit metric object for geometric computations. @@ -3663,7 +3724,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial ... charts=(cxc.Cart2D, cxc.Polar2D), ... chart_default=cxc.cart2d, ... ) - >>> M = cxm.CustomManifold(A, cxm.EuclideanMetric(2)) + >>> M = cxm.CustomManifold(A, cxm.FlatMetric(2)) >>> M.ndim 2 @@ -3753,11 +3814,11 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial False ``` -(software-spec-cartesianproductmetric)= +(software-spec-productmetric)= -!!! info `CartesianProductMetric` +!!! info `ProductMetric` - `CartesianProductMetric` is the canonical metric on a Cartesian product manifold. + `ProductMetric` is the canonical metric on a Cartesian product manifold. For factor manifolds $(M_i, g_i)$, the product manifold $$ @@ -3772,14 +3833,14 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Construction: ```text - CartesianProductMetric(factors: tuple[AbstractMetric, ...]) + ProductMetric(factors: tuple[AbstractMetricField, ...]) ``` Semantics: - `signature` is the concatenation of factor signatures in product order. - `ndim = sum(metric.ndim for metric in factors)`. - - `metric_matrix(chart, /, *, at, usys=None)` requires a product chart and returns a block-diagonal matrix with one block per factor metric. + - `metric_matrix(manifold, point, chart)` requires a product chart and returns a block-diagonal `DenseMetric` with one block per factor metric. - Each block is the factor metric matrix evaluated at the corresponding factor point extracted from `at` using product-chart factor splitting. **Example** @@ -3790,10 +3851,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.manifolds as cxm >>> import unxt as u - >>> M = cxm.CartesianProductManifold( - ... factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - ... factor_names=("S2", "R1"), - ... ) + >>> M = cxm.CartesianProductManifold(factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1")) >>> metric = M.metric >>> metric.signature (1, 1, 1) @@ -3806,8 +3864,9 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial ... "S2.phi": u.Angle(0.0, "rad"), ... "R1.x": u.Q(1.0, "m"), ... } - >>> g = metric.metric_matrix(chart, at=at) - >>> g.shape + >>> import coordinax.api.manifolds as cxmapi + >>> g = cxmapi.metric_matrix(M, at, chart) + >>> g.to_dense().matrix.value.shape (3, 3) ``` @@ -3836,7 +3895,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial - `ndim = sum(factor.ndim for factor in factors)` - `atlas = CartesianProductAtlas(...)` formed from factor atlases. - - `metric = CartesianProductMetric(...)` formed from factor metrics. + - `metric = ProductMetric(...)` formed from factor metrics. - `default_chart = atlas.default_chart()` - Factor names must be unique and are used as keys when accessing the product atlas. @@ -3858,10 +3917,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.manifolds as cxm >>> import coordinax.charts as cxc - >>> M = cxm.CartesianProductManifold( - ... factors=(cxm.HyperSphericalManifold(), cxm.EuclideanManifold(1)), - ... factor_names=("S2", "R1"), - ... ) + >>> M = cxm.CartesianProductManifold(factors=(cxm.S2, cxm.R1), factor_names=("S2", "R1")) >>> M.ndim 3 @@ -3978,13 +4034,13 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial CustomEmbeddingMap(intrinsic=Cart1D(), ambient=Cart2D(), ...) ``` -(software-spec-inducedmetric)= +(software-spec-pullbackmetric)= -!!! info `InducedMetric` +!!! info `PullbackMetric` - `InducedMetric` is the pullback metric on an embedded manifold. + `PullbackMetric` is the pullback metric on an embedded manifold. - Given an embedding $\iota : N \hookrightarrow M$, `InducedMetric` constructs the intrinsic metric on $N$ from the ambient metric on $M$ by pullback: + Given an embedding $\iota : N \hookrightarrow M$, `PullbackMetric` constructs the intrinsic metric on $N$ from the ambient metric on $M$ by pullback: $$ g_N = \iota^* g_M, @@ -4001,9 +4057,9 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial Construction: ```text - InducedMetric( + PullbackMetric( embed_map: AbstractEmbeddingMap, - ambient_metric: AbstractMetric, + ambient_metric: AbstractMetricField, ) ``` @@ -4011,7 +4067,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial - `embed_map` defines the intrinsic and ambient charts used to compute the pullback. - `ambient_metric` is evaluated at the embedded point `embed_map.embed(at, usys=usys)`. - - `metric_matrix(chart, /, *, at, usys=None)` computes the embedding Jacobian and returns `J^T G J` as a `QuantityMatrix`. + - `metric_matrix(manifold, point, chart)` (called on the `EmbeddedManifold`) computes the embedding Jacobian and returns `J^T G J` as a `DenseMetric`. - The current implementation ignores the `chart` argument numerically and computes the induced metric in the embedding map's intrinsic chart. - `signature = (1,) * embed_map.intrinsic.ndim`. - `ndim = embed_map.intrinsic.ndim`. @@ -4026,14 +4082,18 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> import coordinax.charts as cxc >>> import coordinax.manifolds as cxm + >>> import coordinax.api.manifolds as cxmapi >>> embed_map = cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")) - >>> metric = cxm.InducedMetric(embed_map, cxm.EuclideanMetric(3)) + >>> M_emb = cxm.EmbeddedManifold( + ... intrinsic=cxm.S2, + ... ambient=cxm.R3, + ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(1.0, "km")), + ... ) >>> at = {"theta": u.Q(jnp.pi / 2, "rad"), "phi": u.Q(0.0, "rad")} - >>> metric.metric_matrix(cxc.sph2, at=at) - QuantityMatrix([[1., 0.], - [0., 1.]], '((km2 / rad2, km2 / rad2), (km2 / rad2, km2 / rad2))') - >>> metric.signature - (1, 1) + >>> g = cxmapi.metric_matrix(M_emb, at, cxc.sph2) + >>> g.matrix.value + Array([[1., 0.], + [0., 1.]], dtype=float64, weak_type=True) ``` (software-spec-embeddedmanifold)= @@ -4065,7 +4125,7 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial - `embed_map: AbstractEmbeddingMap` — the smooth embedding defining coordinates transformation - `atlas = intrinsic.atlas` — uses the intrinsic manifold's atlas - `ndim = intrinsic.ndim` — the dimension is that of the intrinsic manifold - - `metric = InducedMetric(embed_map, ambient.metric)` — the metric is derived from the embedding and the ambient manifold metric, not passed separately at construction time + - `metric = PullbackMetric(embed_map, ambient.metric)` — the metric is derived from the embedding and the ambient manifold metric, not passed separately at construction time Embedding API: @@ -4108,8 +4168,8 @@ $$g_{ij}(q) = g_p\!\left(\frac{\partial}{\partial q^i}, \frac{\partial}{\partial >>> # Create a 2-sphere embedded in R^3 with radius 2 km >>> manifold = cxm.EmbeddedManifold( - ... intrinsic=cxm.HyperSphericalManifold(), - ... ambient=cxm.EuclideanManifold(3), + ... intrinsic=cxm.S2, + ... ambient=cxm.R3, ... embed_map=cxm.TwoSphereIn3D(radius=u.Q(2.0, "km")), ... ) >>> manifold.metric.signature diff --git a/docs/tutorials/point_objects.md b/docs/tutorials/point_objects.md index 8a51caf0..566e7e0f 100644 --- a/docs/tutorials/point_objects.md +++ b/docs/tutorials/point_objects.md @@ -157,22 +157,23 @@ Q(1, 'm') ## Reading Metric Diagonals -Because a vector carries both its chart and its manifold, you can ask the manifold for the metric diagonal entries at the represented location: +Because a vector carries both its chart and its manifold, you can ask for the metric diagonal entries at the represented location: ```{code-block} python +>>> import coordinax.manifolds as cxm >>> v = cx.Point.from_( ... {"r": u.Q(2, "km"), "theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0, "rad")}, ... cxc.sph3d, ... ) ->>> gdiag = v.M.scale_factors(v.chart, at=v.data) +>>> gdiag = cxm.scale_factors(v.chart, at=v.data) >>> gdiag.shape (3,) >>> gdiag.unit.to_string() '(, km2 / rad2, km2 / rad2)' ``` -`scale_factors` returns the diagonal metric entries $g_{ii}$ as a 1-D `QuantityMatrix`, so each direction can keep its own unit. +`scale_factors` returns the diagonal metric entries $g_{ii}$ as a 1-D `QMatrix`, so each direction can keep its own unit. ## Changing The Chart (Coordinate Conversion) From 24c8623029c978452ef74696a2dfd5fdf39ebf11 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:36:29 -0400 Subject: [PATCH 11/15] =?UTF-8?q?=F0=9F=94=A7=20config:=20add=20import=20a?= =?UTF-8?q?bbreviations=20and=20ARG002=20test=20lint=20ignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add jax.extend.core and jax.tree_util import abbreviations to pyproject.toml. Add ARG002 (unused method argument) to per-file test ignores to allow pytest fixture parameters that are injected but not referenced in the body. --- pyproject.toml | 3 +++ tests/unit/internal/test_quantity_matrix.py | 7 ++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c2bed5a9..8ec35d72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -376,6 +376,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "**/tests/**" = [ "ANN", + "ARG002", # Unused method argument "ARG005", # Unused lambda argument "D102", "D401", # First line of docstring should be in imperative mood @@ -418,6 +419,8 @@ ignore = [ "coordinax.vectors" = "cxv" equinox = "eqx" functools = "ft" +"jax.extend.core" = "jexc" +"jax.tree_util" = "jtu" unxt = "u" diff --git a/tests/unit/internal/test_quantity_matrix.py b/tests/unit/internal/test_quantity_matrix.py index 3a9c93c5..43f7e8ed 100644 --- a/tests/unit/internal/test_quantity_matrix.py +++ b/tests/unit/internal/test_quantity_matrix.py @@ -51,7 +51,7 @@ def unit_2x2(): @pytest.fixture def qm_2x2(unit_2x2): """Return a 2x2 QMatrix with values 1-4.""" - return QMat(value=jnp.array([[1, 2], [3, 4]]), unit=unit_2x2) + return QMat(value=jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=unit_2x2) @pytest.fixture @@ -456,10 +456,7 @@ def test_same_units(self, qm_2x2, unit_2x2): def test_result_keeps_lhs_units(self, qm_2x2, unit_2x2, unit_2x2_alt): """Result units come from the LHS.""" - other = QMat( - value=jnp.array([[1.0, 1000.0], [3000.0, 180.0]]), - unit=unit_2x2_alt, - ) + other = QMat(value=jnp.array([[1, 1000], [3000, 180]]), unit=unit_2x2_alt) result = _add(qm_2x2, other) assert result.unit == unit_2x2 From d8c2d48cf9e636d08a1f000f969ca93140305493 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 18:48:43 -0400 Subject: [PATCH 12/15] =?UTF-8?q?=F0=9F=94=A7=20config:=20add=20ty=20type-?= =?UTF-8?q?ignore=20comments=20to=20silence=20known=20false=20positives?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cast u.ustrip() results to Array/jax.Array so .ndim/.shape resolve; cast q.unit to u.AbstractUnit; fix ty: ignore placement for J.T attribute access; add ty: ignore[possibly-missing-submodule] on jax.core references. --- src/coordinax/_src/euclidean/register_metric.py | 10 +++++----- src/coordinax/_src/internal/quantity_matrix.py | 2 +- src/coordinax/_src/internal/quantity_matrix/_inv.py | 2 +- .../internal/quantity_matrix/_register_primitives.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/coordinax/_src/euclidean/register_metric.py b/src/coordinax/_src/euclidean/register_metric.py index 012f38c0..bbc84837 100644 --- a/src/coordinax/_src/euclidean/register_metric.py +++ b/src/coordinax/_src/euclidean/register_metric.py @@ -15,7 +15,7 @@ __all__: tuple[str, ...] = () -from typing import Any +from typing import Any, cast import jax.numpy as jnp import plum @@ -61,8 +61,8 @@ def _angle_rad(q: Any, /) -> Any: def _angle_unit(q: Any, /) -> u.AbstractUnit: """Return the unit of an angular coordinate, or dimensionless if plain array.""" if isinstance(q, u.AbstractQuantity): - return q.unit - return u.unit("") + return cast("u.AbstractUnit", q.unit) + return u.unit("") # ty: ignore[invalid-return-type] # ===================================================================== @@ -484,5 +484,5 @@ def metric_matrix( unit_tup = tuple(tuple(u.unit("") for _ in range(n)) for _ in range(n)) return DenseMetric(QMatrix(jnp.eye(n), unit=UnitsMatrix(unit_tup))) J = cxcapi.jac_pt_map(point, chart, cart_chart, usys=None) - JT = J.T - return DenseMetric(JT @ J) # ty: ignore[unsupported-operator] + JT = J.T # ty: ignore[unresolved-attribute] + return DenseMetric(JT @ J) diff --git a/src/coordinax/_src/internal/quantity_matrix.py b/src/coordinax/_src/internal/quantity_matrix.py index 4aed3ca1..5065aa48 100644 --- a/src/coordinax/_src/internal/quantity_matrix.py +++ b/src/coordinax/_src/internal/quantity_matrix.py @@ -1099,7 +1099,7 @@ def dot_general_qm_qty( """ rhs_unit = u.unit_of(rhs) - rhs_val = u.ustrip(AllowValue, rhs_unit, rhs) + rhs_val = cast("Array", u.ustrip(AllowValue, rhs_unit, rhs)) if rhs_val.ndim == 1: n = rhs_val.shape[0] rhs_qm = QMatrix(rhs_val, unit=UnitsMatrix(tuple(rhs_unit for _ in range(n)))) diff --git a/src/coordinax/_src/internal/quantity_matrix/_inv.py b/src/coordinax/_src/internal/quantity_matrix/_inv.py index 0c6d37d0..0d73bdfe 100644 --- a/src/coordinax/_src/internal/quantity_matrix/_inv.py +++ b/src/coordinax/_src/internal/quantity_matrix/_inv.py @@ -100,7 +100,7 @@ def _inv_impl(x: Array, /) -> Array: # ── 2. Abstract evaluation rule ─────────────────────────────────────────── -def _inv_abstract_eval(x: "jax.core.ShapedArray", /) -> "jax.core.ShapedArray": +def _inv_abstract_eval(x: "jax.core.ShapedArray", /) -> "jax.core.ShapedArray": # ty: ignore[possibly-missing-submodule] if x.ndim < 2: raise ValueError(f"inv_p requires at least 2-D input, got ndim={x.ndim}") if x.shape[-1] != x.shape[-2]: diff --git a/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py b/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py index a296445a..db946cdd 100644 --- a/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py +++ b/src/coordinax/_src/internal/quantity_matrix/_register_primitives.py @@ -9,7 +9,7 @@ - ``lax.reduce_sum_p`` — summation reduction """ -from typing import Any +from typing import Any, cast import jax import jax.numpy as jnp @@ -495,7 +495,7 @@ def dot_general_qm_qty( """ rhs_unit = u.unit_of(rhs) - rhs_val = u.ustrip(AllowValue, rhs_unit, rhs) + rhs_val = cast("jax.Array", u.ustrip(AllowValue, rhs_unit, rhs)) if rhs_val.ndim == 1: n = rhs_val.shape[0] rhs_qm = QMatrix(rhs_val, unit=UnitsMatrix(tuple(rhs_unit for _ in range(n)))) @@ -732,7 +732,7 @@ def gather_qm( # Number of output elements — start_indices.shape is always concrete in JAX. out_size = start_indices.shape[0] - if isinstance(start_indices, jax.core.Tracer): + if isinstance(start_indices, jax.core.Tracer): # ty: ignore[possibly-missing-submodule] # JIT path: indices are traced — fall back to uniform-unit check. out_unit = _jit_fallback_uniform_unit(x.unit, out_size) else: From 8121ac70ee8daa5b410998308e7e8b9b6262ef23 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 19:22:45 -0400 Subject: [PATCH 13/15] =?UTF-8?q?=E2=9C=85=20test:=20simplify=20numeric=20?= =?UTF-8?q?literals=20in=20test=20assertions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace float literals (1.0, 2.0, etc.) with integer literals where values are whole numbers, and normalize rtol=0.0 to rtol=0 in assert_allclose calls. --- .../tests/unit/test_frame_transforms.py | 6 +- tests/unit/internal/test_quantity_matrix.py | 316 +++++++++--------- 2 files changed, 155 insertions(+), 167 deletions(-) diff --git a/packages/coordinax.astro/tests/unit/test_frame_transforms.py b/packages/coordinax.astro/tests/unit/test_frame_transforms.py index b5db6ca2..67348a6e 100644 --- a/packages/coordinax.astro/tests/unit/test_frame_transforms.py +++ b/packages/coordinax.astro/tests/unit/test_frame_transforms.py @@ -100,7 +100,7 @@ def test_icrs_to_galactocentric_matches_astropy_positions(xyz_pc) -> None: got = cxfm.act(op, None, u.Q(jnp.asarray(xyz_pc), "pc")).ustrip("pc") expected = _astropy_icrs_to_gcf_xyz_pc(xyz_pc, gcf) - np.testing.assert_allclose(got, expected, rtol=0.0, atol=1e-6) + np.testing.assert_allclose(got, expected, rtol=0, atol=1e-6) @pytest.mark.parametrize( @@ -114,7 +114,7 @@ def test_galactocentric_to_icrs_matches_astropy_positions(xyz_pc) -> None: got = cxfm.act(op, None, u.Q(jnp.asarray(xyz_pc), "pc")).ustrip("pc") expected = _astropy_gcf_to_icrs_xyz_pc(xyz_pc, gcf) - np.testing.assert_allclose(got, expected, rtol=0.0, atol=1e-6) + np.testing.assert_allclose(got, expected, rtol=0, atol=1e-6) def test_icrs_galactocentric_transitions_are_inverse_for_positions() -> None: @@ -176,7 +176,7 @@ def test_gcf_icrs_gcf_roundtrip(self, q: u.AbstractQuantity) -> None: bwd = cxf.frame_transition(icrs, gcf) back = cxfm.act(bwd, None, cxfm.act(fwd, None, q)) - np.testing.assert_allclose(back.ustrip("pc"), q.ustrip("pc"), rtol=s, atol=1e-6) + np.testing.assert_allclose(back.ustrip("pc"), q.ustrip("pc"), rtol=0, atol=1e-6) @given( q=ust.quantities( diff --git a/tests/unit/internal/test_quantity_matrix.py b/tests/unit/internal/test_quantity_matrix.py index 43f7e8ed..1c9f0584 100644 --- a/tests/unit/internal/test_quantity_matrix.py +++ b/tests/unit/internal/test_quantity_matrix.py @@ -407,7 +407,7 @@ class TestConvertValuePoint: def test_noop_same_units(self, unit_1d): """If from_units == to_units no conversion happens.""" - val = jnp.array([7.0, 8.0, 9.0]) + val = jnp.array([7, 8, 9]) out = _convert_value_vector(val, unit_1d, unit_1d) assert jnp.array_equal(out, val) @@ -462,15 +462,12 @@ def test_result_keeps_lhs_units(self, qm_2x2, unit_2x2, unit_2x2_alt): def test_mixed_unit_values(self, qm_2x2, unit_2x2_alt): """Values are correctly converted before addition.""" - other = QMat( - value=jnp.array([[1.0, 1000.0], [3000.0, 180.0]]), - unit=unit_2x2_alt, - ) + other = QMat(value=jnp.array([[1, 1000], [3000, 180]]), unit=unit_2x2_alt) res_val = _add(qm_2x2, other).value - assert jnp.isclose(res_val[0, 0], 1001.0) # 1 + 1000 m = 1001 - assert jnp.isclose(res_val[0, 1], 3.0) # 2 + 1.0 s = 3 - assert jnp.isclose(res_val[1, 0], 6.0) # 3 + 3.0 kg = 6 - assert jnp.isclose(res_val[1, 1], 4.0 + math.pi, atol=1e-4) # 4+pi rad≈7.14159 + assert jnp.isclose(res_val[0, 0], 1001) # 1 + 1000 m = 1001 + assert jnp.isclose(res_val[0, 1], 3) # 2 + 1.0 s = 3 + assert jnp.isclose(res_val[1, 0], 6) # 3 + 3.0 kg = 6 + assert jnp.isclose(res_val[1, 1], 4 + math.pi, atol=1e-4) # 4+pi rad≈7.14159 def test_add_zeros(self, qm_2x2, unit_2x2): """Adding zeros gives original values.""" @@ -480,8 +477,8 @@ def test_add_zeros(self, qm_2x2, unit_2x2): def test_commutativity_same_units(self, unit_2x2): """A + b == b + a when units are the same.""" - a = QMat(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=unit_2x2) - b = QMat(jnp.array([[5.0, 6.0], [7.0, 8.0]]), unit=unit_2x2) + a = QMat(jnp.array([[1, 2], [3, 4]]), unit=unit_2x2) + b = QMat(jnp.array([[5, 6], [7, 8]]), unit=unit_2x2) r1 = _add(a, b) r2 = _add(b, a) assert jnp.allclose(r1.value, r2.value) @@ -496,16 +493,16 @@ def test_batch_addition(self, unit_2x2): def test_1x1(self): """1x1 addition.""" - a = QMat(jnp.array([[3.0]]), unit=((_m,),)) - b = QMat(jnp.array([[7.0]]), unit=((_m,),)) + a = QMat(jnp.array([[3]]), unit=((_m,),)) + b = QMat(jnp.array([[7]]), unit=((_m,),)) result = _add(a, b) - assert jnp.isclose(result.value[0, 0], 10.0) + assert jnp.isclose(result.value[0, 0], 10) def test_1d_addition_same_units(self, qm_1d, unit_1d): """1D vector addition, same units.""" - other = QMat(jnp.array([10.0, 20.0, 30.0]), unit=unit_1d) + other = QMat(jnp.array([10, 20, 30]), unit=unit_1d) result = _add(qm_1d, other) - expected = jnp.array([11.0, 22.0, 33.0]) + expected = jnp.array([11, 22, 33]) assert jnp.allclose(result.value, expected) assert result.unit == unit_1d @@ -513,9 +510,9 @@ def test_1d_addition_mixed_units(self, qm_1d, unit_1d_alt): """1D vector addition with unit conversion.""" other = QMat(jnp.array([1.0, 1000.0, 3000.0]), unit=unit_1d_alt) result = _add(qm_1d, other) - assert jnp.isclose(result.value[0], 1001.0) # 1 + 1000 m - assert jnp.isclose(result.value[1], 3.0) # 2 + 1.0 s - assert jnp.isclose(result.value[2], 6.0) # 3 + 3.0 kg + assert jnp.isclose(result.value[0], 1001) # 1 + 1000 m + assert jnp.isclose(result.value[1], 3) # 2 + 1.0 s + assert jnp.isclose(result.value[2], 6) # 3 + 3.0 kg assert result.unit == qm_1d.unit def test_1d_batch_addition(self, unit_1d): @@ -536,8 +533,8 @@ def test_direct_operator_mixed_units(self): assert result.unit.to_string() == "((km, deg), (km, deg))" assert jnp.isclose(result.value[0, 0], 1.001) assert jnp.isclose(result.value[1, 0], 3.003) - assert jnp.isclose(result.value[0, 1], 2.0 + 2.0 * (180.0 / jnp.pi), atol=1e-8) - assert jnp.isclose(result.value[1, 1], 4.0 + (720.0 / jnp.pi), atol=1e-8) + assert jnp.isclose(result.value[0, 1], 2 + 2 * (180 / jnp.pi), atol=1e-8) + assert jnp.isclose(result.value[1, 1], 4 + 720 / jnp.pi, atol=1e-8) # --------------------------------------------------------------------------- @@ -555,35 +552,26 @@ class TestSubtraction: def test_same_units(self, qm_2x2, unit_2x2): """Simple sub, same units.""" - other = QMat( - value=jnp.array([[10.0, 20.0], [30.0, 40.0]]), - unit=unit_2x2, - ) + other = QMat(value=jnp.array([[10, 20], [30, 40]]), unit=unit_2x2) result = _sub(other, qm_2x2) - expected = jnp.array([[9.0, 18.0], [27.0, 36.0]]) + expected = jnp.array([[9, 18], [27, 36]]) assert jnp.allclose(result.value, expected) assert result.unit == unit_2x2 def test_result_keeps_lhs_units(self, qm_2x2, unit_2x2, unit_2x2_alt): """Result units come from the LHS.""" - other = QMat( - value=jnp.array([[1.0, 1000.0], [3000.0, 180.0]]), - unit=unit_2x2_alt, - ) + other = QMat(value=jnp.array([[1, 1000], [3000, 180]]), unit=unit_2x2_alt) result = _sub(qm_2x2, other) assert result.unit == unit_2x2 def test_mixed_unit_values(self, qm_2x2, unit_2x2_alt): """Values are correctly converted before subtraction.""" - other = QMat( - value=jnp.array([[1.0, 1000.0], [3000.0, 180.0]]), - unit=unit_2x2_alt, - ) + other = QMat(value=jnp.array([[1, 1000], [3000, 180]]), unit=unit_2x2_alt) res_val = _sub(qm_2x2, other).value - assert jnp.isclose(res_val[0, 0], -999.0) # 1 - 1000 m - assert jnp.isclose(res_val[0, 1], 1.0) # 2 - 1.0 s - assert jnp.isclose(res_val[1, 0], 0.0) # 3 - 3.0 kg - assert jnp.isclose(res_val[1, 1], 4.0 - math.pi, atol=1e-4) # 4-pi rad + assert jnp.isclose(res_val[0, 0], -999) # 1 - 1000 m + assert jnp.isclose(res_val[0, 1], 1) # 2 - 1.0 s + assert jnp.isclose(res_val[1, 0], 0) # 3 - 3.0 kg + assert jnp.isclose(res_val[1, 1], 4 - math.pi, atol=1e-4) # 4-pi rad def test_sub_zeros(self, qm_2x2, unit_2x2): """Subtracting zeros gives original values.""" @@ -599,8 +587,8 @@ def test_self_subtraction(self, qm_2x2, unit_2x2): def test_anticommutativity_same_units(self, unit_2x2): """A - b == -(b - a) when units are the same.""" - a = QMat(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=unit_2x2) - b = QMat(jnp.array([[5.0, 6.0], [7.0, 8.0]]), unit=unit_2x2) + a = QMat(jnp.array([[1, 2], [3, 4]]), unit=unit_2x2) + b = QMat(jnp.array([[5, 6], [7, 8]]), unit=unit_2x2) r1 = _sub(a, b) r2 = _sub(b, a) assert jnp.allclose(r1.value, -r2.value) @@ -615,16 +603,16 @@ def test_batch_subtraction(self, unit_2x2): def test_1x1(self): """1x1 subtraction.""" - a = QMat(jnp.array([[7.0]]), unit=((_m,),)) - b = QMat(jnp.array([[3.0]]), unit=((_m,),)) + a = QMat(jnp.array([[7]]), unit=((_m,),)) + b = QMat(jnp.array([[3]]), unit=((_m,),)) result = _sub(a, b) - assert jnp.isclose(result.value[0, 0], 4.0) + assert jnp.isclose(result.value[0, 0], 4) def test_1d_subtraction_same_units(self, qm_1d, unit_1d): """1D vector subtraction, same units.""" - other = QMat(jnp.array([10.0, 20.0, 30.0]), unit=unit_1d) + other = QMat(jnp.array([10, 20, 30]), unit=unit_1d) result = _sub(other, qm_1d) - expected = jnp.array([9.0, 18.0, 27.0]) + expected = jnp.array([9, 18, 27]) assert jnp.allclose(result.value, expected) assert result.unit == unit_1d @@ -632,9 +620,9 @@ def test_1d_subtraction_mixed_units(self, qm_1d, unit_1d_alt): """1D vector subtraction with unit conversion.""" other = QMat(jnp.array([1.0, 1000.0, 3000.0]), unit=unit_1d_alt) result = _sub(qm_1d, other) - assert jnp.isclose(result.value[0], -999.0) # 1 - 1000 m - assert jnp.isclose(result.value[1], 1.0) # 2 - 1.0 s - assert jnp.isclose(result.value[2], 0.0) # 3 - 3.0 kg + assert jnp.isclose(result.value[0], -999) # 1 - 1000 m + assert jnp.isclose(result.value[1], 1) # 2 - 1.0 s + assert jnp.isclose(result.value[2], 0) # 3 - 3.0 kg assert result.unit == qm_1d.unit @@ -653,13 +641,13 @@ class TestDotProduct: def test_simple_matmul_uniform_units(self): """2x2 @ 2x1 with uniform units along contraction axis.""" - a = QMat(jnp.array([[2.0, 3.0], [4.0, 5.0]]), unit=((_m, _m), (_kg, _kg))) - b = QMat(jnp.array([[10.0], [20.0]]), unit=((_s,), (_s,))) + a = QMat(jnp.array([[2, 3], [4, 5]]), unit=((_m, _m), (_kg, _kg))) + b = QMat(jnp.array([[10], [20]]), unit=((_s,), (_s,))) result = _matmul(a, b) # C[0,0] = 2*10 + 3*20 = 80 in m*s # C[1,0] = 4*10 + 5*20 = 140 in kg*s - assert jnp.isclose(result.value[0, 0], 80.0) - assert jnp.isclose(result.value[1, 0], 140.0) + assert jnp.isclose(result.value[0, 0], 80) + assert jnp.isclose(result.value[1, 0], 140) assert result.unit == ( (_m * _s,), (_kg * _s,), @@ -673,36 +661,36 @@ def test_matmul_with_unit_conversion(self): # B is 2x1 with units [[s], [s]] # C[0,0] = m*s + km*s -> converted to m*s (ref is j=0) # = 2*10 + (3 km = 3000 m)*20 s = 20 + 60000 = 60020 - a = QMat(jnp.array([[2.0, 3.0], [4.0, 5.0]]), unit=((_m, _km), (_kg, _kg))) - b = QMat(jnp.array([[10.0], [20.0]]), unit=((_s,), (_s,))) + a = QMat(jnp.array([[2, 3], [4, 5]]), unit=((_m, _km), (_kg, _kg))) + b = QMat(jnp.array([[10], [20]]), unit=((_s,), (_s,))) result = _matmul(a, b) - assert jnp.isclose(result.value[0, 0], 60020.0) + assert jnp.isclose(result.value[0, 0], 60020) # C[1,0] = 4*10 + 5*20 = 140 (uniform kg*s, no conversion) - assert jnp.isclose(result.value[1, 0], 140.0) + assert jnp.isclose(result.value[1, 0], 140) def test_matmul_2x2_by_2x2(self): """Square 2x2 @ 2x2 matmul.""" - a = QMat(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=((_m, _m), (_m, _m))) - b = QMat(jnp.array([[5.0, 6.0], [7.0, 8.0]]), unit=((_s, _s), (_s, _s))) + a = QMat(jnp.array([[1, 2], [3, 4]]), unit=((_m, _m), (_m, _m))) + b = QMat(jnp.array([[5, 6], [7, 8]]), unit=((_s, _s), (_s, _s))) result = _matmul(a, b) # Standard matmul: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]] # = [[19, 22], [43, 50]] - expected = jnp.array([[19.0, 22.0], [43.0, 50.0]]) + expected = jnp.array([[19, 22], [43, 50]]) assert jnp.allclose(result.value, expected) ms = _m * _s assert result.unit == ((ms, ms), (ms, ms)) def test_matmul_identity(self): """Multiply by identity matrix.""" - a = QMat(jnp.array([[3.0, 7.0], [11.0, 13.0]]), unit=((_m, _m), (_m, _m))) + a = QMat(jnp.array([[3, 7], [11, 13]]), unit=((_m, _m), (_m, _m))) identity = QMat(jnp.eye(2), unit=((_dimless, _dimless), (_dimless, _dimless))) result = _matmul(a, identity) assert jnp.allclose(result.value, a.value) def test_matmul_output_units(self): """Output unit[i][k] = lhs.unit[i][0] * rhs.unit[0][k].""" - a = QMat(jnp.array([[1.0, 1.0]]), unit=((_m, _m),)) - b = QMat(jnp.array([[2.0, 3.0], [4.0, 5.0]]), unit=((_s, _kg), (_s, _kg))) + a = QMat(jnp.array([[1, 1]]), unit=((_m, _m),)) + b = QMat(jnp.array([[2, 3], [4, 5]]), unit=((_s, _kg), (_s, _kg))) result = _matmul(a, b) # Output shape: 1x2 assert result.shape[-2] == 1 @@ -714,10 +702,10 @@ def test_matmul_output_units(self): def test_matmul_1x1(self): """1x1 @ 1x1 is scalar product.""" - a = QMat(jnp.array([[3.0]]), unit=((_m,),)) - b = QMat(jnp.array([[7.0]]), unit=((_s,),)) + a = QMat(jnp.array([[3]]), unit=((_m,),)) + b = QMat(jnp.array([[7]]), unit=((_s,),)) result = _matmul(a, b) - assert jnp.isclose(result.value[0, 0], 21.0) + assert jnp.isclose(result.value[0, 0], 21) assert result.unit == ((_m * _s,),) def test_matmul_rhs_unit_conversion(self): @@ -726,39 +714,39 @@ def test_matmul_rhs_unit_conversion(self): # B: 2x1, units [[s], [min]] # C[0,0] = m*s + m*min -> ref = m*s # = 1*1 + 1*1 min -> 1*1 + 1*60 s = 1 + 60 = 61 in m*s - a = QMat(jnp.array([[1.0, 1.0]]), ((_m, _m),)) - b = QMat(jnp.array([[1.0], [1.0]]), ((_s,), (_min,))) + a = QMat(jnp.array([[1, 1]]), ((_m, _m),)) + b = QMat(jnp.array([[1], [1]]), ((_s,), (_min,))) result = _matmul(a, b) - assert jnp.isclose(result.value[0, 0], 61.0) + assert jnp.isclose(result.value[0, 0], 61) assert result.unit == ((_m * _s,),) def test_1d_dot_product_uniform_units(self): """1D @ 1D vector dot product with uniform units.""" - a = QMat(jnp.array([2.0, 3.0]), unit=(_m, _m)) - b = QMat(jnp.array([4.0, 5.0]), unit=(_s, _s)) + a = QMat(jnp.array([2, 3]), unit=(_m, _m)) + b = QMat(jnp.array([4, 5]), unit=(_s, _s)) result = _matmul(a, b) # Result should be a scalar Quantity, not a QMatrix # 2*4 + 3*5 = 8 + 15 = 23 in m*s assert isinstance(result, u.Q) - assert jnp.isclose(result.value, 23.0) + assert jnp.isclose(result.value, 23) assert result.unit == _m * _s def test_1d_dot_product_mixed_units(self): """1D @ 1D with mixed units requiring conversion.""" # a: [1 m, 1 km], b: [1 s, 1 s] # Result = 1*1 + 1000*1 = 1 + 1000 = 1001 in m*s - a = QMat(jnp.array([1.0, 1.0]), unit=(_m, _km)) - b = QMat(jnp.array([1.0, 1.0]), unit=(_s, _s)) + a = QMat(jnp.array([1, 1]), unit=(_m, _km)) + b = QMat(jnp.array([1, 1]), unit=(_s, _s)) result = _matmul(a, b) assert isinstance(result, u.Q) - assert jnp.isclose(result.value, 1001.0) + assert jnp.isclose(result.value, 1001) assert result.unit == _m * _s def test_1d_dot_product_batch(self): """1D @ 1D with batch dimensions.""" # Batch of 3 vectors, each length 2 - a = QMat(jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), unit=(_m, _m)) - b = QMat(jnp.array([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]), unit=(_s, _s)) + a = QMat(jnp.array([[1, 2], [3, 4], [5, 6]]), unit=(_m, _m)) + b = QMat(jnp.array([[7, 8], [9, 10], [11, 12]]), unit=(_s, _s)) @quax.quaxify def dot_batched(x, y): @@ -766,9 +754,9 @@ def dot_batched(x, y): result = jax.vmap(dot_batched)(a, b) # [1*7 + 2*8, 3*9 + 4*10, 5*11 + 6*12] = [23, 67, 127] - assert jnp.isclose(result.value[0], 23.0) - assert jnp.isclose(result.value[1], 67.0) - assert jnp.isclose(result.value[2], 127.0) + assert jnp.isclose(result.value[0], 23) + assert jnp.isclose(result.value[1], 67) + assert jnp.isclose(result.value[2], 127) # --------------------------------------------------------------------------- @@ -781,18 +769,18 @@ class TestJaxIntegration: def test_jit_add(self, unit_2x2): """jit-compiled addition.""" - a = QMat(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=unit_2x2) - b = QMat(jnp.array([[5.0, 6.0], [7.0, 8.0]]), unit=unit_2x2) + a = QMat(jnp.array([[1, 2], [3, 4]]), unit=unit_2x2) + b = QMat(jnp.array([[5, 6], [7, 8]]), unit=unit_2x2) result = jax.jit(_add)(a, b) - expected = jnp.array([[6.0, 8.0], [10.0, 12.0]]) + expected = jnp.array([[6, 8], [10, 12]]) assert jnp.allclose(result.value, expected) def test_jit_matmul(self): """jit-compiled matmul.""" - a = QMat(jnp.array([[2.0, 3.0]]), unit=((_m, _m),)) - b = QMat(jnp.array([[4.0], [5.0]]), unit=((_s,), (_s,))) + a = QMat(jnp.array([[2, 3]]), unit=((_m, _m),)) + b = QMat(jnp.array([[4], [5]]), unit=((_s,), (_s,))) result = jax.jit(_matmul)(a, b) - assert jnp.isclose(result.value[0, 0], 23.0) + assert jnp.isclose(result.value[0, 0], 23) def test_pytree_flatten_unflatten(self, qm_2x2, unit_2x2): """QMatrix is a proper PyTree.""" @@ -825,7 +813,7 @@ class TestPlumConversion: def test_QMatrix_to_quantity_uniform_1d(self): """1D uniform-unit ``QMatrix`` converts to ``u.Q``.""" - qm = QMat(value=jnp.array([1.0, 2.0, 3.0]), unit=(_m, _m, _m)) + qm = QMat(value=jnp.array([1, 2, 3]), unit=(_m, _m, _m)) result = plum.convert(qm, u.Q) @@ -922,25 +910,25 @@ class TestMatVec: def test_identity_uniform_units(self): """Identity 3x3 @ uniform-unit vector → same vector.""" A = QMat(jnp.eye(3), unit=((_dimless, _dimless, _dimless),) * 3) - v = QMat(jnp.array([1.0, 2.0, 3.0]), unit=(_m, _m, _m)) + v = QMat(jnp.array([1, 2, 3]), unit=(_m, _m, _m)) w = _matmul(A, v) assert isinstance(w, QMat) assert w.ndim == 1 - assert jnp.allclose(w.value, jnp.array([1.0, 2.0, 3.0])) + assert jnp.allclose(w.value, jnp.array([1, 2, 3])) def test_uniform_units_values(self): """2x2 @ 2: correct values with uniform units.""" - A = QMat(jnp.array([[2.0, 3.0], [4.0, 5.0]]), unit=((_m, _m), (_m, _m))) - v = QMat(jnp.array([10.0, 20.0]), unit=(_s, _s)) + A = QMat(jnp.array([[2, 3], [4, 5]]), unit=((_m, _m), (_m, _m))) + v = QMat(jnp.array([10, 20]), unit=(_s, _s)) w = _matmul(A, v) # [2*10 + 3*20, 4*10 + 5*20] = [80, 140] - assert jnp.isclose(w.value[0], 80.0) - assert jnp.isclose(w.value[1], 140.0) + assert jnp.isclose(w.value[0], 80) + assert jnp.isclose(w.value[1], 140) def test_output_is_1d_QMatrix(self): """Result of 2D @ 1D is a 1D ``QMatrix``, not a 2D one.""" - A = QMat(jnp.array([[1.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) - v = QMat(jnp.array([3.0, 7.0]), unit=(_s, _s)) + A = QMat(jnp.array([[1, 0], [0, 1]]), unit=((_m, _m), (_m, _m))) + v = QMat(jnp.array([3, 7]), unit=(_s, _s)) w = _matmul(A, v) assert isinstance(w, QMat) assert w.ndim == 1 @@ -948,8 +936,8 @@ def test_output_is_1d_QMatrix(self): def test_output_units_are_product(self): """Output unit[i] == lhs.unit[i][0] * rhs.unit[0].""" - A = QMat(jnp.array([[1.0, 1.0], [1.0, 1.0]]), unit=((_m, _m), (_kg, _kg))) - v = QMat(jnp.array([1.0, 1.0]), unit=(_s, _s)) + A = QMat(jnp.array([[1, 1], [1, 1]]), unit=((_m, _m), (_kg, _kg))) + v = QMat(jnp.array([1, 1]), unit=(_s, _s)) w = _matmul(A, v) assert w.unit[0] == _m * _s assert w.unit[1] == _kg * _s @@ -960,11 +948,11 @@ def test_lhs_unit_conversion_on_contraction_axis(self): # v: [1, 1] in [s, s] # ref[i] = m*s, scale[i,1] = 1000 (km*s → m*s) # w[0] = 1*1 + 1000*1 = 1001, w[1] = 1001 - A = QMat(jnp.array([[1.0, 1.0], [1.0, 1.0]]), unit=((_m, _km), (_m, _km))) - v = QMat(jnp.array([1.0, 1.0]), unit=(_s, _s)) + A = QMat(jnp.array([[1, 1], [1, 1]]), unit=((_m, _km), (_m, _km))) + v = QMat(jnp.array([1, 1]), unit=(_s, _s)) w = _matmul(A, v) - assert jnp.isclose(w.value[0], 1001.0) - assert jnp.isclose(w.value[1], 1001.0) + assert jnp.isclose(w.value[0], 1001) + assert jnp.isclose(w.value[1], 1001) assert w.unit[0] == _m * _s assert w.unit[1] == _m * _s @@ -974,8 +962,8 @@ def test_rhs_unit_conversion_on_contraction_axis(self): # v: [1, 1] in [s, ms] # ref[i] = m*s, scale[i,1] = uconvert_value(m*s, m*ms, 1.0) = 0.001 # w[0] = 1*1 + 0.001*1 = 1.001, w[1] = 1.001 - A = QMat(jnp.array([[1.0, 1.0], [1.0, 1.0]]), unit=((_m, _m), (_m, _m))) - v = QMat(jnp.array([1.0, 1.0]), unit=(_s, _ms)) + A = QMat(jnp.array([[1, 1], [1, 1]]), unit=((_m, _m), (_m, _m))) + v = QMat(jnp.array([1, 1]), unit=(_s, _ms)) w = _matmul(A, v) assert jnp.isclose(w.value[0], 1.001) assert jnp.isclose(w.value[1], 1.001) @@ -984,28 +972,28 @@ def test_rhs_unit_conversion_on_contraction_axis(self): def test_non_square_3x2_at_2(self): """Non-square 3x2 @ 2 → 3.""" A = QMat( - jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + jnp.array([[1, 2], [3, 4], [5, 6]]), unit=((_m, _km), (_m, _km), (_m, _km)), ) - v = QMat(jnp.array([1.0, 1.0]), unit=(_s, _s)) + v = QMat(jnp.array([1, 1]), unit=(_s, _s)) w = _matmul(A, v) # scale[i,1] = 1000 (km*s → m*s) # w[0] = 1*1 + 1000*2*1 = 2001 # w[1] = 1*3 + 1000*4*1 = 4003 # w[2] = 1*5 + 1000*6*1 = 6005 assert w.shape == (3,) - assert jnp.isclose(w.value[0], 2001.0) - assert jnp.isclose(w.value[1], 4003.0) - assert jnp.isclose(w.value[2], 6005.0) + assert jnp.isclose(w.value[0], 2001) + assert jnp.isclose(w.value[1], 4003) + assert jnp.isclose(w.value[2], 6005) assert all(w.unit[i] == _m * _s for i in range(3)) def test_jit_compatible(self): """jit-compiled matrix-vector multiply works.""" - A = QMat(jnp.array([[2.0, 3.0], [4.0, 5.0]]), unit=((_m, _m), (_m, _m))) - v = QMat(jnp.array([10.0, 20.0]), unit=(_s, _s)) + A = QMat(jnp.array([[2, 3], [4, 5]]), unit=((_m, _m), (_m, _m))) + v = QMat(jnp.array([10, 20]), unit=(_s, _s)) w = jax.jit(_matmul)(A, v) - assert jnp.isclose(w.value[0], 80.0) - assert jnp.isclose(w.value[1], 140.0) + assert jnp.isclose(w.value[0], 80) + assert jnp.isclose(w.value[1], 140) def test_batch_dimensions(self): """Leading batch dimensions are preserved.""" @@ -1020,18 +1008,18 @@ def mv(a, b): w = jax.vmap(mv)(A, v) # Each 2x2 ones @ [1, 1] = [2, 2] assert w.shape == (3, 2) - assert jnp.allclose(w.value, 2.0 * jnp.ones((3, 2))) + assert jnp.allclose(w.value, 2 * jnp.ones((3, 2))) def test_different_per_row_output_units(self): """Each output row can have a different unit.""" # A: row 0 in m, row 1 in kg; v in s - A = QMat(jnp.array([[1.0, 1.0], [1.0, 1.0]]), unit=((_m, _m), (_kg, _kg))) - v = QMat(jnp.array([1.0, 1.0]), unit=(_s, _s)) + A = QMat(jnp.array([[1, 1], [1, 1]]), unit=((_m, _m), (_kg, _kg))) + v = QMat(jnp.array([1, 1]), unit=(_s, _s)) w = _matmul(A, v) assert w.unit[0] == _m * _s assert w.unit[1] == _kg * _s - assert jnp.isclose(w.value[0], 2.0) - assert jnp.isclose(w.value[1], 2.0) + assert jnp.isclose(w.value[0], 2) + assert jnp.isclose(w.value[1], 2) # --------------------------------------------------------------------------- @@ -1050,17 +1038,17 @@ class TestDiagAndGather: def test_diag_values(self): """jnp.diag extracts the correct diagonal values.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 9.0])), + jnp.diag(jnp.array([1, 4, 9])), unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m")), ) d = _diag(A) assert isinstance(d, QMat) - assert jnp.allclose(d.value, jnp.array([1.0, 4.0, 9.0])) + assert jnp.allclose(d.value, jnp.array([1, 4, 9])) def test_diag_unit_is_1d(self): """jnp.diag result has a 1-D UnitsMatrix, not 2-D.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 9.0])), + jnp.diag(jnp.array([1, 4, 9])), unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m")), ) d = _diag(A) @@ -1070,7 +1058,7 @@ def test_diag_unit_is_1d(self): def test_diag_uniform_units(self): """Diagonal of uniform-unit matrix keeps that unit.""" A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), ) d = _diag(A) @@ -1085,7 +1073,7 @@ def test_diag_heterogeneous_units_picks_diagonal(self): are concrete Python/NumPy values. """ A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _s, _kg), (_m, _s, _kg), (_m, _s, _kg)), ) @@ -1102,7 +1090,7 @@ def _fancy_diag(mat): def test_diag_dimensionless_unit_string(self): """Dimensionless diagonal has '(, , )'-style repr, not '((, , ), ...)'.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 4.0])), + jnp.diag(jnp.array([1, 4, 4])), unit=(("", "", ""), ("", "", ""), ("", "", "")), ) d = _diag(A) @@ -1116,12 +1104,12 @@ def test_diag_dimensionless_unit_string(self): def test_diag_under_jit_uniform_units(self): """jnp.diag under jit works for uniform-unit QMatrix.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 9.0])), + jnp.diag(jnp.array([1, 4, 9])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), ) d = jax.jit(_diag)(A) assert isinstance(d, QMat) - assert jnp.allclose(d.value, jnp.array([1.0, 4.0, 9.0])) + assert jnp.allclose(d.value, jnp.array([1, 4, 9])) assert d.unit.ndim == 1 assert all(d.unit[i] == _m for i in range(3)) @@ -1137,17 +1125,17 @@ class TestDiagMethod: def test_uniform_units_values(self): """Diagonal values are correct for a uniform-unit matrix.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 9.0])), + jnp.diag(jnp.array([1, 4, 9])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), ) d = A.diag() assert isinstance(d, QMat) - assert jnp.allclose(d.value, jnp.array([1.0, 4.0, 9.0])) + assert jnp.allclose(d.value, jnp.array([1, 4, 9])) def test_uniform_units_shape(self): """Result is 1-D with length equal to the diagonal.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 9.0])), + jnp.diag(jnp.array([1, 4, 9])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), ) d = A.diag() @@ -1159,7 +1147,7 @@ def test_uniform_units_shape(self): def test_uniform_units_preserved(self): """Units on the diagonal are preserved unchanged.""" A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), ) d = A.diag() @@ -1168,16 +1156,16 @@ def test_uniform_units_preserved(self): def test_heterogeneous_units_values(self): """Diagonal values are correct for a heterogeneous-unit matrix.""" A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _s, _kg), (_m, _s, _kg), (_m, _s, _kg)), ) d = A.diag() - assert jnp.allclose(d.value, jnp.array([1.0, 2.0, 3.0])) + assert jnp.allclose(d.value, jnp.array([1, 2, 3])) def test_heterogeneous_units_correct(self): """Each diagonal unit is taken from ``self.unit[i, i]``.""" A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _s, _kg), (_m, _s, _kg), (_m, _s, _kg)), ) d = A.diag() @@ -1188,7 +1176,7 @@ def test_heterogeneous_units_correct(self): def test_heterogeneous_units_to_string(self): """1-D unit string has the correct format.""" A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _s, _kg), (_m, _s, _kg), (_m, _s, _kg)), ) d = A.diag() @@ -1197,13 +1185,13 @@ def test_heterogeneous_units_to_string(self): def test_2x2_square(self): """2x2 matrix diagonal.""" A = QMat( - jnp.array([[5.0, 0.0], [0.0, 7.0]]), + jnp.array([[5, 0], [0, 7]]), unit=((_m, _s), (_kg, _rad)), ) d = A.diag() assert d.shape == (2,) - assert jnp.isclose(d.value[0], 5.0) - assert jnp.isclose(d.value[1], 7.0) + assert jnp.isclose(d.value[0], 5) + assert jnp.isclose(d.value[1], 7) assert d.unit[0] == _m assert d.unit[1] == _rad @@ -1211,13 +1199,13 @@ def test_non_square_picks_min_dim(self): """For non-square matrices the diagonal length is min(rows, cols).""" # 2x3 matrix → diagonal of length 2 A = QMat( - jnp.arange(6.0).reshape(2, 3), + jnp.arange(6).reshape(2, 3), unit=((_m, _s, _kg), (_rad, _km, _ms)), ) d = A.diag() assert d.shape == (2,) - assert jnp.isclose(d.value[0], 0.0) # A[0,0] - assert jnp.isclose(d.value[1], 4.0) # A[1,1] + assert jnp.isclose(d.value[0], 0) # A[0,0] + assert jnp.isclose(d.value[1], 4) # A[1,1] assert d.unit[0] == _m # unit[0,0] assert d.unit[1] == _km # unit[1,1] @@ -1229,39 +1217,39 @@ def test_1d_raises(self, qm_1d): def test_jit_uniform_units(self): """Works under jax.jit with uniform units.""" A = QMat( - jnp.diag(jnp.array([1.0, 4.0, 9.0])), + jnp.diag(jnp.array([1, 4, 9])), unit=((_m, _m, _m), (_m, _m, _m), (_m, _m, _m)), ) d = jax.jit(lambda x: x.diag())(A) assert isinstance(d, QMat) - assert jnp.allclose(d.value, jnp.array([1.0, 4.0, 9.0])) + assert jnp.allclose(d.value, jnp.array([1, 4, 9])) assert d.unit[0] == _m def test_jit_heterogeneous_units(self): """Works under jax.jit with heterogeneous units (the key advantage).""" A = QMat( - jnp.diag(jnp.array([1.0, 2.0, 3.0])), + jnp.diag(jnp.array([1, 2, 3])), unit=((_m, _s, _kg), (_m, _s, _kg), (_m, _s, _kg)), ) d = jax.jit(lambda x: x.diag())(A) assert d.unit[0] == _m assert d.unit[1] == _s assert d.unit[2] == _kg - assert jnp.allclose(d.value, jnp.array([1.0, 2.0, 3.0])) + assert jnp.allclose(d.value, jnp.array([1, 2, 3])) def test_batch_dimensions(self): """Batch leading dimensions are preserved.""" # (3, 2, 2) batched matrix - base = jnp.diag(jnp.array([1.0, 4.0])) + base = jnp.diag(jnp.array([1, 4])) A = QMat( jnp.stack([base, 2 * base, 3 * base]), # (3, 2, 2) unit=((_m, _m), (_s, _s)), ) d = A.diag() assert d.shape == (3, 2) - assert jnp.isclose(d.value[0, 0], 1.0) - assert jnp.isclose(d.value[1, 0], 2.0) - assert jnp.isclose(d.value[2, 1], 12.0) + assert jnp.isclose(d.value[0, 0], 1) + assert jnp.isclose(d.value[1, 0], 2) + assert jnp.isclose(d.value[2, 1], 12) # --------------------------------------------------------------------------- @@ -1405,7 +1393,7 @@ class TestQMatrixTranspose: def test_2d_square_values(self, qm_2x2): """Value array is transposed correctly for a square matrix.""" t = qm_2x2.T - expected = jnp.array([[1.0, 3.0], [2.0, 4.0]]) + expected = jnp.array([[1, 3], [2, 4]]) assert jnp.allclose(t.value, expected) def test_2d_square_units(self, qm_2x2): @@ -1431,9 +1419,9 @@ def test_2d_square_returns_QMatrix(self, qm_2x2): def test_2d_nonsquare_values(self): """Transposing a 2x3 matrix gives a 3x2 with correct values.""" - a = QMat(jnp.arange(6.0).reshape(2, 3), unit=((_m, _s, _kg), (_rad, _km, _ms))) + a = QMat(jnp.arange(6).reshape(2, 3), unit=((_m, _s, _kg), (_rad, _km, _ms))) t = a.T - expected = jnp.arange(6.0).reshape(2, 3).T + expected = jnp.arange(6).reshape(2, 3).T assert jnp.allclose(t.value, expected) def test_2d_nonsquare_shape(self): @@ -1492,7 +1480,7 @@ def test_batch_unit_structure_unchanged(self): def test_batch_matrix_transpose_via_quax(self): """``matrix_transpose`` on a batched ``(B, N, M)`` preserves batch axes.""" - a = QMat(jnp.arange(12.0).reshape(3, 2, 2), unit=((_m, _s), (_kg, _rad))) + a = QMat(jnp.arange(12).reshape(3, 2, 2), unit=((_m, _s), (_kg, _rad))) t = qnp.matrix_transpose(a) # Batch axis preserved; last two swapped: (3,2,2) → (3,2,2) assert t.shape == (3, 2, 2) @@ -1507,7 +1495,7 @@ def test_batch_matrix_transpose_via_quax(self): def test_vmap_transpose(self): """``jax.vmap`` over a batch of 2-D matrices gives the per-element transpose.""" - a = QMat(jnp.arange(12.0).reshape(3, 2, 2), unit=((_m, _s), (_kg, _rad))) + a = QMat(jnp.arange(12).reshape(3, 2, 2), unit=((_m, _s), (_kg, _rad))) @quax.quaxify def single_T(x): @@ -1530,7 +1518,7 @@ def single_T(x): def test_jit_values(self, qm_2x2): """``jax.jit`` preserves transpose values.""" t = jax.jit(_transpose)(qm_2x2) - expected = jnp.array([[1.0, 3.0], [2.0, 4.0]]) + expected = jnp.array([[1, 3], [2, 4]]) assert jnp.allclose(t.value, expected) def test_jit_units(self, qm_2x2): @@ -1638,19 +1626,19 @@ class TestDetQMatrix: def test_returns_abstract_quantity(self): """Det of a 2×2 QMatrix returns an AbstractQuantity.""" - A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[2, 0], [0, 3]]), unit=((_m, _m), (_m, _m))) result = quax.quaxify(qm_det)(A) assert isinstance(result, u.AbstractQuantity) def test_value_2x2_diagonal(self): """Numeric value equals jnp.linalg.det of the value array.""" - A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[2, 0], [0, 3]]), unit=((_m, _m), (_m, _m))) result = quax.quaxify(qm_det)(A) - assert jnp.allclose(result.value, 6.0) + assert jnp.allclose(result.value, 6) def test_unit_product_of_diagonal(self): """Unit is the product of the main-diagonal units: m·m = m².""" - A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[2, 0], [0, 3]]), unit=((_m, _m), (_m, _m))) result = quax.quaxify(qm_det)(A) assert result.unit == u.unit("m2") @@ -1664,14 +1652,14 @@ def test_unit_3x3_uniform(self): """Det of 3×3 identity with uniform unit m gives unit m³.""" A = QMat(jnp.eye(3), unit=((_m, _m, _m),) * 3) result = quax.quaxify(qm_det)(A) - assert jnp.allclose(result.value, 1.0) + assert jnp.allclose(result.value, 1) assert result.unit == u.unit("m3") def test_jit_QMatrix(self): """Det of QMatrix works under jax.jit.""" A = QMat(jnp.array([[2.0, 0.0], [0.0, 3.0]]), unit=((_m, _m), (_m, _m))) result = jax.jit(quax.quaxify(qm_det))(A) - assert jnp.allclose(result.value, 6.0) + assert jnp.allclose(result.value, 6) assert result.unit == u.unit("m2") @@ -1784,20 +1772,20 @@ class TestInvQMatrix: def test_returns_QMatrix(self): """Inv of a 2×2 QMatrix returns a QMatrix.""" - A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[4, 0], [0, 1]]), unit=((_m, _m), (_m, _m))) result = quax.quaxify(qm_inv)(A) assert isinstance(result, QMat) def test_value_2x2_diagonal(self): """Numeric value equals jnp.linalg.inv of the value array.""" - A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[4, 0], [0, 1]]), unit=((_m, _m), (_m, _m))) result = quax.quaxify(qm_inv)(A) expected_val = jnp.linalg.inv(jnp.array([[4.0, 0.0], [0.0, 1.0]])) assert jnp.allclose(result.value, expected_val) def test_unit_reciprocal(self): """Unit of the inverse is the reciprocal of the original unit: 1/m.""" - A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[4, 0], [0, 1]]), unit=((_m, _m), (_m, _m))) result = quax.quaxify(qm_inv)(A) expected_unit = u.unit("1 / m") assert result.unit[0, 0] == expected_unit @@ -1807,7 +1795,7 @@ def test_unit_m2_per_rad2(self): """Inv of a metric with m²/rad² entries carries rad²/m² units.""" m2_r2 = u.unit("m2 / rad2") A = QMat( - jnp.array([[4.0, 0.0], [0.0, 1.0]]), + jnp.array([[4, 0], [0, 1]]), unit=((m2_r2, m2_r2), (m2_r2, m2_r2)), ) result = quax.quaxify(qm_inv)(A) @@ -1815,7 +1803,7 @@ def test_unit_m2_per_rad2(self): def test_jit_QMatrix(self): """Inv of QMatrix works under jax.jit.""" - A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[4, 0], [0, 1]]), unit=((_m, _m), (_m, _m))) result = jax.jit(quax.quaxify(qm_inv))(A) assert jnp.allclose(result.value, jnp.array([[0.25, 0.0], [0.0, 1.0]])) assert result.unit[0, 0] == u.unit("1 / m") @@ -1823,7 +1811,7 @@ def test_jit_QMatrix(self): def test_roundtrip_identity(self): """A @ inv(A) ≈ I for a QMatrix (value check).""" A = QMat( - jnp.array([[2.0, 1.0], [1.0, 3.0]]), + jnp.array([[2, 1], [1, 3]]), unit=((_m, _m), (_m, _m)), ) Ainv = quax.quaxify(qm_inv)(A) From 29f3994a650f2d2c96600c546b690644e0a30c33 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 19:24:28 -0400 Subject: [PATCH 14/15] =?UTF-8?q?=E2=9C=85=20test(tests):=20replace=20whol?= =?UTF-8?q?e-number=20float=20literals=20with=20ints=20in=20quantity=5Fmat?= =?UTF-8?q?rix=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace float literals like 1.0, 2.0 with int equivalents where the values are whole numbers. Preserve floats where needed for JAX grad/jvp tests, dtype-sensitive operations, and mixed-unit conversion tests that produce float results. --- tests/unit/internal/test_quantity_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/internal/test_quantity_matrix.py b/tests/unit/internal/test_quantity_matrix.py index 1c9f0584..f78ac007 100644 --- a/tests/unit/internal/test_quantity_matrix.py +++ b/tests/unit/internal/test_quantity_matrix.py @@ -1803,7 +1803,7 @@ def test_unit_m2_per_rad2(self): def test_jit_QMatrix(self): """Inv of QMatrix works under jax.jit.""" - A = QMat(jnp.array([[4, 0], [0, 1]]), unit=((_m, _m), (_m, _m))) + A = QMat(jnp.array([[4.0, 0.0], [0.0, 1.0]]), unit=((_m, _m), (_m, _m))) result = jax.jit(quax.quaxify(qm_inv))(A) assert jnp.allclose(result.value, jnp.array([[0.25, 0.0], [0.0, 1.0]])) assert result.unit[0, 0] == u.unit("1 / m") From aed8deac65f025b066965aed7e79bb2272d0f430 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 17 May 2026 19:36:24 -0400 Subject: [PATCH 15/15] =?UTF-8?q?=E2=9C=85=20test(tests):=20replace=20whol?= =?UTF-8?q?e-number=20float=20literals=20with=20ints=20in=20strategies=20a?= =?UTF-8?q?nd=20ptmap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace float literals like 8.0, -8.0, -3.0, 3.0, -5.0, 5.0, 10.0, -10.0 with int equivalents where the values are whole numbers and dtype safety permits. Signed-off-by: nstarman --- .../coordinax.hypothesis/tests/test_composite_dispatch.py | 2 +- tests/strategies.py | 8 ++++---- tests/usage/charts/test_ptmap.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/coordinax.hypothesis/tests/test_composite_dispatch.py b/packages/coordinax.hypothesis/tests/test_composite_dispatch.py index 86b3bd73..f0d336d5 100644 --- a/packages/coordinax.hypothesis/tests/test_composite_dispatch.py +++ b/packages/coordinax.hypothesis/tests/test_composite_dispatch.py @@ -143,7 +143,7 @@ def test_dispatch_two_int_args(v): assert 0 <= v <= 100 -@given(interval_value(0, 1)) +@given(interval_value(0.0, 1.0)) def test_dispatch_two_float_args(v): assert isinstance(v, float) assert 0 <= v <= 1 diff --git a/tests/strategies.py b/tests/strategies.py index 212fb254..55addabb 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -46,18 +46,18 @@ def _angle_qty(min_value: float, max_value: float): # --------------------------------------------------------------------------- # Strictly positive radial coordinate (avoids origin singularity) -pos_m = _m_qty(0.5, 8.0) +pos_m = _m_qty(0.5, 8) # Any Cartesian component value -any_m = _m_qty(-8.0, 8.0) +any_m = _m_qty(-8, 8) # Polar angle θ ∈ (0.25, 2.875) — avoids singularities at 0 and π polar_rad = _angle_qty(0.25, 2.875) # Any azimuthal / angle value -any_angle_rad = _angle_qty(-3.0, 3.0) +any_angle_rad = _angle_qty(-3, 3) # Dimensionless tangent-vector components v_elem = st.floats( - min_value=-5.0, max_value=5.0, allow_nan=False, allow_subnormal=False, width=32 + min_value=-5, max_value=5, allow_nan=False, allow_subnormal=False, width=32 ) diff --git a/tests/usage/charts/test_ptmap.py b/tests/usage/charts/test_ptmap.py index 9e882658..fed5934b 100644 --- a/tests/usage/charts/test_ptmap.py +++ b/tests/usage/charts/test_ptmap.py @@ -32,13 +32,13 @@ # Helpers # --------------------------------------------------------------------------- -_pos_m = _m_qty(0.5, 10.0) -_any_m = _m_qty(-10.0, 10.0) +_pos_m = _m_qty(0.5, 10) +_any_m = _m_qty(-10, 10) def _assert_cdict_approx(got, ref, *, rel=1e-5, abs=None) -> None: """Assert two CDicts agree component-wise, stripping each to ref's units.""" - atol = 0.0 if abs is None else abs + atol = 0 if abs is None else abs for key in ref: got_value = np.asarray(u.ustrip(ref[key].unit, got[key])) ref_value = np.asarray(u.ustrip(ref[key].unit, ref[key]))