Skip to content

Commit b699860

Browse files
committed
✨ feat: manifolds
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 8ac8499 commit b699860

7 files changed

Lines changed: 142 additions & 0 deletions

File tree

packages/coordinax-api/src/coordinax_api/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"project_tangent",
1717
# metrics
1818
"metric_of",
19+
# manifolds
20+
"manifold_of",
1921
# frames
2022
"frame_of",
2123
"frame_transform_op",
@@ -55,6 +57,7 @@
5557
)
5658
from ._embeddings import embed_point, embed_tangent, project_point, project_tangent
5759
from ._frames import frame_of, frame_transform_op
60+
from ._manifolds import manifold_of
5861
from ._metrics import metric_of
5962
from ._objs import vconvert
6063
from ._operators import apply_op, simplify
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Manifold API for coordinax."""
2+
3+
__all__ = ("manifold_of",)
4+
5+
from typing import Any
6+
7+
import plum
8+
9+
10+
@plum.dispatch.abstract
11+
def manifold_of(*args: Any) -> Any:
12+
"""Return the default manifold associated with the input."""
13+
raise NotImplementedError # pragma: no cover

src/coordinax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"distances",
99
"embeddings",
1010
"frames",
11+
"manifolds",
1112
"metrics",
1213
"objs",
1314
"ops",
@@ -36,6 +37,7 @@
3637
distances,
3738
embeddings,
3839
frames,
40+
manifolds,
3941
metrics,
4042
objs,
4143
ops,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""`coordinax.manifolds` module."""
2+
3+
__all__ = (
4+
"AbstractManifold",
5+
"Atlas",
6+
"Euclidean",
7+
"EuclideanAtlas",
8+
"manifold_of",
9+
)
10+
11+
from coordinax import setup_package
12+
13+
with setup_package.install_import_hook("coordinax.manifolds"):
14+
from ._src import AbstractManifold, Atlas, Euclidean, EuclideanAtlas
15+
from coordinax.api import manifold_of
16+
17+
18+
del setup_package
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Manifolds in coordinax."""
2+
3+
from .base import *
4+
from .euclidean import *
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Manifold definitions and manifold inference helpers."""
2+
3+
__all__ = (
4+
"AbstractManifold",
5+
"Atlas",
6+
)
7+
8+
import abc
9+
10+
from typing import Any, Protocol
11+
12+
import coordinax.charts as cxc
13+
import coordinax.metrics as cxm
14+
15+
16+
class Atlas(Protocol):
17+
"""Atlas protocol for manifolds."""
18+
19+
def default_chart(self) -> cxc.AbstractChart[Any, Any]:
20+
"""Return a default chart from the atlas."""
21+
...
22+
23+
# def supports(self, chart: cxc.AbstractChart[Any, Any]) -> bool:
24+
# """Return whether the atlas supports the given chart."""
25+
# ...
26+
27+
28+
class AbstractManifold(metaclass=abc.ABCMeta):
29+
"""Abstract manifold interface."""
30+
31+
dim: int
32+
"""Intrinsic dimension of the manifold."""
33+
34+
metric: cxm.AbstractMetric
35+
"""(Pseudo-)Riemannian metric on the manifold."""
36+
37+
atlas: Atlas
38+
"""Charts compatible with this manifold."""
39+
40+
@property
41+
def default_chart(self) -> cxc.AbstractChart[Any, Any]:
42+
"""Return a default chart from the atlas."""
43+
return self.atlas.default_chart()
44+
45+
# def has_chart(self, chart: cxc.AbstractChart[Any, Any], /) -> bool:
46+
# """Return whether ``chart`` belongs to this manifold atlas."""
47+
# return self.atlas.supports(chart)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Euclidean manifolds."""
2+
3+
__all__ = ("Euclidean", "EuclideanAtlas")
4+
5+
from dataclasses import dataclass
6+
7+
from typing import Any, final
8+
9+
import coordinax.charts as cxc
10+
import coordinax.metrics as cxm
11+
from .base import AbstractManifold, Atlas
12+
13+
14+
@dataclass(frozen=True, slots=True)
15+
class EuclideanAtlas(Atlas):
16+
"""Atlas for Euclidean manifolds."""
17+
18+
dim: int
19+
"""Dimension of the Euclidean manifold."""
20+
21+
def default_chart(self) -> cxc.AbstractChart[Any, Any]:
22+
# TODO: make this autodetect
23+
chart: cxc.AbstractChart[Any, Any]
24+
match self.dim:
25+
case 0:
26+
chart = cxc.cart0d
27+
case 1:
28+
chart = cxc.cart1d
29+
case 2:
30+
chart = cxc.cart2d
31+
case 3:
32+
chart = cxc.cart3d
33+
case 6:
34+
chart = cxc.poincarepolar6d
35+
case _:
36+
msg = f"Euclidean({self.dim}) is unsupported for now."
37+
raise ValueError(msg)
38+
return chart
39+
40+
# def supports(self, chart: cxc.AbstractChart[Any, Any]) -> bool:
41+
# return any(c == chart for c in self._charts)
42+
43+
44+
@final
45+
@dataclass(frozen=True, slots=True)
46+
class Euclidean(AbstractManifold):
47+
"""Euclidean manifold with identity metric."""
48+
49+
dim: int
50+
"""Intrinsic dimension of the manifold."""
51+
52+
def __init__(self, dim: int, /) -> None:
53+
object.__setattr__(self, "dim", dim)
54+
object.__setattr__(self, "metric", cxm.EuclideanMetric(dim))
55+
object.__setattr__(self, "atlas", EuclideanAtlas(self.dim))

0 commit comments

Comments
 (0)