Skip to content

Commit 16c576b

Browse files
committed
Flatten stalign subpackage, switch to xy convention, wire landmark backend, add blobs e2e tests
Four user-driven changes on top of the previous integration commit: 1. Flatten src/squidpy/experimental/tl/_align/_stalign/ into siblings under _align/_backends/ (_stalign_core, _stalign_helpers, _stalign_tools). The lifted Selman code now lives next to its only consumer (_backends/_stalign.py) and there's no extra package layer. 2. Switch AffineTransform / ObsDisplacement / to_spatialdata() to xy convention to match spatialdata's points API (which asserts axes == ('x', 'y') in get_transformation_between_landmarks) and obsm['spatial']. Drops the latent yx-vs-xy gotcha from materialise_obs. 3. Wire fit_landmark_affine to real solvers: - model='similarity' -> spatialdata.transformations.get_transformation_between_landmarks - model='affine' -> skimage.transform.estimate_transform('affine', ...) Both produce 3x3 homogeneous matrices in xy. apply_affine_to_cs grows a third writeback path (cs-keyed: walk every element registered in the moving cs and set_transformation on each) so align_by_landmarks can register the fit without an explicit element key. 4. Add tests/experimental/tl/test_align_blobs_e2e.py: four end-to-end tests on the spatialdata blobs() fixture (200 points). The closed-form landmark fits recover a known 12 deg rotation to 1e-6 max residual; the stalign LDDMM run reduces the residual below the no-op baseline. Tests: 33/33 passing. Lint clean.
1 parent 872dea6 commit 16c576b

12 files changed

Lines changed: 425 additions & 83 deletions

File tree

src/squidpy/experimental/tl/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from squidpy.experimental.tl._tiling_qc import calculate_tiling_qc
77

