Skip to content

Commit 1bcd264

Browse files
timtreisclaude
andcommitted
Address review: use scikit-image, vectorize TileGrid, remove double ensure_f32_2d
- Replace numba @njit convolutions with skimage.filters.sobel_h/sobel_v (tenengrad), skimage.filters.laplace (laplacian variance), and np.var (population variance). Removes numba dependency from sharpness metrics. - Vectorize TileGrid: replace itertools.product loops with numpy broadcasting for indices/bounds and shapely.box for batch polygon creation. - Remove redundant _ensure_f32_2d calls from metric registry lambdas (each metric function already handles its own input validation). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 192ef2f commit 1bcd264

4 files changed

Lines changed: 48 additions & 108 deletions

File tree

src/squidpy/experimental/im/_qc_metrics.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88

99
from squidpy.experimental.im._sharpness_metrics import (
10-
_ensure_f32_2d,
1110
_fft_high_freq_energy,
1211
_haar_wavelet_energy,
1312
_laplacian_variance,
@@ -186,9 +185,9 @@ def _tissue_fraction(block: np.ndarray) -> np.ndarray:
186185

187186
_METRIC_REGISTRY: dict[QCMetric, tuple[InputKind, MetricFn]] = {
188187
# Sharpness (grayscale)
189-
QCMetric.TENENGRAD: (InputKind.GRAYSCALE, lambda a: _tenengrad_mean(_ensure_f32_2d(a))),
190-
QCMetric.VAR_OF_LAPLACIAN: (InputKind.GRAYSCALE, lambda a: _laplacian_variance(_ensure_f32_2d(a))),
191-
QCMetric.VARIANCE: (InputKind.GRAYSCALE, lambda a: _pop_variance(_ensure_f32_2d(a))),
188+
QCMetric.TENENGRAD: (InputKind.GRAYSCALE, _tenengrad_mean),
189+
QCMetric.VAR_OF_LAPLACIAN: (InputKind.GRAYSCALE, _laplacian_variance),
190+
QCMetric.VARIANCE: (InputKind.GRAYSCALE, _pop_variance),
192191
QCMetric.FFT_HIGH_FREQ_ENERGY: (InputKind.GRAYSCALE, _fft_high_freq_energy),
193192
QCMetric.HAAR_WAVELET_ENERGY: (InputKind.GRAYSCALE, _haar_wavelet_energy),
194193
# Intensity (grayscale)

src/squidpy/experimental/im/_sharpness_metrics.py

Lines changed: 14 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import numpy as np
4-
from numba import njit
54
from scipy.fft import fft2, fftfreq
5+
from skimage.filters import laplace, sobel_h, sobel_v
66

77

88
def _ensure_f32_2d(x: np.ndarray) -> np.ndarray:
@@ -11,79 +11,25 @@ def _ensure_f32_2d(x: np.ndarray) -> np.ndarray:
1111
return np.ascontiguousarray(x.astype(np.float32, copy=False))
1212

1313

14-
@njit(cache=True, fastmath=True)
15-
def _clamp(v: int, lo: int, hi: int) -> int:
16-
if v < lo:
17-
return lo
18-
if v > hi:
19-
return hi
20-
return v
14+
def _tenengrad_mean(block: np.ndarray) -> np.ndarray:
15+
"""Mean Tenengrad energy (sum of squared Sobel gradients)."""
16+
b = _ensure_f32_2d(block)
17+
energy = sobel_h(b) ** 2 + sobel_v(b) ** 2
18+
return np.array([[float(energy.mean())]], dtype=np.float32)
2119

2220

23-
@njit(cache=True, fastmath=True)
24-
def _tenengrad_mean(block: np.ndarray) -> np.ndarray:
25-
"""Mean Tenengrad energy using Sobel 3x3."""
26-
h, w = block.shape
27-
gxk = np.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], dtype=np.float32)
28-
gyk = np.array([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]], dtype=np.float32)
29-
s = 0.0
30-
for i in range(h):
31-
for j in range(w):
32-
gx = 0.0
33-
gy = 0.0
34-
for di in range(-1, 2):
35-
for dj in range(-1, 2):
36-
ii = _clamp(i + di, 0, h - 1)
37-
jj = _clamp(j + dj, 0, w - 1)
38-
v = block[ii, jj]
39-
gx += gxk[di + 1, dj + 1] * v
40-
gy += gyk[di + 1, dj + 1] * v
41-
s += gx * gx + gy * gy
42-
mean_val = s / (h * w)
43-
return np.array([[mean_val]], dtype=np.float32)
44-
45-
46-
@njit(cache=True, fastmath=True)
4721
def _laplacian_variance(block: np.ndarray) -> np.ndarray:
4822
"""Population variance of Laplacian response."""
49-
h, w = block.shape
50-
lk = np.array([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]], dtype=np.float32)
51-
n = h * w
52-
s = 0.0
53-
s2 = 0.0
54-
for i in range(h):
55-
for j in range(w):
56-
y = 0.0
57-
for di in range(-1, 2):
58-
for dj in range(-1, 2):
59-
ii = _clamp(i + di, 0, h - 1)
60-
jj = _clamp(j + dj, 0, w - 1)
61-
y += lk[di + 1, dj + 1] * block[ii, jj]
62-
s += y
63-
s2 += y * y
64-
mean = s / n
65-
# var = E[y^2] - (E[y])^2
66-
var = (s2 / n) - (mean * mean)
67-
var_val = var if var > 0.0 else 0.0
68-
return np.array([[var_val]], dtype=np.float32)
69-
70-
71-
@njit(cache=True, fastmath=True)
23+
b = _ensure_f32_2d(block)
24+
lap = laplace(b)
25+
var_val = float(np.var(lap))
26+
return np.array([[max(var_val, 0.0)]], dtype=np.float32)
27+
28+
7229
def _pop_variance(block: np.ndarray) -> np.ndarray:
7330
"""Population variance of pixel intensities."""
74-
h, w = block.shape
75-
n = h * w
76-
s = 0.0
77-
s2 = 0.0
78-
for i in range(h):
79-
for j in range(w):
80-
v = block[i, j]
81-
s += v
82-
s2 += v * v
83-
mean = s / n
84-
var = (s2 / n) - (mean * mean)
85-
var_val = var if var > 0.0 else 0.0
86-
return np.array([[var_val]], dtype=np.float32)
31+
b = _ensure_f32_2d(block)
32+
return np.array([[float(np.var(b))]], dtype=np.float32)
8733

