Skip to content

Commit a5e4dfd

Browse files
timtreisclaude
andcommitted
Add qc_image for tile-based QC of histopathology images
Replaces the earlier qc_sharpness prototype with a general-purpose qc_image function that computes tile-based QC metrics on spatial images. Compute (sq.experimental.im.qc_image): - Tile-based metrics: sharpness (tenengrad, var_of_laplacian), intensity (brightness, entropy), staining (hematoxylin/eosin via HED deconvolution), and artifact detection (fold fraction, tissue fraction) - QCMetric enum and registry mapping each metric to its input kind and callable - Percentile-rank unfocus scoring within tissue tiles for outlier detection - Preview overlay showing flagged tiles on the image - Shared utilities in _utils.py: vectorized TileGrid (numpy + shapely.box), mask helpers, and shapes persistence (also used by make_tiles) Plot (sq.experimental.pl.qc_image): - Multi-panel summary: spatial view, KDE distribution (tissue vs background), and descriptive statistics per metric Metrics use scikit-image filters (sobel_h/v, laplace) instead of hand-rolled convolutions, and thread-safe HED caching avoids redundant deconvolution. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4a018a3 commit a5e4dfd

14 files changed

+1305
-153
lines changed

src/squidpy/experimental/im/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
detect_tissue,
88
)
99
from ._make_tiles import make_tiles, make_tiles_from_spots
10+
from ._qc_image import qc_image
11+
from ._qc_metrics import QCMetric
1012

1113
__all__ = [
1214
"BackgroundDetectionParams",
1315
"FelzenszwalbParams",
16+
"QCMetric",
1417
"WekaParams",
1518
"detect_tissue",
1619
"make_tiles",
1720
"make_tiles_from_spots",
21+
"qc_image",
1822
]

src/squidpy/experimental/im/_make_tiles.py

Lines changed: 17 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
from __future__ import annotations
22

3-
import itertools
4-
from typing import Literal
5-
6-
import dask.array as da
73
import geopandas as gpd
84
import numpy as np
95
import pandas as pd
@@ -17,101 +13,16 @@
1713
from spatialdata.transformations import get_transformation, set_transformation
1814

1915
from squidpy._utils import _yx_from_shape
20-
21-
from ._utils import _get_element_data
16+
from squidpy.experimental.im._utils import (
17+
TileGrid,
18+
_get_element_data,
19+
_get_mask_materialized,
20+
_save_tile_grid_to_shapes,
21+
)
2222

2323
__all__ = ["make_tiles", "make_tiles_from_spots"]
2424

2525

