Skip to content

Commit 42d74ab

Browse files
timtreisclaude
andcommitted
Add alignment skeleton under sq.experimental.tl
Introduces a backend-agnostic alignment API (align_obs, align_images, align_by_landmarks) with STalign and landmark backends, lazy JAX imports, and e2e tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 373228d commit 42d74ab

21 files changed

Lines changed: 2930 additions & 2 deletions

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ optional-dependencies.docs = [
102102
"sphinxcontrib-bibtex>=2.3",
103103
"sphinxcontrib-spelling>=7.6.2",
104104
]
105+
optional-dependencies.jax = [
106+
"jax",
107+
]
105108
optional-dependencies.leiden = [
106109
"leidenalg",
107110
"spatialleiden>=0.4",

src/squidpy/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66

77
from __future__ import annotations
88

9-
from . import im, pl
9+
from . import im, pl, tl
1010

11-
__all__ = ["im", "pl"]
11+
__all__ = ["im", "pl", "tl"]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
from squidpy.experimental.tl._align import align_by_landmarks, align_images, align_obs
6+
7+
if TYPE_CHECKING:
8+
from squidpy.experimental.tl._align._backends._stalign_tools import (
9+
STalignConfig,
10+
STalignPreprocessConfig,
11+
STalignPreprocessResult,
12+
STalignRegistrationConfig,
13+
STalignResult,
14+
)
15+
16+
__all__ = [
17+
"STalignConfig",
18+
"STalignPreprocessConfig",
19+
"STalignPreprocessResult",
20+
"STalignRegistrationConfig",
21+
"STalignResult",
22+
"align_by_landmarks",
23+
"align_images",
24+
"align_obs",
25+
]
26+
27+
_STALIGN_REEXPORTS = frozenset(
28+
{
29+
"STalignConfig",
30+
"STalignPreprocessConfig",
31+
"STalignPreprocessResult",
32+
"STalignRegistrationConfig",
33+
"STalignResult",
34+
}
35+
)
36+
37+
38+
def __getattr__(name: str) -> Any:
39+
"""Lazy access to the JAX-only STalign config dataclasses.
40+
41+
Importing :mod:`squidpy.experimental.tl._align._backends._stalign_tools` pulls in
42+
:mod:`jax` at module-load time, so we defer the import until the first
43+
attribute access. This preserves the lazy-import contract pinned by
44+
``test_optional_deps_not_imported_at_import_time``.
45+
"""
46+
if name in _STALIGN_REEXPORTS:
47+
try:
48+
from squidpy.experimental.tl._align._backends import _stalign_tools as _tools
49+
except ModuleNotFoundError as e:
50+
if e.name == "jax":
51+
raise ImportError(
52+
'STalign requires the optional dependency `jax`. Install it with `pip install "squidpy[jax]"`.'
53+
) from e
54+
raise
55+
return getattr(_tools, name)
56+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Alignment skeleton under :mod:`squidpy.experimental.tl`.
2+
3+
Public surface:
4+
5+
- :func:`align_obs` — align two ``obs``-level point clouds (cells / spots).
6+
- :func:`align_images` — align two raster images in :class:`spatialdata.SpatialData`.
7+
- :func:`align_by_landmarks` — closed-form fit from user-provided landmarks.
8+
9+
Optional backends (``stalign``, ``moscot``) and JAX are imported lazily — only
10+
the function call that needs them pulls them in.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from squidpy.experimental.tl._align._api import (
16+
align_by_landmarks,
17+
align_images,
18+
align_obs,
19+
)
20+
21+
__all__ = ["align_by_landmarks", "align_images", "align_obs"]
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""Public ``align_*`` orchestrators.
2+
3+
Each function is intentionally thin: resolve inputs, validate, dispatch to a
4+
backend, write the result back. All branching on argument shape lives in
5+
:mod:`._io`; all backend selection lives in :mod:`._backends`; all validation
6+
of "passed-but-unneeded" combinations lives in :mod:`._validation`.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import TYPE_CHECKING, Any, Literal
12+
13+
from squidpy.experimental.tl._align._backends import get_backend
14+
from squidpy.experimental.tl._align._io import (
15+
apply_affine_to_cs,
16+
materialise_obs,
17+
resolve_element_pair,
18+
resolve_image_pair,
19+
resolve_obs_pair,
20+
)
21+
from squidpy.experimental.tl._align._types import AffineTransform, AlignPair, AlignResult
22+
from squidpy.experimental.tl._align._validation import (
23+
ALLOWED_FLAVOURS_IMAGES,
24+
ALLOWED_FLAVOURS_OBS,
25+
ALLOWED_OUTPUT_MODES_NONOBS,
26+
ALLOWED_OUTPUT_MODES_OBS,
27+
validate_flavour,
28+
validate_key_added,
29+
validate_landmark_model,
30+
validate_landmarks,
31+
validate_output_mode,
32+
validate_required,
33+
)
34+
35+
if TYPE_CHECKING:
36+
from anndata import AnnData
37+
from spatialdata import SpatialData
38+
39+
40+
def align_obs(
41+
data_ref: AnnData | SpatialData,
42+
data_query: AnnData | SpatialData | None = None,
43+
adata_ref_name: str | None = None,
44+
adata_query_name: str | None = None,
45+
flavour: Literal["stalign", "moscot"] = "stalign",
46+
*,
47+
output_mode: Literal["affine", "obs", "return"] = "affine",
48+
key_added: str | None = None,
49+
device: Literal["cpu", "gpu"] | None = None,
50+
inplace: bool = True,
51+
**flavour_kwargs: Any,
52+
) -> AnnData | SpatialData | AlignResult | None:
53+
"""Align two ``obs``-level point clouds (cells / spots).
54+
55+
Parameters
56+
----------
57+
data_ref, data_query
58+
Either both :class:`anndata.AnnData`, both :class:`spatialdata.SpatialData`,
59+
or ``data_ref`` a SpatialData and ``data_query=None`` (in which case
60+
``adata_ref_name`` and ``adata_query_name`` select two different
61+
tables of the same SpatialData object).
62+
adata_ref_name, adata_query_name
63+
Required only when SpatialData inputs are used. Passing them with
64+
AnnData inputs raises an educational :class:`ValueError`.
65+
flavour
66+
Backend to use. ``'stalign'`` is the default LDDMM-based fit;
67+
``'moscot'`` is OT-based.
68+
output_mode
69+
How to deliver the result:
70+
71+
- ``'affine'`` — register the fitted affine on the query element via
72+
:func:`spatialdata.transformations.set_transformation`, so every
73+
element in the query coordinate system inherits the alignment.
74+
Requires the backend to produce an affine transform.
75+
- ``'obs'`` — bake the (possibly non-affine) fit into a new AnnData
76+
whose ``obsm['spatial']`` already lives in the reference coordinate
77+
system; for SpatialData inputs the new table is stored under
78+
``key_added``.
79+
- ``'return'`` — return the raw :class:`AlignResult`; no writeback.
80+
key_added
81+
Required when ``output_mode='obs'`` and inputs are SpatialData.
82+
Rejected with any other ``output_mode``.
83+
device
84+
``'cpu'``/``'gpu'`` to force a JAX device, or ``None`` to let JAX
85+
pick the default. Only consulted by JAX-backed flavours.
86+
inplace
87+
If ``True``, mutate the query container; otherwise return a copy.
88+
**flavour_kwargs
89+
Backend-specific knobs forwarded as-is to the chosen backend.
90+
"""
91+
validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_OBS, op="align_obs")
92+
validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_OBS, op="align_obs")
93+
validate_key_added(key_added, output_mode)
94+
95+
pair = resolve_obs_pair(data_ref, data_query, adata_ref_name, adata_query_name)
96+
backend = get_backend(flavour)
97+
result = backend.align_obs(pair, device=device, **flavour_kwargs)
98+
99+
return _writeback(pair, result, output_mode=output_mode, key_added=key_added, inplace=inplace)
100+
101+
102+
def align_images(
103+
sdata_ref: SpatialData,
104+
sdata_query: SpatialData | None = None,
105+
img_ref_name: str | None = None,
106+
img_query_name: str | None = None,
107+
flavour: Literal["stalign"] = "stalign",
108+
*,
109+
scale_ref: str | Literal["auto"] = "auto",
110+
scale_query: str | Literal["auto"] = "auto",
111+
output_mode: Literal["affine", "return"] = "affine",
112+
device: Literal["cpu", "gpu"] | None = None,
113+
inplace: bool = True,
114+
**flavour_kwargs: Any,
115+
) -> SpatialData | AlignResult | None:
116+
"""Align two raster images living inside :class:`spatialdata.SpatialData`.
117+
118+
Parameters
119+
----------
120+
sdata_ref, sdata_query
121+
SpatialData containers. Pass ``sdata_query=None`` to align two
122+
images of the same SpatialData against each other.
123+
img_ref_name, img_query_name
124+
Image element keys.
125+
flavour
126+
Only ``'stalign'`` is currently supported.
127+
scale_ref, scale_query
128+
Scale level for multi-scale image elements. ``'auto'`` picks the
129+
coarsest level. Single-scale images ignore this parameter.
130+
output_mode
131+
``'affine'`` registers the fit on the query image element so all of
132+
its scales inherit the transformation; ``'return'`` returns the raw
133+
:class:`AlignResult`.
134+
device, inplace, flavour_kwargs
135+
See :func:`align_obs`.
136+
"""
137+
validate_required(name="img_ref_name", value=img_ref_name, when="calling `align_images`")
138+
validate_required(name="img_query_name", value=img_query_name, when="calling `align_images`")
139+
validate_flavour(flavour, allowed=ALLOWED_FLAVOURS_IMAGES, op="align_images")
140+
validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_images")
141+
142+
pair = resolve_image_pair(
143+
sdata_ref,
144+
sdata_query,
145+
img_ref_name,
146+
img_query_name,
147+
scale_ref=scale_ref,
148+
scale_query=scale_query,
149+
)
150+
backend = get_backend(flavour)
151+
result = backend.align_images(pair, device=device, **flavour_kwargs)
152+
153+
return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace)
154+
155+
156+
def align_by_landmarks(
157+
sdata_ref: SpatialData,
158+
sdata_query: SpatialData | None = None,
159+
cs_name_ref: str | None = None,
160+
cs_name_query: str | None = None,
161+
scale_ref: str | None = None,
162+
scale_query: str | None = None,
163+
landmarks_ref: tuple[tuple[float, float], ...] | None = None,
164+
landmarks_query: tuple[tuple[float, float], ...] | None = None,
165+
*,
166+
model: Literal["similarity", "affine"] = "similarity",
167+
output_mode: Literal["affine", "return"] = "affine",
168+
inplace: bool = True,
169+
) -> SpatialData | AlignResult | None:
170+
"""Align by a closed-form fit on user-provided landmarks.
171+
172+
Pure NumPy under the hood — JAX is **not** required for this path.
173+
174+
Parameters
175+
----------
176+
sdata_ref, sdata_query
177+
SpatialData containers. Pass ``sdata_query=None`` to align two
178+
coordinate systems of the same SpatialData against each other.
179+
cs_name_ref, cs_name_query
180+
Coordinate system names.
181+
scale_ref, scale_query
182+
Optional scale identifiers used purely for landmark-extent
183+
validation: if you extracted your landmarks at a particular scale,
184+
passing the same scale here lets us catch the "wrong scale" footgun
185+
early.
186+
landmarks_ref, landmarks_query
187+
Equal-length sequences of ``(y, x)`` tuples. ``model='similarity'``
188+
needs ≥ 2 pairs, ``model='affine'`` needs ≥ 3.
189+
model
190+
``'similarity'`` (rotation + uniform scale + translation) or
191+
``'affine'`` (full 6-parameter linear).
192+
output_mode, inplace
193+
See :func:`align_obs`.
194+
"""
195+
validate_required(name="cs_name_ref", value=cs_name_ref, when="calling `align_by_landmarks`")
196+
validate_required(name="cs_name_query", value=cs_name_query, when="calling `align_by_landmarks`")
197+
validate_required(name="landmarks_ref", value=landmarks_ref, when="calling `align_by_landmarks`")
198+
validate_required(name="landmarks_query", value=landmarks_query, when="calling `align_by_landmarks`")
199+
200+
validate_output_mode(output_mode, allowed=ALLOWED_OUTPUT_MODES_NONOBS, op="align_by_landmarks")
201+
validate_landmark_model(model)
202+
203+
# We don't materialise extents here in the skeleton; backends / a future
204+
# PR can fill in the cs-extent lookup once we wire spatialdata.get_extent.
205+
ref_arr, query_arr = validate_landmarks(landmarks_ref, landmarks_query, model=model)
206+
207+
pair = resolve_element_pair(sdata_ref, sdata_query, cs_name_ref, cs_name_query)
208+
209+
from squidpy.experimental.tl._align._backends._landmark import fit_landmark_affine
210+
211+
affine = fit_landmark_affine(
212+
ref_arr,
213+
query_arr,
214+
model=model,
215+
source_cs=cs_name_query,
216+
target_cs=cs_name_ref,
217+
)
218+
result = AlignResult(transform=affine, metadata={"flavour": "landmark", "model": model})
219+
220+
return _writeback(pair, result, output_mode=output_mode, key_added=None, inplace=inplace)
221+
222+
223+
# ---------------------------------------------------------------------------
224+
# Internal: writeback dispatch
225+
# ---------------------------------------------------------------------------
226+
227+
228+
def _writeback(
229+
pair: AlignPair,
230+
result: AlignResult,
231+
*,
232+
output_mode: str,
233+
key_added: str | None,
234+
inplace: bool,
235+
) -> AnnData | SpatialData | AlignResult | None:
236+
if output_mode == "return":
237+
return result
238+
239+
if output_mode == "affine":
240+
if not isinstance(result.transform, AffineTransform):
241+
raise TypeError(
242+
f"`output_mode='affine'` requires the backend to return an AffineTransform, "
243+
f"got {type(result.transform).__name__}. Use `output_mode='obs'` (for "
244+
f"`align_obs`) or `output_mode='return'` to access non-affine fits."
245+
)
246+
return apply_affine_to_cs(pair, result.transform, inplace=inplace)
247+
248+
if output_mode == "obs":
249+
return materialise_obs(pair, result, key_added=key_added, inplace=inplace)
250+
251+
raise ValueError(f"Unknown output_mode {output_mode!r}.")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Backend dispatch for the alignment skeleton.
2+
3+
Imports of individual backends happen *inside* the dispatch branches so that
4+
``import squidpy.experimental.tl`` never pulls in ``stalign``, ``moscot``, or
5+
``jax`` transitively.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import TYPE_CHECKING
11+
12+
if TYPE_CHECKING:
13+
from squidpy.experimental.tl._align._backends._base import AlignBackend
14+
15+
16+
def get_backend(flavour: str) -> AlignBackend:
17+
"""Return a backend instance for the requested ``flavour``."""
18+
if flavour == "stalign":
19+
from squidpy.experimental.tl._align._backends._stalign import StAlignBackend
20+
21+
return StAlignBackend()
22+
if flavour == "moscot":
23+
from squidpy.experimental.tl._align._backends._moscot import MoscotBackend
24+
25+
return MoscotBackend()
26+
raise ValueError(f"Unknown alignment flavour {flavour!r}; expected 'stalign' or 'moscot'.")

0 commit comments

Comments
 (0)