88
if TYPE_CHECKING:
9-
from squidpy.experimental.tl._align._stalign._tools import (
9+
from squidpy.experimental.tl._align._backends._stalign_tools import (
1010
STalignConfig,
1111
STalignPreprocessConfig,
1212
STalignPreprocessResult,
@@ -47,7 +47,7 @@ def __getattr__(name: str) -> Any:
4747
"""
4848
if name in _STALIGN_REEXPORTS:
4949
try:
50-
from squidpy.experimental.tl._align._stalign import _tools
50+
from squidpy.experimental.tl._align._backends import _stalign_tools as _tools
5151
except ModuleNotFoundError as e:
5252
if e.name == "jax":
5353
raise ImportError(

src/squidpy/experimental/tl/_align/_backends/_landmark.py

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
"""Closed-form landmark fit.
22
3-
Pure NumPy — no JAX, no optional deps. Used directly by
4-
:func:`squidpy.experimental.tl.align_by_landmarks`; in a later PR the same
5-
fit will also serve as the initial guess for the STalign backend.
3+
Two models, both pure NumPy / no JAX:
64
7-
The actual math (Umeyama for ``model='similarity'``, linear least-squares for
8-
``model='affine'``) lands together with the STalign body in the follow-up PR.
5+
- ``"similarity"`` (4 DOF: rotation + uniform scale + translation, plus an
6+
optional reflection check) - delegated to
7+
:func:`spatialdata.transformations.get_transformation_between_landmarks`.
8+
- ``"affine"`` (6 DOF: rotation + non-uniform scale + shear + translation) -
9+
delegated to :func:`skimage.transform.estimate_transform`, the same
10+
least-squares solver spatialdata uses internally.
11+
12+
Useful as a one-shot alignment when you already have corresponding landmarks,
13+
and as a sanity-check baseline for the much heavier STalign LDDMM solver.
914
"""
1015

1116
from __future__ import annotations
@@ -28,12 +33,74 @@ def fit_landmark_affine(
2833
) -> AffineTransform:
2934
"""Fit a 2D affine that maps ``landmarks_query`` onto ``landmarks_ref``.
3035
31-
Both inputs are ``(N, 2)`` arrays in ``(y, x)`` convention. The returned
32-
:class:`~squidpy.experimental.tl._align._types.AffineTransform` is a
33-
``(3, 3)`` homogeneous matrix in the same convention.
36+
Both inputs are ``(N, 2)`` ``(x, y)`` arrays of corresponding landmarks
37+
(the ``i``-th row of ``landmarks_query`` matches the ``i``-th row of
38+
``landmarks_ref``). ``N`` must be at least 3.
39+
40+
Parameters
41+
----------
42+
landmarks_ref, landmarks_query
43+
Corresponding landmark coordinates in ``(x, y)`` convention.
44+
model
45+
``"similarity"`` (4 DOF, via spatialdata) or ``"affine"`` (6 DOF,
46+
via skimage).
47+
source_cs, target_cs
48+
Optional coordinate-system labels stamped onto the returned
49+
:class:`AffineTransform` for traceability.
3450
"""
35-
raise NotImplementedError(
36-
"Landmark fit: TODO. Skeleton landed; the closed-form Umeyama / "
37-
"least-squares solver will land in the next PR alongside the STalign "
38-
"backend body."
39-
)
51+
from squidpy.experimental.tl._align._types import AffineTransform
52+
53+
ref = np.asarray(landmarks_ref, dtype=float)
54+
query = np.asarray(landmarks_query, dtype=float)
55+
56+
if model == "similarity":
57+
matrix = _fit_similarity_via_spatialdata(ref, query)
58+
elif model == "affine":
59+
matrix = _fit_affine_via_skimage(ref, query)
60+
else:
61+
raise ValueError(f"Unknown landmark `model={model!r}`; expected 'similarity' or 'affine'.")
62+
63+
return AffineTransform(matrix=matrix, source_cs=source_cs, target_cs=target_cs)
64+
65+
66+
def _fit_similarity_via_spatialdata(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray:
67+
"""4-DOF similarity fit, delegated to spatialdata."""
68+
from spatialdata.models import PointsModel
69+
from spatialdata.transformations import get_transformation_between_landmarks
70+
71+
refs_pts = PointsModel.parse(ref_xy)
72+
moving_pts = PointsModel.parse(query_xy)
73+
sd_transform = get_transformation_between_landmarks(refs_pts, moving_pts)
74+
return _extract_affine_matrix(sd_transform)
75+
76+
77+
def _fit_affine_via_skimage(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray:
78+
"""Full 6-DOF affine fit, delegated to skimage's least-squares solver.
79+
80+
This is what :func:`spatialdata.transformations.get_transformation_between_landmarks`
81+
uses under the hood before collapsing to a similarity; for the affine
82+
model we keep the raw matrix instead.
83+
"""
84+
from skimage.transform import estimate_transform
85+
86+
model_obj = estimate_transform("affine", src=query_xy, dst=ref_xy)
87+
return np.asarray(model_obj.params)
88+
89+
90+
def _extract_affine_matrix(sd_transform: object) -> np.ndarray:
91+
"""Pull a ``(3, 3)`` homogeneous matrix out of a spatialdata transformation.
92+
93+
:func:`get_transformation_between_landmarks` may return either a single
94+
:class:`spatialdata.transformations.Affine` or a
95+
:class:`spatialdata.transformations.Sequence` of two affines (when a
96+
reflection is detected and rolled into the fit). Use
97+
``to_affine_matrix`` to collapse either back to a single 3x3.
98+
"""
99+
from spatialdata.transformations import Affine as SDAffine
100+
from spatialdata.transformations import Sequence as SDSequence
101+
102+
if isinstance(sd_transform, SDAffine):
103+
return np.asarray(sd_transform.matrix)
104+
if isinstance(sd_transform, SDSequence):
105+
return np.asarray(sd_transform.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")))
106+
raise TypeError(f"Unexpected transformation type from spatialdata: {type(sd_transform).__name__}.")

src/squidpy/experimental/tl/_align/_backends/_stalign.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def align_obs(
4141
# _jax.require_jax.
4242
require_jax(device)
4343

44-
from squidpy.experimental.tl._align._stalign._tools import stalign_points
44+
from squidpy.experimental.tl._align._backends._stalign_tools import stalign_points
4545
from squidpy.experimental.tl._align._types import AlignResult, ObsDisplacement
4646

4747
if not isinstance(pair.ref, AnnData) or not isinstance(pair.query, AnnData):

src/squidpy/experimental/tl/_align/_stalign/_core.py renamed to src/squidpy/experimental/tl/_align/_backends/_stalign_core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Core JAX implementation for experimental STalign point registration.
22
33
Lifted byte-for-byte from scverse/squidpy#1150 (Selman Özleyen).
4-
See ``_align/_stalign/__init__.py`` for the lift note.
54
"""
65

76
from __future__ import annotations

src/squidpy/experimental/tl/_align/_stalign/_helpers.py renamed to src/squidpy/experimental/tl/_align/_backends/_stalign_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Helpers for experimental STalign point-cloud registration.
22
33
Lifted byte-for-byte from scverse/squidpy#1150 (Selman Özleyen).
4-
See ``_align/_stalign/__init__.py`` for the lift note.
54
"""
65

76
from __future__ import annotations

src/squidpy/experimental/tl/_align/_stalign/_tools.py renamed to src/squidpy/experimental/tl/_align/_backends/_stalign_tools.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Low-level point-cloud tools for experimental STalign.
22
33
Lifted from scverse/squidpy#1150 (Selman Özleyen). Only the two import paths
4-
on lines 12-20 were rewritten to point at the relocated submodules; the rest
5-
of the file is byte-for-byte identical to the upstream PR. See
6-
``_align/_stalign/__init__.py`` for the lift note.
4+
below were rewritten to point at the sibling lifted modules; the rest of the
5+
file is byte-for-byte identical to the upstream PR.
76
"""
87

98
from __future__ import annotations
@@ -15,8 +14,8 @@
1514
import numpy as np
1615
from anndata import AnnData
1716

18-
from squidpy.experimental.tl._align._stalign._core import JAX_DTYPE, lddmm, transform_points_row_col
19-
from squidpy.experimental.tl._align._stalign._helpers import (
17+
from squidpy.experimental.tl._align._backends._stalign_core import JAX_DTYPE, lddmm, transform_points_row_col
18+
from squidpy.experimental.tl._align._backends._stalign_helpers import (
2019
PointOrder,
2120
affine_from_points,
2221
extract_points,

src/squidpy/experimental/tl/_align/_io.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,23 +192,49 @@ def apply_affine_to_cs(
192192
*,
193193
inplace: bool,
194194
) -> SpatialData | AnnData | None:
195-
"""Register ``affine`` as a transformation on the query element.
196-
197-
The transformation goes onto the parent element node (so all scales of a
198-
multiscale image inherit it) under the *reference* coordinate-system name.
199-
For plain-AnnData inputs there is no SpatialData element to attach to —
200-
we apply the affine to ``query.obsm['spatial']`` in place.
195+
"""Register ``affine`` on the query side of the pair.
196+
197+
Three writeback paths, in order of specificity:
198+
199+
1. **Element-keyed**: ``pair.query_container`` and ``pair.query_element_key``
200+
are both set (e.g. ``align_obs`` / ``align_images`` resolved an explicit
201+
table or image). Register the transform on that single element so all
202+
scales / sibling tables that share its parent element node inherit it.
203+
2. **Cs-keyed**: only ``pair.query_cs`` is set (e.g. ``align_by_landmarks``
204+
resolved a coordinate system but no specific element). Walk every
205+
element that has the moving cs in its transformation graph and register
206+
the transform on each, mapping into the reference cs.
207+
3. **Plain AnnData**: no spatialdata container at all - warp
208+
``query.obsm['spatial']`` directly.
201209
"""
202-
from spatialdata.transformations import set_transformation
210+
from spatialdata.transformations import get_transformation, set_transformation
211+
212+
target_cs = affine.target_cs or pair.ref_cs or "aligned"
203213

204214
if pair.query_container is not None and pair.query_element_key is not None:
205215
sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container)
206216
element = sdata[pair.query_element_key]
207-
target_cs = affine.target_cs or pair.ref_cs or "aligned"
208217
set_transformation(element, affine.to_spatialdata(), to_coordinate_system=target_cs)
209218
return None if inplace else sdata
210219

211-
# Plain AnnData query -> warp obsm["spatial"].
220+
if pair.query_container is not None and pair.query_cs is not None:
221+
sdata = pair.query_container if inplace else _shallow_copy_sdata(pair.query_container)
222+
moving_cs = pair.query_cs
223+
sd_affine = affine.to_spatialdata()
224+
touched_any = False
225+
for _etype, _name, element in sdata._gen_elements(include_tables=False):
226+
element_transforms = get_transformation(element, get_all=True)
227+
if moving_cs not in element_transforms:
228+
continue
229+
set_transformation(element, sd_affine, to_coordinate_system=target_cs)
230+
touched_any = True
231+
if not touched_any:
232+
raise KeyError(
233+
f"No elements in the query SpatialData are registered to coordinate "
234+
f"system {moving_cs!r}; nothing to attach the alignment to."
235+
)
236+
return None if inplace else sdata
237+
212238
if isinstance(pair.query, AnnData):
213239
adata = pair.query if inplace else pair.query.copy()
214240
if "spatial" not in adata.obsm:
@@ -239,14 +265,6 @@ def materialise_obs(
239265
raise KeyError("Source AnnData has no `obsm['spatial']` to warp.")
240266

241267
src_coords = np.asarray(pair.query.obsm["spatial"])
242-
# CONVENTION NOTE: anndata's ``obsm['spatial']`` is conventionally (x, y)
243-
# in squidpy, while ``AffineTransform.matrix`` is documented as (y, x).
244-
# Today the only backend that returns a ``Transform`` for AnnData inputs
245-
# is stalign, which returns an ``ObsDisplacement`` whose ``deltas`` are
246-
# already in the same convention as ``obsm['spatial']`` - so the addition
247-
# below is correct for the displacement path. The first backend that
248-
# returns an ``AffineTransform`` for an AnnData input will need to either
249-
# swap here, or normalise the matrix to xy at the backend boundary.
250268
new_coords = result.transform.apply(src_coords)
251269

252270
# Slim copy: share X/var/obs structurally and only rewrite obsm so we

src/squidpy/experimental/tl/_align/_stalign/__init__.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

src/squidpy/experimental/tl/_align/_types.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,16 @@ class AlignPair:
4444

4545
@dataclass(frozen=True)
4646
class AffineTransform:
47-
"""A ``(3, 3)`` homogeneous affine in ``(y, x)`` convention."""
47+
"""A ``(3, 3)`` homogeneous affine in ``(x, y)`` convention.
48+
49+
This matches the coordinate axis order spatialdata uses for points -
50+
``spatialdata.transformations.get_transformation_between_landmarks``
51+
asserts ``axes == ("x", "y")`` - and the order squidpy / scanpy use for
52+
``adata.obsm["spatial"]``. Image elements are stored ``(c, y, x)`` in
53+
spatialdata, so when registering an ``AffineTransform`` on an *image*
54+
element you may need a separate matrix; this skeleton currently only
55+
deals with point coordinates.
56+
"""
4857

4958
matrix: np.ndarray
5059
source_cs: str | None = None
@@ -60,24 +69,25 @@ def to_spatialdata(self) -> Affine:
6069

6170
return Affine(
6271
self.matrix,
63-
input_axes=("y", "x"),
64-
output_axes=("y", "x"),
72+
input_axes=("x", "y"),
73+
output_axes=("x", "y"),
6574
)
6675

6776
def apply(self, coords: np.ndarray) -> np.ndarray:
68-
"""Apply the affine to an ``(N, 2)`` ``(y, x)`` coordinate array."""
77+
"""Apply the affine to an ``(N, 2)`` ``(x, y)`` coordinate array."""
6978
if coords.ndim != 2 or coords.shape[1] != 2:
7079
raise ValueError(f"Expected an (N, 2) coordinate array, got shape {coords.shape}.")
7180
return coords @ self.matrix[:2, :2].T + self.matrix[:2, 2]
7281

7382

7483
@dataclass(frozen=True)
7584
class ObsDisplacement:
76-
"""Per-obs ``(N, 2)`` displacement field.
85+
"""Per-obs ``(N, 2)`` ``(x, y)`` displacement field.
7786
7887
Used by non-affine fits (e.g. LDDMM) where a single matrix cannot
7988
represent the deformation. Displacements are added to the source
80-
observation coordinates to obtain the aligned coordinates.
89+
observation coordinates (also ``(x, y)``) to obtain the aligned
90+
coordinates.
8191
"""
8292

8393
deltas: np.ndarray

src/squidpy/experimental/tl/_align/_validation.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,14 @@ def validate_landmarks(
6262
Parameters
6363
----------
6464
landmarks_ref, landmarks_query
65-
Sequences of ``(y, x)`` tuples. Must have the same length.
65+
Sequences of ``(x, y)`` tuples. Must have the same length and at
66+
least 3 entries (the closed-form solvers under both ``model``
67+
choices need at least 3 corresponding points).
6668
model
67-
``"similarity"`` (≥ 2 landmarks) or ``"affine"`` (≥ 3 landmarks).
69+
``"similarity"`` (4 DOF, via spatialdata) or ``"affine"`` (6 DOF,
70+
via skimage's least-squares estimator).
6871
cs_ref_extent, cs_query_extent
69-
Optional ``(y_min, x_min, y_max, x_max)`` bounds of the named
72+
Optional ``(x_min, y_min, x_max, y_max)`` bounds of the named
7073
coordinate system at the requested scale. When provided, every
7174
landmark must fall inside. Catches the "I extracted these from
7275
scale0 but asked for scale2" footgun.
@@ -75,17 +78,18 @@ def validate_landmarks(
7578
query = np.asarray(landmarks_query, dtype=float)
7679

7780
if ref.ndim != 2 or ref.shape[1] != 2:
78-
raise ValueError(f"`landmarks_ref` must be a sequence of (y, x) pairs, got shape {ref.shape}.")
81+
raise ValueError(f"`landmarks_ref` must be a sequence of (x, y) pairs, got shape {ref.shape}.")
7982
if query.ndim != 2 or query.shape[1] != 2:
80-
raise ValueError(f"`landmarks_query` must be a sequence of (y, x) pairs, got shape {query.shape}.")
83+
raise ValueError(f"`landmarks_query` must be a sequence of (x, y) pairs, got shape {query.shape}.")
8184
if len(ref) != len(query):
8285
raise ValueError(
8386
f"`landmarks_ref` and `landmarks_query` must have the same length; got {len(ref)} and {len(query)}."
8487
)
8588

86-
min_landmarks = 3 if model == "affine" else 2
87-
if len(ref) < min_landmarks:
88-
raise ValueError(f"`model={model!r}` needs at least {min_landmarks} landmark pairs, got {len(ref)}.")
89+
if len(ref) < 3:
90+
raise ValueError(
91+
f"`model={model!r}` needs at least 3 landmark pairs (spatialdata requirement), got {len(ref)}."
92+
)
8993

9094
if cs_ref_extent is not None:
9195
_check_in_extent(ref, cs_ref_extent, name="landmarks_ref")
@@ -101,13 +105,13 @@ def _check_in_extent(
101105
*,
102106
name: str,
103107
) -> None:
104-
y_min, x_min, y_max, x_max = extent
105-
out_of_bounds = (points[:, 0] < y_min) | (points[:, 0] > y_max) | (points[:, 1] < x_min) | (points[:, 1] > x_max)
108+
x_min, y_min, x_max, y_max = extent
109+
out_of_bounds = (points[:, 0] < x_min) | (points[:, 0] > x_max) | (points[:, 1] < y_min) | (points[:, 1] > y_max)
106110
if out_of_bounds.any():
107111
bad = points[out_of_bounds]
108112
raise ValueError(
109113
f"{name}: {int(out_of_bounds.sum())} landmark(s) fall outside the coordinate-system "
110-
f"extent (y in [{y_min}, {y_max}], x in [{x_min}, {x_max}]). "
114+
f"extent (x in [{x_min}, {x_max}], y in [{y_min}, {y_max}]). "
111115
f"This usually means the landmarks were extracted at a different scale than the "
112116
f"one requested. First out-of-bounds point: {tuple(bad[0])}."
113117
)

0 commit comments

Comments
 (0)