8834

8935
def _fft_high_freq_energy(block: np.ndarray) -> np.ndarray:

src/squidpy/experimental/im/_utils.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

3-
import itertools
43
from typing import TYPE_CHECKING, Any, Literal
54

65
import dask.array as da
76
import geopandas as gpd
87
import numpy as np
98
import xarray as xr
10-
from shapely.geometry import Polygon
9+
from shapely import box
1110
from spatialdata._logging import logger
1211
from spatialdata.models import ShapesModel
1312
from spatialdata.transformations import get_transformation, set_transformation
@@ -46,13 +45,13 @@ def __init__(
4645
total_w_needed = self.W - grid_start_x
4746
self.tiles_y = (total_h_needed + self.ty - 1) // self.ty
4847
self.tiles_x = (total_w_needed + self.tx - 1) // self.tx
49-
# Cache immutable derived values
50-
self._indices = np.array(
51-
[[iy, ix] for iy, ix in itertools.product(range(self.tiles_y), range(self.tiles_x))], dtype=int
52-
)
53-
self._names = [f"tile_x{ix}_y{iy}" for iy, ix in itertools.product(range(self.tiles_y), range(self.tiles_x))]
54-
self._bounds = self._compute_bounds()
55-
self._centroids_polys = self._compute_centroids_and_polygons()
48+
# Cache immutable derived values (vectorized)
49+
iy = np.repeat(np.arange(self.tiles_y), self.tiles_x)
50+
ix = np.tile(np.arange(self.tiles_x), self.tiles_y)
51+
self._indices = np.column_stack([iy, ix])
52+
self._names = [f"tile_x{x}_y{y}" for y, x in zip(iy, ix, strict=True)]
53+
self._bounds = self._compute_bounds(iy, ix)
54+
self._centroids, self._polys = self._compute_centroids_and_polygons()
5655

5756
def indices(self) -> np.ndarray:
5857
return self._indices
@@ -63,33 +62,29 @@ def names(self) -> list[str]:
6362
def bounds(self) -> np.ndarray:
6463
return self._bounds
6564

66-
def _compute_bounds(self) -> np.ndarray:
67-
b: list[list[int]] = []
68-
for iy, ix in itertools.product(range(self.tiles_y), range(self.tiles_x)):
69-
y0 = iy * self.ty + self.offset_y
70-
x0 = ix * self.tx + self.offset_x
71-
y1 = ((iy + 1) * self.ty + self.offset_y) if iy < self.tiles_y - 1 else self.H
72-
x1 = ((ix + 1) * self.tx + self.offset_x) if ix < self.tiles_x - 1 else self.W
73-
# Clamp bounds to image dimensions
74-
y0 = max(0, min(y0, self.H))
75-
x0 = max(0, min(x0, self.W))
76-
y1 = max(0, min(y1, self.H))
77-
x1 = max(0, min(x1, self.W))
78-
b.append([y0, x0, y1, x1])
79-
return np.array(b, dtype=int)
80-
81-
def centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]:
82-
return self._centroids_polys
83-
84-
def _compute_centroids_and_polygons(self) -> tuple[np.ndarray, list[Polygon]]:
85-
cents: list[list[float]] = []
86-
polys: list[Polygon] = []
87-
for y0, x0, y1, x1 in self._bounds:
88-
cy = (y0 + y1) / 2
89-
cx = (x0 + x1) / 2
90-
cents.append([cy, cx])
91-
polys.append(Polygon([(x0, y0), (x1, y0), (x1, y1), (x0, y1), (x0, y0)]))
92-
return np.array(cents, dtype=float), polys
65+
def _compute_bounds(self, iy: np.ndarray, ix: np.ndarray) -> np.ndarray:
66+
y0 = iy * self.ty + self.offset_y
67+
x0 = ix * self.tx + self.offset_x
68+
y1 = (iy + 1) * self.ty + self.offset_y
69+
x1 = (ix + 1) * self.tx + self.offset_x
70+
# Last row/column extends to image edge
71+
y1[iy == self.tiles_y - 1] = self.H
72+
x1[ix == self.tiles_x - 1] = self.W
73+
# Clamp to image dimensions
74+
y0 = np.clip(y0, 0, self.H)
75+
x0 = np.clip(x0, 0, self.W)
76+
y1 = np.clip(y1, 0, self.H)
77+
x1 = np.clip(x1, 0, self.W)
78+
return np.column_stack([y0, x0, y1, x1]).astype(int)
79+
80+
def centroids_and_polygons(self) -> tuple[np.ndarray, list]:
81+
return self._centroids, self._polys
82+
83+
def _compute_centroids_and_polygons(self) -> tuple[np.ndarray, list]:
84+
y0, x0, y1, x1 = self._bounds[:, 0], self._bounds[:, 1], self._bounds[:, 2], self._bounds[:, 3]
85+
centroids = np.column_stack([(y0 + y1) / 2.0, (x0 + x1) / 2.0])
86+
polys = list(box(x0, y0, x1, y1))
87+
return centroids, polys
9388

9489
def rechunk_and_pad(self, arr_yx: da.Array) -> da.Array:
9590
if arr_yx.ndim != 2:
-432 Bytes
Loading

0 commit comments

Comments
 (0)