26-
class _TileGrid:
27-
"""Immutable tile grid definition with cached bounds and centroids."""
28-
29-
def __init__(
30-
self,
31-
H: int,
32-
W: int,
33-
tile_size: Literal["auto"] | tuple[int, int] = "auto",
34-
target_tiles: int = 100,
35-
offset_y: int = 0,
36-
offset_x: int = 0,
37-
):
38-
self.H = H
39-
self.W = W
40-
if tile_size == "auto":
41-
size = max(min(self.H // target_tiles, self.W // target_tiles), 100)
42-
self.ty = int(size)
43-
self.tx = int(size)
44-
else:
45-
self.ty = int(tile_size[0])
46-
self.tx = int(tile_size[1])
47-
self.offset_y = offset_y
48-
self.offset_x = offset_x
49-
# Calculate number of tiles needed to cover entire image, accounting for offset
50-
# The grid starts at offset_y, offset_x (can be negative)
51-
# We need tiles from min(0, offset_y) to at least H
52-
# So total coverage needed is from min(0, offset_y) to H
53-
grid_start_y = min(0, self.offset_y)
54-
grid_start_x = min(0, self.offset_x)
55-
total_h_needed = self.H - grid_start_y
56-
total_w_needed = self.W - grid_start_x
57-
self.tiles_y = (total_h_needed + self.ty - 1) // self.ty
58-
self.tiles_x = (total_w_needed + self.tx - 1) // self.tx
59-
# Cache immutable derived values
60-
self._indices = np.array([[iy, ix] for iy in range(self.tiles_y) for ix in range(self.tiles_x)], dtype=int)
61-
self._names = [f"tile_x{ix}_y{iy}" for iy in range(self.tiles_y) for ix in range(self.tiles_x)]
62-
self._bounds = self._compute_bounds()
63-
self._centroids_polys = self._compute_centroids_and_polygons()
64-
65-
def indices(self) -> np.ndarray:
66-
return self._indices
67-
68-
def names(self) -> list[str]:
69-
return self._names
70-
71-
def bounds(self) -> np.ndarray:
72-
return self._bounds
73-
74-
def _compute_bounds(self) -> np.ndarray:
75-
b: list[list[int]] = []
76-
for iy, ix in itertools.product(range(self.tiles_y), range(self.tiles_x)):
77-
y0 = iy * self.ty + self.offset_y
78-
x0 = ix * self.tx + self.offset_x
79-
y1 = ((iy + 1) * self.ty + self.offset_y) if iy < self.tiles_y - 1 else self.H
80-
x1 = ((ix + 1) * self.tx + self.offset_x) if ix < self.tiles_x - 1 else self.W
81-
# Clamp bounds to image dimensions
82-
y0 = max(0, min(y0, self.H))
83-
x0 = max(0, min(x0, self.W))
84-
y1 = max(0, min(y1, self.H))
85-
x1 = max(0, min(x1, self.W))
86-
b.append([y0, x0, y1, x1])
87-
return np.array(b, dtype=int)
88-
89-
def centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]:
90-
return self._centroids_polys
91-
92-
def _compute_centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]:
93-
cents: list[list[float]] = []
94-
polys: list[Polygon] = []
95-
for y0, x0, y1, x1 in self._bounds:
96-
cy = (y0 + y1) / 2
97-
cx = (x0 + x1) / 2
98-
cents.append([cy, cx])
99-
polys.append(Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1), (x0, y0)]))
100-
return np.array(cents, dtype=float), polys
101-
102-
def rechunk_and_pad(self, arr_yx: da.Array) -> da.Array:
103-
if arr_yx.ndim != 2:
104-
raise ValueError("Expected a 2D array shaped (y, x).")
105-
pad_y = self.tiles_y * self.ty - int(arr_yx.shape[0])
106-
pad_x = self.tiles_x * self.tx - int(arr_yx.shape[1])
107-
a = arr_yx.rechunk((self.ty, self.tx))
108-
return da.pad(a, ((0, pad_y), (0, pad_x)), mode="edge") if (pad_y > 0 or pad_x > 0) else a
109-
110-
def coarsen(self, arr_yx: da.Array, reduce: Literal["mean", "sum"] = "mean") -> da.Array:
111-
reducer = np.mean if reduce == "mean" else np.sum
112-
return da.coarsen(reducer, arr_yx, {0: self.ty, 1: self.tx}, trim_excess=False)
113-
114-
11526
class _SpotTileGrid:
11627
"""Tile container for Visium spots, used with ``_filter_tiles``."""
11728

