Skip to content

Commit 3f4ca77

Browse files
committed
✨ feat: representations
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 6920650 commit 3f4ca77

11 files changed

Lines changed: 320 additions & 8 deletions

File tree

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ pip install coordinax
4949
components.
5050
- Metrics: a bilinear form _g_ on the tangent space defining inner products and
5151
norms (Euclidean, sphere intrinsic, Minkowski).
52-
- Transformations: convert between coordinate charts and roles, including
53-
velocity and acceleration transformations
52+
- Transformations: convert between coordinate charts and representations,
53+
including velocity and acceleration transformations
5454
- Operators: frame-aware vector operations (rotation, boost, translation, etc.)
5555
- `PointedVector`: combines a vector with a reference point or frame for context
5656
- Frames: reference systems (ICRS, Galactocentric, etc.) and their
@@ -62,7 +62,7 @@ pip install coordinax
6262

6363
```
6464
import coordinax.charts as cxc # Charts
65-
import coordinax.roles as cxr # Roles
65+
import coordinax.representations as cxr # Representations
6666
import coordinax.metrics as cxm # Metric
6767
import coordinax.tangents as cxt # Transformations
6868
```
@@ -126,7 +126,7 @@ frames.
126126

127127
```
128128
import jax.numpy as jnp
129-
import coordinax.roles as cxr
129+
import coordinax.representations as cxr
130130
from coordinax.objs import Vector
131131
132132
q = {"x": u.Q(1.0, "km"), "y": u.Q(2.0, "km"), "z": u.Q(3.0, "km")}
@@ -139,7 +139,7 @@ bool(jnp.allclose(u.ustrip("km/s", v_back.data["x"]),
139139
u.ustrip("km/s", vvec.data["x"])))
140140
```
141141

142-
Different vector roles transform via different mechanisms:
142+
Different vector representations transform via different mechanisms:
143143

144144
| Role | Transformation | Requires Base Point? |
145145
| ---------- | ----------------------------------- | -------------------- |
@@ -253,9 +253,9 @@ disp_from_origin = cx.as_disp(new_pos)
253253
254254
## Metrics and Representations
255255

256-
Representations are coordinate charts; roles (PhysDisp, PhysVel, PhysAcc, ...)
257-
give vectors their physical meaning. A chart’s default metric defines how
258-
physical components are interpreted.
256+
Representations are coordinate charts; representations (PhysDisp, PhysVel,
257+
PhysAcc, ...) give vectors their physical meaning. A chart’s default metric
258+
defines how physical components are interpreted.
259259

