diff --git a/docs/spec.md b/docs/spec.md index 03e01471..66d9054d 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -946,7 +946,7 @@ A non-exhaustive table of exported objects are: | `coordinax.angles` | `AbstractAngle`, `Angle`, `wrap_to` | | `coordinax.distances` | `AbstractDistance`, `Distance` | | `coordinax.charts` | `CartesianProductChart`,
`cartesian_chart`, `guess_chart`, `cdict`, `pt_map`, `jac_pt_map`,
`cart0d`,
`cart1d`, `radial1d`, `time1d`,
`cart2d`, `polar2d`,
`cart3d`, `cyl3d`, `sph3d`, `lonlat_sph3d`, `loncoslat_sph3d`, `math_sph3d`,
`cartnd`,
`spacetimect` | -| `coordinax.representations` | `cconvert`, `change_basis`,
`Representation`, `point`, `coord_disp`, `coord_vel`, `coord_acc`, `phys_disp`, `phys_vel`, `phys_acc`,
`PointGeometry`, `point_geom`, `TangentGeometry`, `tangent_geom`,
`NoBasis`, `no_basis`, `CoordinateBasis`, `coord_basis`, `PhysicalBasis`, `phys_basis`,
`Location`, `loc`, `Displacement`, `dpl`, `Velocity`, `vel`, `Acceleration`, `acc`,
`guess_geometry_kind`, `guess_semantic_kind`, `guess_rep` | +| `coordinax.representations` | `cconvert`, `change_basis`, `tangent_map`,
`Representation`, `point`, `coord_disp`, `coord_vel`, `coord_acc`, `phys_disp`, `phys_vel`, `phys_acc`,
`PointGeometry`, `point_geom`, `TangentGeometry`, `tangent_geom`,
`NoBasis`, `no_basis`, `CoordinateBasis`, `coord_basis`, `PhysicalBasis`, `phys_basis`,
`Location`, `loc`, `Displacement`, `dpl`, `Velocity`, `vel`, `Acceleration`, `acc`,
`guess_geometry_kind`, `guess_semantic_kind`, `guess_rep` | | `coordinax.vectors` | `Point`, `ToUnitsOptions` | | `coordinax.manifolds` | `guess_manifold`, `scale_factors`, `angle_between`,
`EuclideanManifold`, `EuclideanMetric`, `euclidean3d`,
`EmbeddedManifold`, `EmbeddedChart`
`twosphere`, `embedded_twosphere`,
`CustomManifold`,`CustomAtlas`, | | `coordinax.transforms` | `act`, `simplify`, `compose`, `materialize_transform`,
`AbstractTransform`, `Identity`, `Composed`, `Translate`, `Rotate`, `Reflect`, `Scale`, `Shear`, `identity`,
`AbstractTransformGroup`, `IdentityGroup`, `DiffeomorphismGroup`, `AffineGroup`, `EuclideanGroup`, `OrthogonalGroup`, `SpecialOrthogonalGroup`, `PoincareGroup`, `LorentzGroup`, `ProperOrthochronousLorentzGroup` | @@ -1728,6 +1728,44 @@ A representation is therefore **not** the same thing as a chart: the chart deter - Inherits all failure semantics from `guess_geometry_kind`. +!!! info `tangent_map` + + Transform a tangent vector from one chart to another. + + **Signature:** + + ```text + tangent_map(v, from_chart, from_geom, from_rep, to_chart, to_geom, to_rep, /, *, at) -> CDict + ``` + + A 4-argument shorthand form is also supported: + + ```text + tangent_map(v, from_chart, from_rep, to_chart, /, *, at) -> CDict + ``` + + **Arguments:** + + - `v`: `CDict` — tangent vector components in `from_chart` with basis `from_rep.basis`. + - `from_chart`: source chart. + - `from_geom`: source geometry (e.g. `TangentGeometry`). + - `from_rep`: source `Representation` (must have `TangentGeometry`). + - `to_chart`: target chart. + - `to_geom`: target geometry. + - `to_rep`: target `Representation`. + - `at`: `CDict` — base point in `from_chart` coordinates at which the tangent space is attached. Required for non-Cartesian charts (since the Jacobian depends on the base point). + + **Semantics by basis:** + + - **`CoordinateBasis`**: delegates to `jac_pt_map(at, from_chart, to_chart)` to obtain the Jacobian $J^j{}_i = \partial\tilde{q}^j/\partial q^i$, then applies $\tilde{v}^j = J^j{}_i v^i$. + - **`PhysicalBasis`**: fetch the orthonormal frame matrices $B_{\rm from}$ (columns = physical basis vectors in Cartesian) and $B_{\rm to}$ via `frame_cart`, compute $R = B_{\rm to}^T B_{\rm from}$, apply $\hat{v}' = R \hat{v}$. + + **Same-chart optimisation:** when `from_chart is to_chart`, returns `v` unchanged. + + **`cconvert` integration:** `cconvert` dispatches to `tangent_map` when the source representation has `TangentGeometry`, passing `at` through the `at` keyword argument. + + **Same-chart basis conversion:** `change_basis(v, chart, from_basis, to_basis, /, *, at)` changes tangent component conventions without changing charts. In v1 it is defined only for Cartesian charts and `CoordinateBasis` $ +
### Geometric Kind diff --git a/packages/coordinax.api/src/coordinax/api/representations.py b/packages/coordinax.api/src/coordinax/api/representations.py index ddea6243..7e366024 100644 --- a/packages/coordinax.api/src/coordinax/api/representations.py +++ b/packages/coordinax.api/src/coordinax/api/representations.py @@ -32,6 +32,23 @@ def change_basis(*args: Any, **kwargs: Any) -> Any: raise NotImplementedError # pragma: no cover +@plum.dispatch.abstract +def tangent_map(*args: Any, **kwargs: Any) -> Any: + """Compute the tangent map (Jacobian) of a chart transition. + + This is an abstract API definition. See the main coordinax package for + concrete implementations. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.representations as cxr + + """ + raise NotImplementedError # pragma: no cover + + @plum.dispatch.abstract def cconvert(*args: Any, **kwargs: Any) -> Any: """Transform the current vector to the target chart. diff --git a/src/coordinax/main/__init__.py b/src/coordinax/main/__init__.py index ee48ec31..386625c7 100644 --- a/src/coordinax/main/__init__.py +++ b/src/coordinax/main/__init__.py @@ -78,6 +78,7 @@ "phys_disp", "phys_vel", "phys_acc", + "tangent_map", "change_basis", # vectors "Point", @@ -151,6 +152,7 @@ point_geom, subtract, tangent_geom, + tangent_map, vel, ) from coordinax.transforms import ( diff --git a/src/coordinax/representations/__init__.py b/src/coordinax/representations/__init__.py index e5eadf45..88fb0da5 100644 --- a/src/coordinax/representations/__init__.py +++ b/src/coordinax/representations/__init__.py @@ -96,6 +96,7 @@ "guess_rep", "guess_semantic_kind", "subtract", + "tangent_map", "change_basis", # Representations "Representation", @@ -171,6 +172,7 @@ point, point_geom, tangent_geom, + tangent_map, vel, ) from coordinax.api.representations import ( diff --git a/src/coordinax/representations/_src/__init__.py b/src/coordinax/representations/_src/__init__.py index ebb4dc0d..f702597b 100644 --- a/src/coordinax/representations/_src/__init__.py +++ b/src/coordinax/representations/_src/__init__.py @@ -8,3 +8,4 @@ from .register_cx import * from .rep import * from .semantics import * +from .tangent_map import * diff --git a/src/coordinax/representations/_src/basis_change.py b/src/coordinax/representations/_src/basis_change.py index 21e1d13a..8e724dda 100644 --- a/src/coordinax/representations/_src/basis_change.py +++ b/src/coordinax/representations/_src/basis_change.py @@ -3,7 +3,7 @@ __all__ = ("change_basis",) from jaxtyping import ArrayLike -from typing import Any +from typing import Any, TypeVar import jax import jax.scipy.linalg @@ -28,6 +28,8 @@ from coordinax.internal import QuantityMatrix, UnitsMatrix from coordinax.internal.custom_types import CDict, OptUSys +T = TypeVar("T", bound=u.Quantity) + _RAD = u.unit("rad") diff --git a/src/coordinax/representations/_src/core.py b/src/coordinax/representations/_src/core.py index be04395b..9cba4587 100644 --- a/src/coordinax/representations/_src/core.py +++ b/src/coordinax/representations/_src/core.py @@ -330,19 +330,14 @@ def cconvert( at: CDict | None = None, usys: OptUSys = None, ) -> Any: - r"""Convert tangent data between basis conventions in the same chart. - - Tangent conversions are basis changes when source and target charts are - identical. In this case, `cconvert` redispatches to `change_basis`. + r"""Convert tangent data between charts via Jacobian pushforward. Examples -------- - Convert tangent data between coordinate and physical basis in the same - chart: - >>> import jax.numpy as jnp >>> import coordinax.charts as cxc >>> import coordinax.representations as cxr + >>> v = {"r": jnp.array(5.0), "theta": jnp.array(1.0), "phi": jnp.array(2.0)} >>> at = {"r": jnp.array(3.0), "theta": jnp.array(0.5), "phi": jnp.array(0.0)} >>> cxr.cconvert(v, cxc.sph3d, cxr.tangent_geom, cxr.coord_disp, @@ -351,28 +346,15 @@ def cconvert( 'theta': Array(3., dtype=float64, ...), 'phi': Array(..., dtype=float64, ...)} - Tangent conversion across different charts is not implemented by this - dispatch: - - >>> cxr.cconvert(v, cxc.sph3d, cxr.tangent_geom, cxr.coord_disp, - ... cxc.cart3d, cxr.tangent_geom, cxr.coord_disp, at=at) - Traceback (most recent call last): - ... - NotImplementedError: Tangent cconvert between different charts is not implemented; - use the same chart for basis changes. + >>> v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> cxr.cconvert(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, cxr.coord_disp, at=at) + {'r': Array(1., ...), 'theta': Array(0., ...)} """ - del from_geom, to_geom # represented by dispatch signature - - if from_chart != to_chart: - msg = ( - "Tangent cconvert between different charts is not implemented; " - "use the same chart for basis changes." - ) - raise NotImplementedError(msg) - - return cxrapi.change_basis( - x, from_chart, from_rep.basis, to_rep.basis, at=at, usys=usys + del from_geom, to_geom + return cxrapi.tangent_map( + x, from_chart, from_rep, to_chart, to_rep, at=at, usys=usys ) diff --git a/src/coordinax/representations/_src/tangent_map.py b/src/coordinax/representations/_src/tangent_map.py new file mode 100644 index 00000000..71177e44 --- /dev/null +++ b/src/coordinax/representations/_src/tangent_map.py @@ -0,0 +1,276 @@ +"""Tangent map (Jacobian pushforward) between coordinate charts.""" + +__all__ = ("tangent_map",) + +from jaxtyping import Array +from typing import Any + +import jax.numpy as jnp +import plum + +import quaxed.numpy as qnp +import unxt as u + +import coordinax.api.representations as cxrapi +import coordinax.charts as cxc +from .basis import CoordinateBasis, PhysicalBasis, coord_basis +from .geom import TangentGeometry +from .rep import Representation +from coordinax.internal import QuantityMatrix, pack_nonuniform_unit +from coordinax.internal.custom_types import CDict, OptUSys + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + + +def _check_tangent_geom(geom: object, label: str) -> None: + if not isinstance(geom, TangentGeometry): + raise TypeError( + f"tangent_map requires TangentGeometry for {label}, got {geom!r}" + ) + + +def _check_linear_basis(rep: Representation, label: str) -> None: + if not isinstance(rep.basis, CoordinateBasis | PhysicalBasis): + raise TypeError( + "tangent_map requires CoordinateBasis or PhysicalBasis for " + f"{label}, got {rep.basis!r}" + ) + + +# --------------------------------------------------------------------------- +# Shared helper: apply a QuantityMatrix Jacobian to a tangent vector CDict +# --------------------------------------------------------------------------- + + +def _apply_jac( + J: Array | QuantityMatrix, + from_components: tuple[str, ...], + to_components: tuple[str, ...], + v: CDict, +) -> CDict: + """Apply a 2-D QuantityMatrix Jacobian to a tangent CDict. + + If the components of ``v`` are plain arrays, the output is a plain-array + CDict (using ``J.value @ v_arr``). If any component of ``v`` is a + {class}`~unxt.AbstractQuantity`, ``v`` is packed into a 1-D + {class}`~coordinax.internal.QuantityMatrix` and the result is computed + via ``qnp.matmul(J, v_qm)``, which handles per-element unit conversion. + + Parameters + ---------- + J + QuantityMatrix of shape ``(n_out, n_in)`` returned by ``jac_pt_map``. + from_components + Ordered component names for the input chart (columns of J). + to_components + Ordered component names for the output chart (rows of J). + v + Tangent vector components. Values may be plain JAX arrays or + {class}`~unxt.AbstractQuantity` objects. + + Returns + ------- + CDict + Tangent vector components in the output chart. + + """ + if isinstance(v[from_components[0]], u.AbstractQuantity): + v_arr, v_units = pack_nonuniform_unit(v, keys=from_components) + v_qm = QuantityMatrix(v_arr, unit=v_units) + w = qnp.matmul(J, v_qm) # (n_out,) QuantityMatrix + return {key: u.Q(w.value[i], w.unit[i]) for i, key in enumerate(to_components)} # ty: ignore[unresolved-attribute] + + v_arr = jnp.stack([jnp.asarray(v[k]) for k in from_components]) + # When J is a QuantityMatrix, use J.value to avoid the Quax fallback path + # that returns a QuantityMatrix with J's own 2D unit structure (wrong). + # Plain-array velocity is dimensionless, so numeric-only application is correct. + j_arr = J.value if isinstance(J, QuantityMatrix) else J + result = j_arr @ v_arr + return {key: result[i] for i, key in enumerate(to_components)} + + +# --------------------------------------------------------------------------- +# Same-Representation dispatch +# --------------------------------------------------------------------------- + + +@plum.dispatch +def tangent_map( + v: Any, + from_chart: cxc.AbstractChart, + basis: CoordinateBasis, + to_chart: cxc.AbstractChart, + /, + *, + at: CDict | None = None, + usys: OptUSys = None, +) -> CDict: + r"""Push a tangent vector forward from one chart to another. + + Applies the Jacobian of the chart transition map to the tangent vector + components ``v``, evaluated at the base point ``at``. + + Examples + -------- + Convert a tangent vector from Cartesian to polar 2D at the point (1, 0): + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.representations as cxr + + >>> v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + {'r': Array(1., dtype=float64), 'theta': Array(0., dtype=float64)} + + """ + # Same-chart optimization: identity transform + if from_chart is to_chart: + return v + + J = cxc.jac_pt_map(at, from_chart, to_chart, usys=usys) + return _apply_jac(J, from_chart.components, to_chart.components, v) + + +@plum.dispatch +def tangent_map( + v: CDict, + from_chart: cxc.AbstractChart, + basis: PhysicalBasis, + to_chart: cxc.AbstractChart, + /, + *, + at: CDict | None = None, + usys: OptUSys = None, +) -> CDict: + r"""Push a tangent vector forward in physical-basis components. + + This dispatch applies the tangent-map pushforward while preserving the + physical-basis convention by composing three steps: + + 1. convert source components from physical basis to coordinate basis, + 2. apply the chart Jacobian pushforward, + 3. convert target components back to physical basis. + + Examples + -------- + Convert a physical-basis tangent vector from Cartesian to spherical 3D: + + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.representations as cxr + + >>> v = {"x": u.Q(1, "m/s"), "y": u.Q(0, "m/s"), "z": u.Q(0, "m/s")} + >>> at = {"x": u.Q(1, "m"), "y": u.Q(0, "m"), "z": u.Q(0, "m")} + >>> cxr.tangent_map(v, cxc.cart3d, cxr.phys_basis, cxc.sph3d, at=at) + {'r': Q(1., 'm / s'), 'theta': Q(-0., 'm / s'), 'phi': Q(0., 'm / s')} + + The same call can be made using a physical representation: + + >>> v = {"x": 1, "y": 0, "z": 0} + >>> at = {"x": 1, "y": 0, "z": 0} + >>> usys = u.unitsystems.si + >>> cxr.tangent_map(v, cxc.cart3d, cxr.phys_disp, cxc.sph3d, at=at, usys=usys) + {'r': Array(1., dtype=float64), 'theta': Array(0., dtype=float64), + 'phi': Array(0., dtype=float64)} + + """ + # Same-chart optimization: identity transform + if from_chart is to_chart: + return v + + # TODO: direct routes + # Compute physical-basis transport by composing: + # physical -> coordinate -> Jacobian pushforward -> physical + + # Basis Change: physical to coord + v_coord = cxrapi.change_basis(v, from_chart, basis, coord_basis, at=at, usys=usys) + # Chart Jacobian pushforward in coordinate basis + v_coord_to = cxrapi.tangent_map( + v_coord, from_chart, coord_basis, to_chart, at=at, usys=usys + ) + at_to = cxc.pt_map(at, from_chart, to_chart, usys=usys) + # Basis Change: coord to physical + v_coord: CDict = cxrapi.change_basis( # ty: ignore[invalid-assignment] + v_coord_to, to_chart, coord_basis, basis, at=at_to, usys=usys + ) + return v_coord + + +@plum.dispatch +def tangent_map( + v: Any, + from_chart: cxc.AbstractChart, + from_rep: Representation, + to_chart: cxc.AbstractChart, + /, + *, + at: CDict | None = None, + usys: OptUSys = None, +) -> CDict: + r"""Push a tangent vector forward from one chart to another. + + Applies the Jacobian of the chart transition map to the tangent vector + components ``v``, evaluated at the base point ``at``. + + Examples + -------- + Convert a tangent vector from Cartesian to polar 2D at the point (1, 0): + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.representations as cxr + + >>> v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + {'r': Array(1., dtype=float64), 'theta': Array(0., dtype=float64)} + + """ + _check_tangent_geom(from_rep.geom_kind, "from_rep") + + return cxrapi.tangent_map(v, from_chart, from_rep.basis, to_chart, at=at, usys=usys) # ty: ignore[invalid-return-type] + + +# --------------------------------------------------------------------------- +# Cross-Representation dispatch +# --------------------------------------------------------------------------- + + +@plum.dispatch +def tangent_map( + v: Any, + from_chart: cxc.AbstractChart, + from_rep: Representation, + to_chart: cxc.AbstractChart, + to_rep: Representation, + /, + *, + at: CDict | None = None, + usys: OptUSys = None, +) -> CDict: + r"""Push a tangent vector forward from one chart to another. + + Applies the Jacobian of the chart transition map to the tangent vector + components ``v``, evaluated at the base point ``at``. + + Examples + -------- + Convert a tangent vector from Cartesian to polar 2D at the point (1, 0): + + >>> import jax.numpy as jnp + >>> import coordinax.charts as cxc + >>> import coordinax.representations as cxr + + >>> v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + >>> cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, + ... cxc.polar2d, cxr.coord_disp, at=at) + {'r': Array(1., dtype=float64), 'theta': Array(0., dtype=float64)} + + """ + v = cxrapi.tangent_map(v, from_chart, from_rep, to_chart, at=at, usys=usys) + v = cxrapi.change_basis(v, to_chart, from_rep.basis, to_rep.basis, at=at, usys=usys) + return v # noqa: RET504 # ty: ignore[invalid-return-type] diff --git a/src/coordinax/transforms/_src/actions/rotate.py b/src/coordinax/transforms/_src/actions/rotate.py index 7fc73d62..d6e0b3ae 100644 --- a/src/coordinax/transforms/_src/actions/rotate.py +++ b/src/coordinax/transforms/_src/actions/rotate.py @@ -633,6 +633,70 @@ def act( return cast("CDict", out) +@plum.dispatch +def act( + op: Rotate, + tau: Any, + x: CDict, + chart: cxc.AbstractChart, + geom: cxr.TangentGeometry, + rep: cxr.Representation, + /, + *, + usys: OptUSys = None, + **kw: Any, +) -> CDict: + """Apply a spatial rotation to a TangentGeometry coordinate dictionary. + + Rotation is linear, so tangent vectors (for example velocity or + acceleration) transform with the same rotation matrix as point + coordinates. + + The implementation: + + 1. Converts ``x`` to the chart canonical Cartesian chart. + 2. Packs Cartesian components to a common unit. + 3. Applies ``R`` via ``einsum`` in a batch-safe way. + 4. Converts the rotated result back to the original chart. + + Examples + -------- + Rotate a Cartesian velocity vector by +90 degrees about ``z``: + + >>> import quaxed.numpy as jnp + >>> import unxt as u + >>> import coordinax.charts as cxc + >>> import coordinax.representations as cxr + >>> import coordinax.transforms as cxfm + + >>> op = cxfm.Rotate.from_euler("z", u.Q(90, "deg")) + >>> x = {"x": u.Q(1, "m/s"), "y": u.Q(0, "m/s"), "z": u.Q(0, "m/s")} + >>> out = cxfm.act(op, None, x, cxc.cart3d, cxr.tangent_geom, cxr.coord_vel) + >>> jnp.stack([out[c].to_value("m/s") for c in ("x", "y", "z")]).round(3) + Array([0., 1., 0.], dtype=float64) + + """ + del geom, rep, kw # Rotation acts identically on tangent vectors. + + cart = chart.cartesian + comps_cart = cart.components + + op_eval = materialize_transform(op, tau) + R = op_eval._get_R(cart) + + # Convert to canonical Cartesian chart. + p_cart = cxc.pt_map(x, chart, cart, usys=usys) + + # Pack -> rotate -> unpack (batch-safe) + v, unit = pack_uniform_unit(p_cart, keys=comps_cart) # ty: ignore[no-matching-overload] + v_rot = jnp.einsum("ij,...j->...i", R, v) # (..., n) + p_cart_rot = cxc.cdict(v_rot, unit, comps_cart) + + # Convert back to original chart. + out = cxc.pt_map(p_cart_rot, cart, chart, usys=usys) + return cast("CDict", out) + + # ----------------------------------------------- # On CDict with Cartesian-product charts diff --git a/src/coordinax/transforms/_src/actions/translate.py b/src/coordinax/transforms/_src/actions/translate.py index 56a79e67..17dcc53b 100644 --- a/src/coordinax/transforms/_src/actions/translate.py +++ b/src/coordinax/transforms/_src/actions/translate.py @@ -89,6 +89,11 @@ class Translate(AbstractAdd): """ + semantic_kind: cxr.AbstractTangentSemanticKind = eqx.field( + static=True, default=cxr.dpl + ) + """Semantic kind of tangent data this operator acts on. Default: Displacement.""" + # delta, chart, and right_add inherited from AbstractAdd @classmethod def groups(cls) -> frozenset[type]: @@ -96,9 +101,6 @@ def groups(cls) -> frozenset[type]: del cls return frozenset((groups.EuclideanGroup, groups.DiffeomorphismGroup)) - semantic_kind: None = eqx.field(static=True, default=None) - """Semantic kind of tangent data this operator acts on. Default: Displacement.""" - def __add__(self, other: object, /) -> Union["Translate", Composed]: """Combine two Translate operators with matching semantic kinds. diff --git a/tests/unit/representations/test_cconvert_tangent.py b/tests/unit/representations/test_cconvert_tangent.py index 1dcb3231..324972d1 100644 --- a/tests/unit/representations/test_cconvert_tangent.py +++ b/tests/unit/representations/test_cconvert_tangent.py @@ -1,122 +1,157 @@ -"""Tests for tangent `cconvert` dispatch behavior.""" +"""Tests for cconvert dispatching to tangent_map when source is TangentGeometry.""" -from typing import cast +__all__: tuple[str, ...] = () import jax import jax.numpy as jnp import numpy as np import pytest +import unxt as u + import coordinax.charts as cxc import coordinax.representations as cxr +usys = u.unitsystems.si -class TestCConvertTangentDispatch: - """Dispatch behavior for tangent-geometry `cconvert` conversions.""" - - def test_same_chart_dispatches_to_change_basis(self): - v = { - "r": jnp.array(5.0), - "theta": jnp.array(1.0), - "phi": jnp.array(2.0), - } - at = { - "r": jnp.array(3.0), - "theta": jnp.array(0.5), - "phi": jnp.array(0.0), - } - - out = cast( - "dict[str, object]", - cxr.cconvert(v, cxc.sph3d, cxr.coord_disp, cxc.sph3d, cxr.phys_disp, at=at), - ) - expected = cast( - "dict[str, object]", - cxr.change_basis(v, cxc.sph3d, cxr.coord_basis, cxr.phys_basis, at=at), - ) - np.testing.assert_allclose(np.asarray(out["r"]), np.asarray(expected["r"])) - np.testing.assert_allclose( - np.asarray(out["theta"]), np.asarray(expected["theta"]) - ) - np.testing.assert_allclose(np.asarray(out["phi"]), np.asarray(expected["phi"])) +class TestCconvertTangentGeometry: + """cconvert with TangentGeometry representation dispatches to tangent_map.""" - def test_same_chart_cartesian_without_at(self): - v = {"x": jnp.array(1.0), "y": jnp.array(2.0)} + def test_same_chart_noncartesian_matches_change_basis(self) -> None: + """Same-chart tangent conversion should reduce to basis conversion.""" + v = {"r": jnp.array(5.0), "theta": jnp.array(1.0), "phi": jnp.array(2.0)} + at = {"r": jnp.array(3.0), "theta": jnp.array(0.5), "phi": jnp.array(0.0)} - out = cast( - "dict[str, object]", - cxr.cconvert(v, cxc.cart2d, cxr.coord_disp, cxc.cart2d, cxr.phys_disp), + result = cxr.cconvert( + v, cxc.sph3d, cxr.coord_disp, cxc.sph3d, cxr.phys_disp, at=at, usys=usys + ) + expected = cxr.change_basis( + v, cxc.sph3d, cxr.coord_basis, cxr.phys_basis, at=at, usys=usys ) - np.testing.assert_allclose(np.asarray(out["x"]), np.asarray(v["x"])) - np.testing.assert_allclose(np.asarray(out["y"]), np.asarray(v["y"])) - - def test_different_chart_not_implemented(self): - v = { - "r": jnp.array(5.0), - "theta": jnp.array(1.0), - "phi": jnp.array(2.0), - } - at = { - "r": jnp.array(3.0), - "theta": jnp.array(0.5), - "phi": jnp.array(0.0), - } - - with pytest.raises(NotImplementedError, match="different charts"): - cxr.cconvert( - v, cxc.sph3d, cxr.coord_disp, cxc.cart3d, cxr.coord_disp, at=at - ) - - def test_same_chart_non_cartesian_missing_at_raises(self): - v = { - "r": jnp.array(5.0), - "theta": jnp.array(1.0), - "phi": jnp.array(2.0), - } + np.testing.assert_allclose(result["r"], expected["r"]) + np.testing.assert_allclose(result["theta"], expected["theta"]) + np.testing.assert_allclose(result["phi"], expected["phi"]) - with pytest.raises((TypeError, ValueError)): - cxr.cconvert(v, cxc.sph3d, cxr.coord_disp, cxc.sph3d, cxr.phys_disp) - - def test_same_chart_respects_tangent_semantic_kind(self): - v = { - "r": jnp.array(5.0), - "theta": jnp.array(1.0), - "phi": jnp.array(2.0), - } - at = { - "r": jnp.array(3.0), - "theta": jnp.array(0.5), - "phi": jnp.array(0.0), - } - - out_disp = cast( - "dict[str, object]", - cxr.cconvert(v, cxc.sph3d, cxr.coord_disp, cxc.sph3d, cxr.phys_disp, at=at), + def test_cart2d_to_polar2d_coord_disp(self) -> None: + """Cconvert with coord_disp routes through tangent_map (Jacobian).""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.cconvert( + v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, cxr.coord_disp, at=at, usys=usys ) - out_vel = cast( - "dict[str, object]", - cxr.cconvert(v, cxc.sph3d, cxr.coord_vel, cxc.sph3d, cxr.phys_vel, at=at), + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + + def test_same_chart_identity(self) -> None: + """Cconvert with same chart + TangentGeometry returns input unchanged.""" + v = {"x": jnp.array(2.0), "y": jnp.array(3.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.cconvert( + v, cxc.cart2d, cxr.coord_disp, cxc.cart2d, cxr.coord_disp, at=at, usys=usys ) + np.testing.assert_allclose(result["x"], 2.0) + np.testing.assert_allclose(result["y"], 3.0) - np.testing.assert_allclose(np.asarray(out_disp["r"]), np.asarray(out_vel["r"])) - np.testing.assert_allclose( - np.asarray(out_disp["theta"]), np.asarray(out_vel["theta"]) + def test_same_chart_cartesian_without_at(self) -> None: + """Cartesian same-chart basis conversion should not require `at`.""" + v = {"x": jnp.array(1.0), "y": jnp.array(2.0)} + result = cxr.cconvert( + v, cxc.cart2d, cxr.coord_disp, cxc.cart2d, cxr.phys_disp, usys=usys + ) + np.testing.assert_allclose(result["x"], v["x"]) + np.testing.assert_allclose(result["y"], v["y"]) + + def test_cart3d_to_sph3d_coord_vel(self) -> None: + """Cconvert with coord_vel representation uses tangent_map semantics.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.cconvert( + v, cxc.cart3d, cxr.coord_vel, cxc.sph3d, cxr.coord_vel, at=at, usys=usys + ) + # Purely radial result + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 0.0, atol=1e-6) + + def test_same_chart_respects_tangent_semantic_kind(self) -> None: + """Displacement and velocity variants should follow the same basis map.""" + v = {"r": jnp.array(5.0), "theta": jnp.array(1.0), "phi": jnp.array(2.0)} + at = {"r": jnp.array(3.0), "theta": jnp.array(0.5), "phi": jnp.array(0.0)} + + out_disp = cxr.cconvert( + v, cxc.sph3d, cxr.coord_disp, cxc.sph3d, cxr.phys_disp, at=at, usys=usys ) - np.testing.assert_allclose( - np.asarray(out_disp["phi"]), np.asarray(out_vel["phi"]) + out_vel = cxr.cconvert( + v, cxc.sph3d, cxr.coord_vel, cxc.sph3d, cxr.phys_vel, at=at, usys=usys ) - def test_jit_same_chart_tangent(self): - v = {"x": jnp.array(1.0), "y": jnp.array(2.0)} + np.testing.assert_allclose(out_disp["r"], out_vel["r"]) + np.testing.assert_allclose(out_disp["theta"], out_vel["theta"]) + np.testing.assert_allclose(out_disp["phi"], out_vel["phi"]) + + def test_jit_compatible(self) -> None: + """Cconvert with TangentGeometry is JIT-compatible.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} @jax.jit - def run(data): + def run(v, at): return cxr.cconvert( - data, cxc.cart2d, cxr.coord_disp, cxc.cart2d, cxr.phys_disp + v, + cxc.cart2d, + cxr.coord_disp, + cxc.polar2d, + cxr.coord_disp, + at=at, + usys=usys, ) - out = cast("dict[str, object]", run(v)) - np.testing.assert_allclose(np.asarray(out["x"]), np.asarray(v["x"])) - np.testing.assert_allclose(np.asarray(out["y"]), np.asarray(v["y"])) + result = run(v, at) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + + def test_round_trip(self) -> None: + """Cconvert tangent round trip: cart2d → polar2d → cart2d is identity.""" + v_cart = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at_cart = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + + # cart → polar + v_polar = cxr.cconvert( + v_cart, + cxc.cart2d, + cxr.coord_disp, + cxc.polar2d, + cxr.coord_disp, + at=at_cart, + usys=usys, + ) + + # at in polar coords + at_polar = cxr.cconvert(at_cart, cxc.cart2d, cxr.point, cxc.polar2d, usys=usys) + + # polar → cart + v_cart_back = cxr.cconvert( + v_polar, + cxc.polar2d, + cxr.coord_disp, + cxc.cart2d, + cxr.coord_disp, + at=at_polar, + usys=usys, + ) + + np.testing.assert_allclose(v_cart_back["x"], v_cart["x"], atol=1e-6) + np.testing.assert_allclose(v_cart_back["y"], v_cart["y"], atol=1e-6) + + +class TestCconvertAtRequired: + """cconvert with TangentGeometry requires the `at` keyword argument.""" + + def test_at_required_for_nonlinear_charts(self) -> None: + """Missing `at` raises informative error for non-Cartesian charts.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + with pytest.raises((TypeError, ValueError)): + cxr.cconvert( + v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, cxr.coord_disp, usys=usys + ) diff --git a/tests/unit/representations/test_change_basis.py b/tests/unit/representations/test_change_basis.py index b24eef82..42915e5c 100644 --- a/tests/unit/representations/test_change_basis.py +++ b/tests/unit/representations/test_change_basis.py @@ -27,6 +27,9 @@ def tree_equal(lhs: Any, rhs: Any) -> bool: return bool(jnp.all(jnp.stack(reduced))) +usys = u.unitsystems.si + + class TestChangeBasisExistence: """Import-surface checks for the public API.""" diff --git a/tests/unit/representations/test_tangent_map.py b/tests/unit/representations/test_tangent_map.py new file mode 100644 index 00000000..45589d06 --- /dev/null +++ b/tests/unit/representations/test_tangent_map.py @@ -0,0 +1,305 @@ +"""Tests for tangent_map function. + +Tests that tangent_map correctly transforms tangent vectors between charts using +the Jacobian pushforward (CoordinateBasis) or frame matrix rotation (PhysicalBasis). +""" + +__all__: tuple[str, ...] = () + +from typing import Any, cast + +import jax +import jax.numpy as jnp +import numpy as np + +import unxt as u + +import coordinax.charts as cxc +import coordinax.main as cx +import coordinax.representations as cxr + +usys = u.unitsystems.si + + +class TestTangentMapExistence: + """tangent_map is importable and callable.""" + + def test_importable_from_representations(self) -> None: + """tangent_map is in coordinax.representations.""" + assert hasattr(cxr, "tangent_map") + assert callable(cxr.tangent_map) + + def test_importable_from_main(self) -> None: + """tangent_map is in coordinax.main.""" + assert hasattr(cx, "tangent_map") + + +class TestTangentMapSameChart: + """Same-chart optimisation: tangent_map returns v unchanged.""" + + def test_cart3d_to_cart3d(self) -> None: + """Cart3D → Cart3D with CoordinateBasis returns input unchanged.""" + v = {"x": jnp.array(1.0), "y": jnp.array(2.0), "z": jnp.array(3.0)} + at = {"x": jnp.array(0.5), "y": jnp.array(0.5), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.cart3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["x"], v["x"]) + np.testing.assert_allclose(result["y"], v["y"]) + np.testing.assert_allclose(result["z"], v["z"]) + + def test_same_chart_identity(self) -> None: + """Any same chart passes through unchanged.""" + v = {"r": jnp.array(1.0), "theta": jnp.array(0.0), "phi": jnp.array(0.0)} + at = {"r": jnp.array(2.0), "theta": jnp.array(0.5), "phi": jnp.array(0.5)} + result = cxr.tangent_map(v, cxc.sph3d, cxr.coord_disp, cxc.sph3d, at=at) + np.testing.assert_allclose(result["r"], v["r"]) + np.testing.assert_allclose(result["theta"], v["theta"]) + np.testing.assert_allclose(result["phi"], v["phi"]) + + +class TestTangentMapCart3dToSph3d: + """Cart3D → Sph3D CoordinateBasis: Jacobian pushforward at (x=1, y=0, z=0). + + Uses physics spherical conventions: theta=polar, phi=azimuthal. + At (x=1, y=0, z=0), the base point is (r=1, theta=pi/2, phi=0). + + The Jacobian J = d(r,theta,phi)/d(x,y,z) at this point is:: + + J = [[1, 0, 0], + [0, 0, -1], + [0, 1, 0]] + + Resulting pushforwards: + - (dx,dy,dz)=(1,0,0) → (dr,dtheta,dphi)=(1,0,0) [radial] + - (dx,dy,dz)=(0,1,0) → (dr,dtheta,dphi)=(0,0,1) [phi direction] + - (dx,dy,dz)=(0,0,1) → (dr,dtheta,dphi)=(0,-1,0) [minus theta direction] + """ + + def test_radial_vector(self) -> None: + """Purely x-direction at (1,0,0) maps to purely radial (dr=1).""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 0.0, atol=1e-6) + + def test_phi_direction_vector(self) -> None: + """y-direction at (1,0,0) maps to phi direction (dphi=1).""" + v = {"x": jnp.array(0.0), "y": jnp.array(1.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["r"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 1.0, atol=1e-6) + + def test_z_direction_vector(self) -> None: + """z-direction at (1,0,0) maps to -theta direction (dtheta=-1).""" + v = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(1.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["r"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], -1.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 0.0, atol=1e-6) + + +class TestTangentMapPhysicalBasis: + """PhysicalBasis transformations are supported for tangent vectors.""" + + def test_cart3d_to_sph3d_radial_direction(self) -> None: + """Cartesian x-direction maps to spherical radial in physical basis.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + + result = cxr.tangent_map( + v, cxc.cart3d, cxr.phys_basis, cxc.sph3d, at=at, usys=usys + ) + + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 0.0, atol=1e-6) + + def test_representation_dispatch_with_phys_disp(self) -> None: + """Representation overload supports physical-basis representations.""" + v = {"x": jnp.array(0.0), "y": jnp.array(0.0), "z": jnp.array(1.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + + result = cxr.tangent_map( + v, cxc.cart3d, cxr.phys_disp, cxc.sph3d, at=at, usys=usys + ) + + np.testing.assert_allclose(result["r"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], -1.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 0.0, atol=1e-6) + + def test_same_chart_identity_phys_basis(self) -> None: + """Same-chart optimisation also holds for PhysicalBasis inputs.""" + v = { + "r": jnp.array(2.0), + "theta": jnp.array(-1.5), + "phi": jnp.array(0.25), + } + at = {"r": jnp.array(3.0), "theta": jnp.array(0.5), "phi": jnp.array(0.1)} + + result = cxr.tangent_map(v, cxc.sph3d, cxr.phys_basis, cxc.sph3d, at=at) + + np.testing.assert_allclose(result["r"], v["r"]) + np.testing.assert_allclose(result["theta"], v["theta"]) + np.testing.assert_allclose(result["phi"], v["phi"]) + + def test_roundtrip_cart3d_sph3d_phys_disp(self) -> None: + """Physical-basis round-trip Cart3D→Sph3D→Cart3D preserves components.""" + v_cart = { + "x": u.Q(jnp.array(2.0), "m/s"), + "y": u.Q(jnp.array(-1.0), "m/s"), + "z": u.Q(jnp.array(0.5), "m/s"), + } + at_cart = { + "x": u.Q(jnp.array(2.0), "m"), + "y": u.Q(jnp.array(1.0), "m"), + "z": u.Q(jnp.array(0.5), "m"), + } + + v_sph = cxr.tangent_map( + v_cart, cxc.cart3d, cxr.phys_disp, cxc.sph3d, at=at_cart, usys=usys + ) + at_sph = cxc.pt_map(at_cart, cxc.cart3d, cxc.sph3d, usys=usys) + v_back = cxr.tangent_map( + v_sph, cxc.sph3d, cxr.phys_disp, cxc.cart3d, at=at_sph, usys=usys + ) + + np.testing.assert_allclose(v_back["x"].value, v_cart["x"].value, atol=1e-6) + np.testing.assert_allclose(v_back["y"].value, v_cart["y"].value, atol=1e-6) + np.testing.assert_allclose(v_back["z"].value, v_cart["z"].value, atol=1e-6) + + def test_cconvert_7arg_with_phys_disp(self) -> None: + """7-arg cconvert path works for physical-basis tangent representations.""" + v = { + "x": u.Q(jnp.array(1.0), "m/s"), + "y": u.Q(jnp.array(0.0), "m/s"), + "z": u.Q(jnp.array(0.0), "m/s"), + } + at = { + "x": u.Q(jnp.array(1.0), "m"), + "y": u.Q(jnp.array(0.0), "m"), + "z": u.Q(jnp.array(0.0), "m"), + } + + direct = cxr.tangent_map(v, cxc.cart3d, cxr.phys_disp, cxc.sph3d, at=at) + via_cc = cxr.cconvert( + v, + cxc.cart3d, + cxr.tangent_geom, + cxr.phys_disp, + cxc.sph3d, + cxr.tangent_geom, + cxr.phys_disp, + at=at, + ) + + via_cc_cdict = cast("dict[str, Any]", via_cc) + np.testing.assert_allclose( + via_cc_cdict["r"].value, direct["r"].value, atol=1e-6 + ) + np.testing.assert_allclose( + via_cc_cdict["theta"].value, direct["theta"].value, atol=1e-6 + ) + np.testing.assert_allclose( + via_cc_cdict["phi"].value, direct["phi"].value, atol=1e-6 + ) + + +class TestTangentMapCart2dToPolar2d: + """Cart2D → Polar2D CoordinateBasis: Jacobian pushforward at (x=1, y=0). + + At (x=1, y=0), the base point is (r=1, theta=0). + + The Jacobian J = d(r,theta)/d(x,y) at this point is:: + + J = [[1, 0], + [0, 1]] + + Resulting pushforwards: + - (dx,dy)=(1,0) → (dr,dtheta)=(1,0) [radial] + - (dx,dy)=(0,1) → (dr,dtheta)=(0,1) [angular] + """ + + def test_x_direction(self) -> None: + """x-direction at (1,0) maps to radial (dr=1).""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + + def test_y_direction(self) -> None: + """y-direction at (1,0) maps to angular (dtheta=1).""" + v = {"x": jnp.array(0.0), "y": jnp.array(1.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + np.testing.assert_allclose(result["r"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 1.0, atol=1e-6) + + def test_at_45_deg(self) -> None: + """At (x=1,y=1): x-hat component check (dr/dx = x/r = 1/sqrt2).""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(1.0)} + result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + np.testing.assert_allclose(result["r"], 1.0 / jnp.sqrt(2.0), atol=1e-6) + np.testing.assert_allclose(result["theta"], -1.0 / 2.0, atol=1e-6) + + +class TestTangentMapJAXCompatibility: + """tangent_map is compatible with jax.jit and jax.vmap.""" + + def test_jit(self) -> None: + """tangent_map can be JIT-compiled.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + + @jax.jit + def jitted_tangent_map(v, at): + return cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + + result = jitted_tangent_map(v, at) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + + def test_vmap(self) -> None: + """tangent_map can be vmap-ped over a batch of base points.""" + vs = {"x": jnp.ones(3), "y": jnp.zeros(3)} + ats = { + "x": jnp.array([1.0, 2.0, 3.0]), + "y": jnp.zeros(3), + } + + def single_map(v: dict[str, Any], at: dict[str, Any]) -> dict[str, Any]: + return cxr.tangent_map(v, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=at) + + batched = jax.vmap(single_map)(vs, ats) + # At y=0, any x>0: dr/dx = x/r = 1, so dr = 1 always + np.testing.assert_allclose(batched["r"], 1.0, atol=1e-6) + + +class TestTangentMapSemanticPreservation: + """tangent_map works with vel and acc representations too.""" + + def test_vel_rep(self) -> None: + """tangent_map works with coord_vel representation.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_vel, cxc.polar2d, at=at) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + + def test_acc_rep(self) -> None: + """tangent_map works with coord_acc representation.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.tangent_map(v, cxc.cart2d, cxr.coord_acc, cxc.polar2d, at=at) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) diff --git a/tests/unit/representations/test_tangent_map_properties.py b/tests/unit/representations/test_tangent_map_properties.py new file mode 100644 index 00000000..6bc5454a --- /dev/null +++ b/tests/unit/representations/test_tangent_map_properties.py @@ -0,0 +1,453 @@ +"""Tests for ``tangent_map``.""" + +__all__: tuple[str, ...] = () + +import jax.numpy as jnp +import numpy as np +from hypothesis import given, settings, strategies as st +from strategies import ( + any_angle_rad as _any_angle_rad, + any_m as _any_m, + polar_rad as _polar_rad, + pos_m as _pos_m, + v_elem as _v_elem, +) + +import unxt as u + +import coordinax.charts as cxc +import coordinax.representations as cxr + +usys = u.unitsystems.si + + +def _assert_cdict_close(got, ref, *, atol=1e-5, rtol=1e-5): + """Assert two CDicts agree component-wise (strips units if present).""" + for key in ref: + g = got[key] + r = ref[key] + # Strip Quantity wrappers to plain arrays for comparison + g_val = g.value if hasattr(g, "value") else jnp.asarray(g) + r_val = r.value if hasattr(r, "value") else jnp.asarray(r) + np.testing.assert_allclose( + np.asarray(g_val), + np.asarray(r_val), + atol=atol, + rtol=rtol, + err_msg=f"component '{key}' differs", + ) + + +# =========================================================================== +# 1. Round-trip: Cart2D ↔ Polar2D +# =========================================================================== + + +class TestTangentMapRoundTripCart2dPolar2d: + r"""Round-trip invariant: Cart2D → Polar2D → Cart2D ≈ identity. + + For any non-zero base point p and tangent vector v: + + J_{polar→cart}(p_polar) @ J_{cart→polar}(p_cart) @ v ≈ v + + Here we check this end-to-end via two successive ``tangent_map`` calls. + """ + + @given( + r_=_pos_m, + theta=_any_angle_rad, + vx=_v_elem, + vy=_v_elem, + ) + @settings(deadline=None) + def test_round_trip(self, r_, theta, vx, vy) -> None: + """Tangent map round-trip Cart2D → Polar2D → Cart2D recovers original v.""" + # Base point starting in polar (guaranteed r > 0) + p_polar = {"r": r_, "theta": theta} + p_cart = cxc.pt_map(p_polar, cxc.polar2d, cxc.cart2d) + + v_cart = {"x": jnp.array(vx), "y": jnp.array(vy)} + + v_polar = cxr.tangent_map( + v_cart, cxc.cart2d, cxr.coord_disp, cxc.polar2d, at=p_cart + ) + v_back = cxr.tangent_map( + v_polar, cxc.polar2d, cxr.coord_disp, cxc.cart2d, at=p_polar + ) + + np.testing.assert_allclose( + np.asarray(v_back["x"]), + vx, + atol=1e-4, + rtol=1e-4, + ) + np.testing.assert_allclose( + np.asarray(v_back["y"]), + vy, + atol=1e-4, + rtol=1e-4, + ) + + +# =========================================================================== +# 2. Round-trip: Cart3D ↔ Sph3D +# =========================================================================== + + +class TestTangentMapRoundTripCart3dSph3d: + r"""Round-trip invariant: Cart3D → Sph3D → Cart3D ≈ identity.""" + + @given( + r_=_pos_m, + theta=_polar_rad, # away from poles + phi=_any_angle_rad, + vx=_v_elem, + vy=_v_elem, + vz=_v_elem, + ) + @settings(deadline=None) + def test_round_trip(self, r_, theta, phi, vx, vy, vz) -> None: + """Tangent map round-trip Cart3D → Sph3D → Cart3D recovers v.""" + p_sph = {"r": r_, "theta": theta, "phi": phi} + p_cart = cxc.pt_map(p_sph, cxc.sph3d, cxc.cart3d) + + v_cart = {"x": jnp.array(vx), "y": jnp.array(vy), "z": jnp.array(vz)} + + v_sph = cxr.tangent_map( + v_cart, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=p_cart + ) + v_back = cxr.tangent_map(v_sph, cxc.sph3d, cxr.coord_disp, cxc.cart3d, at=p_sph) + + np.testing.assert_allclose(np.asarray(v_back["x"]), vx, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(np.asarray(v_back["y"]), vy, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(np.asarray(v_back["z"]), vz, atol=1e-4, rtol=1e-4) + + +# =========================================================================== +# 3. Round-trip: Cart3D ↔ Cyl3D +# =========================================================================== + + +class TestTangentMapRoundTripCart3dCyl3d: + r"""Round-trip invariant: Cart3D → Cyl3D → Cart3D ≈ identity.""" + + @given( + rho=_pos_m, # cylindrical radius — must be positive + phi=_any_angle_rad, + z=_any_m, + vx=_v_elem, + vy=_v_elem, + vz=_v_elem, + ) + @settings(deadline=None) + def test_round_trip(self, rho, phi, z, vx, vy, vz) -> None: + """Tangent map round-trip Cart3D → Cyl3D → Cart3D recovers v.""" + p_cyl = {"rho": rho, "phi": phi, "z": z} + p_cart = cxc.pt_map(p_cyl, cxc.cyl3d, cxc.cart3d) + + v_cart = {"x": jnp.array(vx), "y": jnp.array(vy), "z": jnp.array(vz)} + + v_cyl = cxr.tangent_map( + v_cart, cxc.cart3d, cxr.coord_disp, cxc.cyl3d, at=p_cart + ) + v_back = cxr.tangent_map(v_cyl, cxc.cyl3d, cxr.coord_disp, cxc.cart3d, at=p_cyl) + + np.testing.assert_allclose(np.asarray(v_back["x"]), vx, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(np.asarray(v_back["y"]), vy, atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(np.asarray(v_back["z"]), vz, atol=1e-4, rtol=1e-4) + + +# =========================================================================== +# 4. Known Cyl3D examples (hand-verified) +# =========================================================================== + + +class TestTangentMapCyl3dKnownExamples: + r"""Specific known-value tests for Cyl3D that aren't in the existing suite. + + Analytical pushforward review: + + At (x=1, y=0, z=0) → (ρ=1, φ=0, z=0), J_{cart→cyl} = I, so: + v_cart → v_cyl component-wise identical. + + At (x=0, y=1, z=0) → (ρ=1, φ=π/2, z=0): + J = [[0, 1, 0], rows: (ρ, φ, z) + [-1, 0, 0], cols: (x, y, z) + [0, 0, 1]] + v_cart = (1, 0, 0) → v_cyl = (0, -1, 0) (azimuthal component flips sign) + v_cart = (0, 1, 0) → v_cyl = (1, 0, 0) (becomes radial) + v_cart = (0, 0, 1) → v_cyl = (0, 0, 1) (z is unchanged) + """ + + def test_at_x1_y0_z0_identity(self) -> None: + """At (1,0,0) the Jacobian is identity: every component is preserved.""" + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + for vx, vy, vz in [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)]: + v = {"x": jnp.array(vx), "y": jnp.array(vy), "z": jnp.array(vz)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.cyl3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["rho"], vx, atol=1e-6) + np.testing.assert_allclose(result["phi"], vy, atol=1e-6) + np.testing.assert_allclose(result["z"], vz, atol=1e-6) + + def test_x_hat_at_0_1_0_maps_to_minus_phi_hat(self) -> None: + """At (0,1,0): x̂ → (ρ=0, φ=-1, z=0) — negative azimuthal direction.""" + at = {"x": jnp.array(0.0), "y": jnp.array(1.0), "z": jnp.array(0.0)} + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.cyl3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["rho"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], -1.0, atol=1e-6) + np.testing.assert_allclose(result["z"], 0.0, atol=1e-6) + + def test_y_hat_at_0_1_0_maps_to_rho_hat(self) -> None: + """At (0,1,0): ŷ → (ρ=1, φ=0, z=0) — becomes radial.""" + at = {"x": jnp.array(0.0), "y": jnp.array(1.0), "z": jnp.array(0.0)} + v = {"x": jnp.array(0.0), "y": jnp.array(1.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.cyl3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["rho"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["z"], 0.0, atol=1e-6) + + def test_cyl_to_cart_rho_hat_at_phi0(self) -> None: + """At (ρ=1, φ=0, z=0): ρ̂ → (x=1, y=0, z=0) — becomes x̂.""" + at = {"rho": jnp.array(1.0), "phi": jnp.array(0.0), "z": jnp.array(0.0)} + v = {"rho": jnp.array(1.0), "phi": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cyl3d, cxr.coord_disp, cxc.cart3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["x"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["y"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["z"], 0.0, atol=1e-6) + + def test_cyl_phi_hat_at_phi0_maps_to_y_hat(self) -> None: + """At (R=1, φ=0, z=0): φ̂ in Cartesian is ŷ. + + ∂x/∂φ = -R sinφ = 0, ∂y/∂φ = R cosφ = 1, ∂z/∂φ = 0. + So φ̂ (i.e. v_phi=1, others=0) → (x=0, y=1, z=0). + """ + at = {"rho": jnp.array(1.0), "phi": jnp.array(0.0), "z": jnp.array(0.0)} + v = {"rho": jnp.array(0.0), "phi": jnp.array(1.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cyl3d, cxr.coord_disp, cxc.cart3d, at=at, usys=usys + ) + np.testing.assert_allclose(result["x"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["y"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["z"], 0.0, atol=1e-6) + + +# =========================================================================== +# 5. Linearity property +# =========================================================================== + + +class TestTangentMapLinearity: + r"""Linearity. + + tangent_map(a·u + b·w, ...) ≈ a·tangent_map(u,...) + b·tangent_map(w,...). + + The Jacobian is a linear map, so this is a sanity-check that the + implementation doesn't introduce nonlinear artefacts. + """ + + @given( + r_=_pos_m, + theta=_polar_rad, + phi=_any_angle_rad, + ux=_v_elem, + uy=_v_elem, + uz=_v_elem, + wx=_v_elem, + wy=_v_elem, + wz=_v_elem, + a=st.floats(min_value=-3.0, max_value=3.0, allow_nan=False, width=32), + b=st.floats(min_value=-3.0, max_value=3.0, allow_nan=False, width=32), + ) + @settings(deadline=None) + def test_linearity_cart3d_to_sph3d( + self, r_, theta, phi, ux, uy, uz, wx, wy, wz, a, b + ) -> None: + """J(a·u + b·w) ≈ a·J(u) + b·J(w) for Cart3D → Sph3D.""" + p_sph = {"r": r_, "theta": theta, "phi": phi} + p_cart = cxc.pt_map(p_sph, cxc.sph3d, cxc.cart3d) + + u_ = {"x": jnp.array(ux), "y": jnp.array(uy), "z": jnp.array(uz)} + w_ = {"x": jnp.array(wx), "y": jnp.array(wy), "z": jnp.array(wz)} + comb = { + "x": jnp.array(a * ux + b * wx), + "y": jnp.array(a * uy + b * wy), + "z": jnp.array(a * uz + b * wz), + } + + J_u = cxr.tangent_map(u_, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=p_cart) + J_w = cxr.tangent_map(w_, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=p_cart) + J_comb = cxr.tangent_map(comb, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=p_cart) + + for key in ("r", "theta", "phi"): + expected = float(a * J_u[key] + b * J_w[key]) + np.testing.assert_allclose( + float(J_comb[key]), + expected, + atol=1e-4, + rtol=1e-4, + err_msg=f"linearity failed for component '{key}'", + ) + + +# =========================================================================== +# 6. Unit-tracking: at with Quantity values +# =========================================================================== + + +class TestTangentMapWithQuantityAt: + r"""When ``at`` contains Quantity values, coordinate units are tracked correctly. + + The Jacobian entries J[j, i] carry units to_dim[j] / from_dim[i]. + When applied to a tangent vector v whose components carry units (e.g. m/s), + the result components carry the correct physical units. + + For Cart3D → Sph3D: + v["x"] in m/s → result["r"] in m/s (dimensionless x m/s) + v["x"] in m/s → result["theta"] in rad/s (rad/m x m/s) + v["x"] in m/s → result["phi"] in rad/s (rad/m x m/s) + """ + + def test_result_r_unit_matches_input_unit(self) -> None: + """J[r, *] is dimensionless, so result['r'] has same unit as v['x'].""" + at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} + v = {"x": u.Q(1.0, "m/s"), "y": u.Q(0.0, "m/s"), "z": u.Q(0.0, "m/s")} + result = cxr.tangent_map(v, cxc.cart3d, cxr.coord_vel, cxc.sph3d, at=at) + # At (1, 0, 0): only r has non-zero result: r̂ component = 1 m/s + r_result = result["r"] + assert hasattr(r_result, "unit"), "result['r'] should be a Quantity" + assert u.dimension_of(r_result) == u.dimension("speed"), ( + f"result['r'] should have speed dimensions, got {u.dimension_of(r_result)}" + ) + + def test_result_theta_unit_is_angular_velocity(self) -> None: + """J[θ, *] is rad/m, so result['theta'] has rad * (m/s) / m = rad/s.""" + at = {"x": u.Q(1.0, "m"), "y": u.Q(0.0, "m"), "z": u.Q(0.0, "m")} + # Use ŷ input which has non-zero dφ component at (1,0,0) + # At (1,0,0): dθ/dy = 0, dφ/dy = 1 rad/m → result phi = 1 rad/s for vy=1 m/s + v = {"x": u.Q(0.0, "m/s"), "y": u.Q(1.0, "m/s"), "z": u.Q(0.0, "m/s")} + result = cxr.tangent_map(v, cxc.cart3d, cxr.coord_vel, cxc.sph3d, at=at) + phi_result = result["phi"] + assert hasattr(phi_result, "unit"), "result['phi'] should be a Quantity" + assert u.dimension_of(phi_result) == u.dimension("angular frequency"), ( + "result['phi'] should have angular-velocity dimensions" + ) + + +# =========================================================================== +# 7. Integration via cconvert (tests the 7-arg tangent_map path) +# RED: cconvert calls api.tangent_map with 7 positional args but the current +# dispatch only has 4 — this will fail until tangent_map is updated. +# =========================================================================== + + +class TestTangentMapViaCconvert: + """tangent_map must be reachable via cconvert with TangentGeometry. + + ``cconvert`` internally calls + api.tangent_map(v, from_chart, from_geom, from_rep, + to_chart, to_geom, to_rep, at=at) + which is the 7-argument form. These tests verify end-to-end correctness + of that path, which requires the new dispatch signature. + """ + + def test_cart2d_polar2d_via_cconvert(self) -> None: + """Cconvert with TangentGeometry dispatches through to tangent_map.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0)} + result = cxr.cconvert( + v, + cxc.cart2d, + cxr.tangent_geom, + cxr.coord_disp, + cxc.polar2d, + cxr.tangent_geom, + cxr.coord_disp, + at=at, + ) + np.testing.assert_allclose(result["r"], 1.0, atol=1e-6) + np.testing.assert_allclose(result["theta"], 0.0, atol=1e-6) + + def test_cart3d_sph3d_via_cconvert(self) -> None: + """Cconvert for tangent gives same result as tangent_map.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + + direct = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_disp, cxc.sph3d, at=at, usys=usys + ) + via_cc = cxr.cconvert( + v, + cxc.cart3d, + cxr.tangent_geom, + cxr.coord_disp, + cxc.sph3d, + cxr.tangent_geom, + cxr.coord_disp, + at=at, + usys=usys, + ) + + for key in ("r", "theta", "phi"): + np.testing.assert_allclose( + float(via_cc[key]), float(direct[key]), atol=1e-6 + ) + + def test_cart3d_cyl3d_via_cconvert(self) -> None: + """cconvert(cart3d→cyl3d) tangent: x̂ at (0,1,0) → (ρ=0, φ=-1, z=0).""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(0.0), "y": jnp.array(1.0), "z": jnp.array(0.0)} + result = cxr.cconvert( + v, + cxc.cart3d, + cxr.tangent_geom, + cxr.coord_disp, + cxc.cyl3d, + cxr.tangent_geom, + cxr.coord_disp, + at=at, + usys=usys, + ) + np.testing.assert_allclose(result["rho"], 0.0, atol=1e-6) + np.testing.assert_allclose(result["phi"], -1.0, atol=1e-6) + np.testing.assert_allclose(result["z"], 0.0, atol=1e-6) + + +# =========================================================================== +# 8. Semantic preservation (vel / acc representations) +# =========================================================================== + + +class TestTangentMapSemanticPreservationCyl3d: + """Semantic kind (vel, acc) is preserved through Cyl3D transformations. + + The tangent_map result keys should match to_chart.components regardless + of semantic kind. This extends the existing vel/acc tests to the Cyl3D pair. + """ + + def test_coord_vel_cart3d_to_cyl3d(self) -> None: + """coord_vel converts Cart3D → Cyl3D correctly.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_vel, cxc.cyl3d, at=at, usys=usys + ) + assert set(result.keys()) == {"rho", "phi", "z"} + np.testing.assert_allclose(result["rho"], 1.0, atol=1e-6) + + def test_coord_acc_cart3d_to_cyl3d(self) -> None: + """coord_acc converts Cart3D → Cyl3D correctly.""" + v = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + at = {"x": jnp.array(1.0), "y": jnp.array(0.0), "z": jnp.array(0.0)} + result = cxr.tangent_map( + v, cxc.cart3d, cxr.coord_acc, cxc.cyl3d, at=at, usys=usys + ) + assert set(result.keys()) == {"rho", "phi", "z"}