@@ -204,34 +115,12 @@ def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[in
204115

205116
def _save_tiles_to_shapes(
206117
sdata: sd.SpatialData,
207-
tg: _TileGrid,
118+
tg: TileGrid,
208119
image_key: str,
209120
shapes_key: str,
210121
) -> None:
211122
"""Save a TileGrid to sdata.shapes as a GeoDataFrame."""
212-
tile_indices = tg.indices()
213-
pixel_bounds = tg.bounds()
214-
_, polys = tg.centroids_and_polygons()
215-
216-
tile_gdf = gpd.GeoDataFrame(
217-
{
218-
"tile_id": tg.names(),
219-
"tile_y": tile_indices[:, 0],
220-
"tile_x": tile_indices[:, 1],
221-
"pixel_y0": pixel_bounds[:, 0],
222-
"pixel_x0": pixel_bounds[:, 1],
223-
"pixel_y1": pixel_bounds[:, 2],
224-
"pixel_x1": pixel_bounds[:, 3],
225-
"geometry": polys,
226-
},
227-
geometry="geometry",
228-
)
229-
230-
sdata.shapes[shapes_key] = ShapesModel.parse(tile_gdf)
231-
# we know that a) the element exists and b) it has at least an Identity transformation
232-
transformations = get_transformation(sdata.images[image_key], get_all=True)
233-
set_transformation(sdata.shapes[shapes_key], transformations, set_all=True)
234-
logger.info(f"Saved tile grid as 'sdata.shapes[\"{shapes_key}\"]'")
123+
_save_tile_grid_to_shapes(sdata, tg, shapes_key, copy_transforms_from_key=image_key)
235124

236125

237126
def _save_spot_tiles_to_shapes(
@@ -366,7 +255,7 @@ def make_tiles(
366255
mask_key_for_grid = default_mask_key
367256
else:
368257
try:
369-
from ._detect_tissue import detect_tissue
258+
from squidpy.experimental.im._detect_tissue import detect_tissue
370259

371260
detect_tissue(
372261
sdata,
@@ -411,7 +300,7 @@ def make_tiles(
411300
classification_mask_key,
412301
)
413302
try:
414-
from ._detect_tissue import detect_tissue
303+
from squidpy.experimental.im._detect_tissue import detect_tissue
415304

416305
detect_tissue(
417306
sdata,
@@ -558,7 +447,7 @@ def make_tiles_from_spots(
558447
classification_mask_key,
559448
)
560449
try:
561-
from ._detect_tissue import detect_tissue
450+
from squidpy.experimental.im._detect_tissue import detect_tissue
562451

563452
detect_tissue(
564453
sdata,
@@ -633,7 +522,7 @@ def make_tiles_from_spots(
633522

634523
def _filter_tiles(
635524
sdata: sd.SpatialData,
636-
tg: _TileGrid,
525+
tg: TileGrid,
637526
image_key: str | None,
638527
*,
639528
tissue_mask_key: str | None = None,
@@ -686,7 +575,7 @@ def _filter_tiles(
686575
raise ValueError("tissue_mask_key must be provided when image_key is None.")
687576
if mask_key not in sdata.labels:
688577
raise KeyError(f"Tissue mask '{mask_key}' not found in sdata.labels.")
689-
mask = _get_mask_from_labels(sdata, mask_key, scale)
578+
mask = _get_mask_materialized(sdata, mask_key, scale)
690579
H_mask, W_mask = mask.shape
691580

692581
# Check tissue coverage for each tile
@@ -751,7 +640,7 @@ def _make_tiles(
751640
tile_size: tuple[int, int] = (224, 224),
752641
center_grid_on_tissue: bool = False,
753642
scale: str = "auto",
754-
) -> _TileGrid:
643+
) -> TileGrid:
755644
"""Construct a tile grid for an image, optionally centered on a tissue mask."""
756645
# Validate image key
757646
if image_key not in sdata.images:
@@ -764,7 +653,7 @@ def _make_tiles(
764653

765654
# Path 1: Regular grid starting from top-left
766655
if not center_grid_on_tissue or image_mask_key is None:
767-
return _TileGrid(H, W, tile_size=tile_size)
656+
return TileGrid(H, W, tile_size=tile_size)
768657

769658
# Path 2: Center grid on tissue mask centroid
770659
if image_mask_key not in sdata.labels:
@@ -806,7 +695,7 @@ def _make_tiles(
806695
mask_bool = mask > 0
807696
if not mask_bool.any():
808697
logger.warning("Mask is empty. Using regular grid starting from top-left.")
809-
return _TileGrid(H, W, tile_size=tile_size)
698+
return TileGrid(H, W, tile_size=tile_size)
810699

811700
# Calculate centroid using center of mass
812701
y_coords, x_coords = np.where(mask_bool)
@@ -821,7 +710,7 @@ def _make_tiles(
821710
offset_y = int(round(centroid_y - tile_center_y_standard))
822711
offset_x = int(round(centroid_x - tile_center_x_standard))
823712

824-
return _TileGrid(H, W, tile_size=tile_size, offset_y=offset_y, offset_x=offset_x)
713+
return TileGrid(H, W, tile_size=tile_size, offset_y=offset_y, offset_x=offset_x)
825714

826715

827716
def _get_spot_coordinates(
@@ -877,27 +766,3 @@ def _derive_tile_size_from_spots(coords: np.ndarray) -> tuple[int, int]:
877766
)
878767
side = max(1, int(np.floor(row_spacing)))
879768
return side, side
880-
881-
882-
def _get_mask_from_labels(sdata: sd.SpatialData, mask_key: str, scale: str) -> np.ndarray:
883-
"""Extract a 2D mask array from ``sdata.labels`` at the requested scale."""
884-
if mask_key not in sdata.labels:
885-
raise KeyError(f"Mask key '{mask_key}' not found in sdata.labels")
886-
887-
label_node = sdata.labels[mask_key]
888-
mask_da = _get_element_data(label_node, scale, "label", mask_key)
889-
890-
if is_dask_collection(mask_da):
891-
mask_da = mask_da.compute()
892-
893-
if isinstance(mask_da, xr.DataArray):
894-
mask = np.asarray(mask_da.data)
895-
else:
896-
mask = np.asarray(mask_da)
897-
898-
if mask.ndim > 2:
899-
mask = mask.squeeze()
900-
if mask.ndim != 2:
901-
raise ValueError(f"Expected 2D mask with shape (y, x), got shape {mask.shape}")
902-
903-
return mask

0 commit comments

Comments
 (0)