260260
## Citation
261261

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Representations."""
2+
3+
__all__ = ("vconvert",)
4+
5+
from typing import Any
6+
7+
import plum
8+
9+
10+
@plum.dispatch.abstract
11+
def vconvert(*args: Any, **kwargs: Any) -> Any:
12+
"""Transform the current vector to the target chart.
13+
14+
This is an abstract API definition. See the main coordinax package for
15+
concrete implementations and usage examples.
16+
17+
"""
18+
raise NotImplementedError # pragma: no cover

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ ignore = [
297297
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
298298
"ARG001", # Unused function argument # TODO: resolve
299299
"B008", # Do not perform function calls in argument defaults
300+
"B024", # X is an ABC, but it has no abstract methods or properties
300301
"COM812", # <- for ruff.format
301302
"D103", # Missing docstring in public function # TODO: resolve
302303
"D105", # Missing docstring in magic method
@@ -339,6 +340,7 @@ ignore = [
339340
"coordinax.core" = "cx"
340341
"coordinax.distances" = "cxd"
341342
"coordinax.hypothesis.core" = "cxst"
343+
"coordinax.representations" = "cxr"
342344
equinox = "eqx"
343345
functools = "ft"
344346
unxt = "u"
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""`coordinax.representations` Module."""
2+
3+
__all__ = (
4+
"vconvert",
5+
# Representations
6+
"Representation",
7+
"point",
8+
# Geometric Kinds
9+
"AbstractGeometry",
10+
"PointGeometry",
11+
"point_geom",
12+
# Coordinate Bases
13+
"AbstractBasis",
14+
"NoBasis",
15+
"nobasis",
16+
# Semantic Kinds
17+
"AbstractSemanticKind",
18+
"Location",
19+
"location",
20+
)
21+
22+
from ._setup_package import install_import_hook
23+
24+
with install_import_hook("coordinax.representations"):
25+
from ._src import (
26+
AbstractBasis,
27+
AbstractGeometry,
28+
AbstractSemanticKind,
29+
Location,
30+
NoBasis,
31+
PointGeometry,
32+
Representation,
33+
location,
34+
nobasis,
35+
point,
36+
point_geom,
37+
)
38+
from coordinax.api.representations import vconvert
39+
40+
41+
del install_import_hook
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Package setup information.
2+
3+
Note that this module is NOT public API nor are any of its contents.
4+
Stability is NOT guaranteed.
5+
This module exposes package setup information for the `unxt` package.
6+
7+
"""
8+
9+
__all__: tuple[str, ...] = ("RUNTIME_TYPECHECKER", "install_import_hook")
10+
11+
import contextlib
12+
import os
13+
14+
from collections.abc import Sequence
15+
from jaxtyping import install_import_hook as _install_import_hook
16+
from typing import Any, Final, Literal
17+
18+
_RUNTIME_TYPECHECKER: str | None | Literal[False]
19+
match os.getenv("COORDINAX_ENABLE_RUNTIME_TYPECHECKING", "False"):
20+
case "False":
21+
_RUNTIME_TYPECHECKER = False
22+
case "None":
23+
_RUNTIME_TYPECHECKER = None
24+
case str() as _name:
25+
_RUNTIME_TYPECHECKER = _name
26+
27+
RUNTIME_TYPECHECKER: Final[str | None | Literal[False]] = _RUNTIME_TYPECHECKER
28+
"""Runtime type checking variable "COORDINAX_ENABLE_RUNTIME_TYPECHECKING".
29+
30+
Set to "False" to disable runtime typechecking (default).
31+
Set to "None" to only enable typechecking for `@jaxtyped`-decorated functions.
32+
Set to "beartype.beartype" to enable runtime typechecking.
33+
34+
See https://docs.kidger.site/jaxtyping/api/runtime-type-checking for more
35+
information on options.
36+
37+
38+
"""
39+
40+
41+
def install_import_hook(
42+
modules: str | Sequence[str], /
43+
) -> contextlib.AbstractContextManager[Any, None]:
44+
"""Install the jaxtyping import hook for the given modules.
45+
46+
Parameters
47+
----------
48+
modules
49+
Module name or sequence of module names to install the import hook for.
50+
51+
Returns
52+
-------
53+
contextlib.AbstractContextManager
54+
Context manager that installs the import hook on entry and removes it on exit.
55+
56+
"""
57+
return (
58+
_install_import_hook(modules, RUNTIME_TYPECHECKER)
59+
if RUNTIME_TYPECHECKER is not False
60+
else contextlib.nullcontext()
61+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Vector."""
2+
3+
from .basis import *
4+
from .core import *
5+
from .geom import *
6+
from .representations import *
7+
from .semantics import *
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Vector."""
2+
3+
__all__ = ("AbstractBasis", "NoBasis", "nobasis")
4+
5+
6+
import abc
7+
import dataclasses
8+
9+
from typing import final
10+
11+
import jax.tree_util as jtu
12+
13+
14+
@jtu.register_static
15+
class AbstractBasis(metaclass=abc.ABCMeta):
16+
pass
17+
18+
19+
@jtu.register_static
20+
@final
21+
@dataclasses.dataclass(frozen=True, slots=True)
22+
class NoBasis(AbstractBasis):
23+
"""No basis.
24+
25+
Points, and other geometric objects that do not have a vector space
26+
structure do not have a basis.
27+
28+
"""
29+
30+
31+
nobasis = NoBasis()
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Vector Conversion."""
2+
3+
__all__ = ("vconvert",)
4+
5+
import plum
6+
7+
import coordinax.api.representations as api
8+
import coordinax.charts as cxc
9+
from .basis import AbstractBasis, NoBasis
10+
from .geom import PointGeometry
11+
from .representations import Representation
12+
from .semantics import AbstractSemanticKind, Location
13+
from coordinax.internal.custom_types import CDict
14+
15+
16+
@plum.dispatch
17+
def vconvert(
18+
rep: Representation,
19+
to_chart: cxc.AbstractChart,
20+
from_chart: cxc.AbstractChart,
21+
x: CDict,
22+
) -> CDict:
23+
# redispatch on the combination of GeometryKind, Basis, and SemanticKind
24+
return api.vconvert(
25+
rep.geom_kind, rep.basis, rep.semantic_kind, to_chart, from_chart, x
26+
)
27+
28+
29+
# =======================================================================
30+
31+
32+
@plum.dispatch
33+
def vconvert(
34+
geom_kind: PointGeometry,
35+
basis: AbstractBasis,
36+
semantic_kind: AbstractSemanticKind,
37+
to_chart: cxc.AbstractChart,
38+
from_chart: cxc.AbstractChart,
39+
x: CDict,
40+
) -> CDict:
41+
msg = (
42+
"For point-role representations, vconvert is only implemented for the "
43+
"combination of PointGeometry, NoBasis, and Location semantic kind. "
44+
)
45+
raise TypeError(msg)
46+
47+
48+
@plum.dispatch
49+
def vconvert(
50+
geom_kind: PointGeometry,
51+
basis: NoBasis,
52+
semantic_kind: Location,
53+
to_chart: cxc.AbstractChart,
54+
from_chart: cxc.AbstractChart,
55+
x: CDict,
56+
) -> CDict:
57+
"""Convert a point-role representation from one chart to another."""
58+
del geom_kind, basis, semantic_kind # only used for dispatching to here.
59+
return cxc.coord_map(to_chart, from_chart, x)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Vector."""
2+
3+
__all__ = ("AbstractGeometry", "PointGeometry", "point_geom")
4+
5+
6+
import abc
7+
import dataclasses
8+
9+
from typing import final
10+
11+
import jax.tree_util as jtu
12+
13+
14+
@jtu.register_static
15+
class AbstractGeometry(metaclass=abc.ABCMeta):
16+
"""Flag for geometric role (position, velocity, acceleration, etc.)."""
17+
18+
19+
@jtu.register_static
20+
@final
21+
@dataclasses.dataclass(frozen=True, slots=True)
22+
class PointGeometry(AbstractGeometry):
23+
r"""Point geometric kind.
24+
25+
Mathematical Definition:
26+
27+
A **point** is an element of a manifold or affine space $M$. Points do not
28+
form a vector space in general; e.g., on a curved manifold you cannot
29+
meaningfully add two points.
30+
31+
- In Euclidean space $\mathbb{R}^n$, points can be identified with vectors
32+
from the origin, but this identification is basis-dependent.
33+
- On manifolds (e.g., a sphere), points are elements of the manifold and
34+
have no additive structure.
35+
- Point coordinates may have mixed dimensions (e.g., spherical: length +
36+
angles).
37+
38+
"""
39+
40+
41+
point_geom = PointGeometry()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Vector."""
2+
3+
__all__ = ("AbstractGeometry",)
4+
5+
6+
import dataclasses
7+
8+
from typing import final
9+
10+
import jax.tree_util as jtu
11+
12+
from .basis import AbstractBasis, nobasis
13+
from .geom import AbstractGeometry, point_geom
14+
from .semantics import AbstractSemanticKind, location
15+
16+
17+
@jtu.register_static
18+
@final
19+
@dataclasses.dataclass(frozen=True, slots=True)
20+
class Representation:
21+
geom_kind: AbstractGeometry
22+
basis: AbstractBasis
23+
semantic_kind: AbstractSemanticKind
24+
25+
26+
point = Representation(point_geom, nobasis, location)

0 commit comments

Comments
 (0)