Skip to content

Commit 8ceff40

Browse files
authored
Merge pull request #212 from PytorchConnectomics/worktree-bridge-cse_01VmZ5vBAYFT9vQkZ5YsMSky
lsd: bounding-box accumulation for ~19x faster LSD target computation
2 parents 9eb6ac2 + 7be7853 commit 8ceff40

2 files changed

Lines changed: 366 additions & 56 deletions

File tree

connectomics/data/processing/lsd.py

Lines changed: 139 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121

2222
from __future__ import annotations
2323

24-
from typing import Iterable, Optional, Sequence, Union
24+
from typing import Any, Iterable, Optional, Sequence, Union, cast
2525

2626
import numpy as np
2727
from numpy.lib.stride_tricks import as_strided
28-
from scipy.ndimage import convolve, gaussian_filter
28+
from scipy.ndimage import convolve, find_objects, gaussian_filter
2929

3030
__all__ = ["LsdExtractor", "seg_to_lsd"]
3131

32+
TRUNCATE = 3.0
33+
3234

3335
def seg_to_lsd(
3436
label: np.ndarray,
@@ -67,12 +69,10 @@ def seg_to_lsd(
6769
def _coerce_sigma(sigma: Union[float, Sequence[float]], ndim: int) -> tuple:
6870
"""Broadcast a scalar sigma into per-axis tuple matching ``ndim``."""
6971
if np.isscalar(sigma):
70-
return tuple(float(sigma) for _ in range(ndim))
71-
sigma_tuple = tuple(float(v) for v in sigma)
72+
return tuple(float(cast(Any, sigma)) for _ in range(ndim))
73+
sigma_tuple = tuple(float(v) for v in cast(Sequence[float], sigma))
7274
if len(sigma_tuple) != ndim:
73-
raise ValueError(
74-
f"sigma length {len(sigma_tuple)} does not match label dim {ndim}"
75-
)
75+
raise ValueError(f"sigma length {len(sigma_tuple)} does not match label dim {ndim}")
7676
return sigma_tuple
7777

7878

@@ -113,11 +113,13 @@ def get_descriptors(
113113
# Trim to the 2D sigma if a 3D one was supplied.
114114
self.sigma = self.sigma[:2]
115115

116-
voxel_size_t = tuple(1 for _ in range(dims)) if voxel_size is None else tuple(int(v) for v in voxel_size)
116+
voxel_size_t = (
117+
tuple(1 for _ in range(dims))
118+
if voxel_size is None
119+
else tuple(int(v) for v in voxel_size)
120+
)
117121
if len(voxel_size_t) != dims:
118-
raise ValueError(
119-
f"voxel_size length {len(voxel_size_t)} != label dim {dims}"
120-
)
122+
raise ValueError(f"voxel_size length {len(voxel_size_t)} != label dim {dims}")
121123

122124
if labels is None:
123125
labels_arr = np.unique(segmentation)
@@ -137,10 +139,61 @@ def get_descriptors(
137139
f"segmentation shape {segmentation.shape} is not divisible by "
138140
f"downsample factor {df}"
139141
)
140-
sub_shape = tuple(s // df for s in segmentation.shape)
141142
sub_voxel_size = tuple(v * df for v in voxel_size_t)
142143
sub_sigma_voxel = tuple(s / v for s, v in zip(self.sigma, sub_voxel_size))
143144

145+
if df == 1:
146+
self._accumulate_bbox(
147+
descriptors,
148+
segmentation,
149+
labels_arr,
150+
sub_sigma_voxel,
151+
sub_voxel_size,
152+
components,
153+
dims,
154+
)
155+
else:
156+
self._accumulate_full(
157+
descriptors,
158+
segmentation,
159+
labels_arr,
160+
sub_sigma_voxel,
161+
sub_voxel_size,
162+
components,
163+
df,
164+
dims,
165+
)
166+
167+
# Normalize to [0, 1]: mean offsets and Pearson coefficients have signed
168+
# ranges that we shift into [0, 1] for prediction.
169+
if self.mode == "gaussian":
170+
# Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached).
171+
max_distance = np.asarray(self.sigma, dtype=np.float32)
172+
else: # sphere
173+
max_distance = np.asarray([0.5 * s for s in self.sigma], dtype=np.float32)
174+
175+
seg_mask = (segmentation != 0).astype(np.float32)
176+
177+
if dims == 3:
178+
self._normalize_3d(descriptors, max_distance, seg_mask, components)
179+
else:
180+
self._normalize_2d(descriptors, max_distance, seg_mask, components)
181+
182+
np.clip(descriptors, 0.0, 1.0, out=descriptors)
183+
return descriptors
184+
185+
def _accumulate_full(
186+
self,
187+
descriptors: np.ndarray,
188+
segmentation: np.ndarray,
189+
labels_arr: np.ndarray,
190+
sub_sigma_voxel: tuple,
191+
sub_voxel_size: tuple,
192+
components: Optional[str],
193+
df: int,
194+
dims: int,
195+
) -> None:
196+
sub_shape = tuple(s // df for s in segmentation.shape)
144197
coords = self._get_or_build_coords(sub_shape, sub_voxel_size)
145198

146199
for raw_label in labels_arr:
@@ -162,23 +215,69 @@ def get_descriptors(
162215
descriptor = self._upsample(sub_descriptor, df)
163216
descriptors += descriptor * mask
164217

165-
# Normalize to [0, 1]: mean offsets and Pearson coefficients have signed
166-
# ranges that we shift into [0, 1] for prediction.
167-
if self.mode == "gaussian":
168-
# Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached).
169-
max_distance = np.asarray(self.sigma, dtype=np.float32)
170-
else: # sphere
171-
max_distance = np.asarray([0.5 * s for s in self.sigma], dtype=np.float32)
172-
173-
seg_mask = (segmentation != 0).astype(np.float32)
174-
175-
if dims == 3:
176-
self._normalize_3d(descriptors, max_distance, seg_mask, components)
177-
else:
178-
self._normalize_2d(descriptors, max_distance, seg_mask, components)
218+
def _accumulate_bbox(
219+
self,
220+
descriptors: np.ndarray,
221+
segmentation: np.ndarray,
222+
labels_arr: np.ndarray,
223+
sub_sigma_voxel: tuple,
224+
sub_voxel_size: tuple,
225+
components: Optional[str],
226+
dims: int,
227+
) -> None:
228+
present = [int(raw_label) for raw_label in labels_arr if int(raw_label) != 0]
229+
if not present:
230+
return
179231

180-
np.clip(descriptors, 0.0, 1.0, out=descriptors)
181-
return descriptors
232+
radius = tuple(int(np.ceil(TRUNCATE * sigma)) for sigma in sub_sigma_voxel)
233+
max_label = int(segmentation.max())
234+
use_find_objects = (
235+
np.issubdtype(segmentation.dtype, np.integer)
236+
and max_label >= 1
237+
and max_label <= max(64, 8 * len(present))
238+
)
239+
objects = find_objects(segmentation) if use_find_objects else None
240+
241+
for label in present:
242+
bbox = None
243+
if objects is not None:
244+
if 1 <= label <= len(objects):
245+
bbox = objects[label - 1]
246+
if bbox is None:
247+
continue
248+
else:
249+
eq = segmentation == label
250+
if not np.any(eq):
251+
continue
252+
slices: list[slice] = []
253+
for axis in range(dims):
254+
other_axes = tuple(d for d in range(dims) if d != axis)
255+
occupied = np.where(eq.any(axis=other_axes))[0]
256+
if occupied.size == 0:
257+
slices = []
258+
break
259+
slices.append(slice(int(occupied[0]), int(occupied[-1]) + 1))
260+
if not slices:
261+
continue
262+
bbox = tuple(slices)
263+
264+
crop = tuple(
265+
slice(
266+
max(0, bbox[d].start - radius[d]),
267+
min(segmentation.shape[d], bbox[d].stop + radius[d]),
268+
)
269+
for d in range(dims)
270+
)
271+
sub = segmentation[crop]
272+
mask = (sub == label).astype(np.float32)
273+
coords_local = self._get_or_build_coords(mask.shape, sub_voxel_size)
274+
offset = np.asarray(
275+
[crop[d].start * sub_voxel_size[d] for d in range(dims)],
276+
dtype=np.float32,
277+
).reshape((dims,) + (1,) * dims)
278+
coords_local = coords_local + offset
279+
desc = np.concatenate(self._get_stats(coords_local, mask, sub_sigma_voxel, components))
280+
descriptors[(slice(None),) + crop] += desc * mask[None]
182281

183282
def _get_or_build_coords(self, sub_shape: tuple, sub_voxel_size: tuple) -> np.ndarray:
184283
key = (sub_shape, sub_voxel_size)
@@ -209,14 +308,10 @@ def _get_stats(
209308
count = np.where(count == 0, 1.0, count)
210309

211310
# Mean (center-of-mass per voxel) along each axis.
212-
mean = np.stack(
213-
[self._aggregate(masked_coords[d], sigma_voxel) for d in range(count_len)]
214-
)
311+
mean = np.stack([self._aggregate(masked_coords[d], sigma_voxel) for d in range(count_len)])
215312
mean = mean / count
216313

217-
need_mean_offset = components is None or any(
218-
str(c) in components for c in range(count_len)
219-
)
314+
need_mean_offset = components is None or any(str(c) in components for c in range(count_len))
220315
need_cov = components is None or any(
221316
str(c) in components for c in range(count_len, 4 * count_len - 3)
222317
)
@@ -229,9 +324,7 @@ def _get_stats(
229324
if need_cov:
230325
coords_outer = self._outer_product(masked_coords)
231326
entries = [0, 4, 8, 1, 2, 5] if count_len == 3 else [0, 3, 1]
232-
covariance = np.stack(
233-
[self._aggregate(coords_outer[d], sigma_voxel) for d in entries]
234-
)
327+
covariance = np.stack([self._aggregate(coords_outer[d], sigma_voxel) for d in entries])
235328
covariance = covariance / count
236329
covariance -= self._outer_product(mean)[entries]
237330

@@ -275,9 +368,7 @@ def _get_stats(
275368
elif i == 9:
276369
ret.append(count[None, :])
277370
else:
278-
raise ValueError(
279-
f"3D LSD components must be in 0..9, got {i}"
280-
)
371+
raise ValueError(f"3D LSD components must be in 0..9, got {i}")
281372
else: # 2D
282373
if 0 <= i < 2:
283374
ret.append(mean_offset[[i]])
@@ -288,16 +379,12 @@ def _get_stats(
288379
elif i == 5:
289380
ret.append(count[None, :])
290381
else:
291-
raise ValueError(
292-
f"2D LSD components must be in 0..5, got {i}"
293-
)
382+
raise ValueError(f"2D LSD components must be in 0..5, got {i}")
294383
return tuple(ret)
295384

296385
def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray:
297386
if self.mode == "gaussian":
298-
return gaussian_filter(
299-
array, sigma=sigma, mode="constant", cval=0.0, truncate=3.0
300-
)
387+
return gaussian_filter(array, sigma=sigma, mode="constant", cval=0.0, truncate=TRUNCATE)
301388
radius = sigma[0]
302389
if any(s != radius for s in sigma):
303390
raise ValueError("mode='sphere' requires isotropic sigma")
@@ -306,7 +393,7 @@ def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray:
306393

307394
@staticmethod
308395
def _make_sphere(radius: int) -> np.ndarray:
309-
r2 = np.arange(-radius, radius) ** 2
396+
r2: np.ndarray = np.arange(-radius, radius) ** 2
310397
dist2 = r2[:, None, None] + r2[:, None] + r2
311398
return (dist2 <= radius**2).astype(np.float32)
312399

@@ -323,6 +410,8 @@ def _upsample(array: np.ndarray, factor: int) -> np.ndarray:
323410
return array
324411
shape = array.shape
325412
stride = array.strides
413+
sh: tuple[int, ...]
414+
st: tuple[int, ...]
326415
if array.ndim == 4:
327416
sh = (shape[0], shape[1], factor, shape[2], factor, shape[3], factor)
328417
st = (stride[0], stride[1], 0, stride[2], 0, stride[3], 0)
@@ -350,9 +439,7 @@ def _normalize_3d(
350439
for slot, token in enumerate(components):
351440
c = int(token)
352441
if 0 <= c < 3:
353-
descriptors[slot] = (
354-
descriptors[slot] / max_distance[c] * 0.5 + 0.5
355-
) * seg_mask
442+
descriptors[slot] = (descriptors[slot] / max_distance[c] * 0.5 + 0.5) * seg_mask
356443
elif 6 <= c < 9:
357444
descriptors[slot] = (descriptors[slot] * 0.5 + 0.5) * seg_mask
358445

@@ -364,17 +451,13 @@ def _normalize_2d(
364451
components: Optional[str],
365452
) -> None:
366453
if components is None:
367-
descriptors[[0, 1]] = (
368-
descriptors[[0, 1]] / max_distance[:, None, None] * 0.5 + 0.5
369-
)
454+
descriptors[[0, 1]] = descriptors[[0, 1]] / max_distance[:, None, None] * 0.5 + 0.5
370455
descriptors[[4]] = descriptors[[4]] * 0.5 + 0.5
371456
descriptors[[0, 1, 4]] *= seg_mask
372457
return
373458
for slot, token in enumerate(components):
374459
c = int(token)
375460
if 0 <= c < 2:
376-
descriptors[slot] = (
377-
descriptors[slot] / max_distance[c] * 0.5 + 0.5
378-
) * seg_mask
461+
descriptors[slot] = (descriptors[slot] / max_distance[c] * 0.5 + 0.5) * seg_mask
379462
elif c == 4:
380463
descriptors[slot] = (descriptors[slot] * 0.5 + 0.5) * seg_mask

0 commit comments

Comments
 (0)