2121
2222from __future__ import annotations
2323
24- from typing import Iterable , Optional , Sequence , Union
24+ from typing import Any , Iterable , Optional , Sequence , Union , cast
2525
2626import numpy as np
2727from numpy .lib .stride_tricks import as_strided
28- from scipy .ndimage import convolve , gaussian_filter
28+ from scipy .ndimage import convolve , find_objects , gaussian_filter
2929
3030__all__ = ["LsdExtractor" , "seg_to_lsd" ]
3131
32+ TRUNCATE = 3.0
33+
3234
3335def seg_to_lsd (
3436 label : np .ndarray ,
@@ -67,12 +69,10 @@ def seg_to_lsd(
6769def _coerce_sigma (sigma : Union [float , Sequence [float ]], ndim : int ) -> tuple :
6870 """Broadcast a scalar sigma into per-axis tuple matching ``ndim``."""
6971 if np .isscalar (sigma ):
70- return tuple (float (sigma ) for _ in range (ndim ))
71- sigma_tuple = tuple (float (v ) for v in sigma )
72+ return tuple (float (cast ( Any , sigma ) ) for _ in range (ndim ))
73+ sigma_tuple = tuple (float (v ) for v in cast ( Sequence [ float ], sigma ) )
7274 if len (sigma_tuple ) != ndim :
73- raise ValueError (
74- f"sigma length { len (sigma_tuple )} does not match label dim { ndim } "
75- )
75+ raise ValueError (f"sigma length { len (sigma_tuple )} does not match label dim { ndim } " )
7676 return sigma_tuple
7777
7878
@@ -113,11 +113,13 @@ def get_descriptors(
113113 # Trim to the 2D sigma if a 3D one was supplied.
114114 self .sigma = self .sigma [:2 ]
115115
116- voxel_size_t = tuple (1 for _ in range (dims )) if voxel_size is None else tuple (int (v ) for v in voxel_size )
116+ voxel_size_t = (
117+ tuple (1 for _ in range (dims ))
118+ if voxel_size is None
119+ else tuple (int (v ) for v in voxel_size )
120+ )
117121 if len (voxel_size_t ) != dims :
118- raise ValueError (
119- f"voxel_size length { len (voxel_size_t )} != label dim { dims } "
120- )
122+ raise ValueError (f"voxel_size length { len (voxel_size_t )} != label dim { dims } " )
121123
122124 if labels is None :
123125 labels_arr = np .unique (segmentation )
@@ -137,10 +139,61 @@ def get_descriptors(
137139 f"segmentation shape { segmentation .shape } is not divisible by "
138140 f"downsample factor { df } "
139141 )
140- sub_shape = tuple (s // df for s in segmentation .shape )
141142 sub_voxel_size = tuple (v * df for v in voxel_size_t )
142143 sub_sigma_voxel = tuple (s / v for s , v in zip (self .sigma , sub_voxel_size ))
143144
145+ if df == 1 :
146+ self ._accumulate_bbox (
147+ descriptors ,
148+ segmentation ,
149+ labels_arr ,
150+ sub_sigma_voxel ,
151+ sub_voxel_size ,
152+ components ,
153+ dims ,
154+ )
155+ else :
156+ self ._accumulate_full (
157+ descriptors ,
158+ segmentation ,
159+ labels_arr ,
160+ sub_sigma_voxel ,
161+ sub_voxel_size ,
162+ components ,
163+ df ,
164+ dims ,
165+ )
166+
167+ # Normalize to [0, 1]: mean offsets and Pearson coefficients have signed
168+ # ranges that we shift into [0, 1] for prediction.
169+ if self .mode == "gaussian" :
170+ # Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached).
171+ max_distance = np .asarray (self .sigma , dtype = np .float32 )
172+ else : # sphere
173+ max_distance = np .asarray ([0.5 * s for s in self .sigma ], dtype = np .float32 )
174+
175+ seg_mask = (segmentation != 0 ).astype (np .float32 )
176+
177+ if dims == 3 :
178+ self ._normalize_3d (descriptors , max_distance , seg_mask , components )
179+ else :
180+ self ._normalize_2d (descriptors , max_distance , seg_mask , components )
181+
182+ np .clip (descriptors , 0.0 , 1.0 , out = descriptors )
183+ return descriptors
184+
185+ def _accumulate_full (
186+ self ,
187+ descriptors : np .ndarray ,
188+ segmentation : np .ndarray ,
189+ labels_arr : np .ndarray ,
190+ sub_sigma_voxel : tuple ,
191+ sub_voxel_size : tuple ,
192+ components : Optional [str ],
193+ df : int ,
194+ dims : int ,
195+ ) -> None :
196+ sub_shape = tuple (s // df for s in segmentation .shape )
144197 coords = self ._get_or_build_coords (sub_shape , sub_voxel_size )
145198
146199 for raw_label in labels_arr :
@@ -162,23 +215,69 @@ def get_descriptors(
162215 descriptor = self ._upsample (sub_descriptor , df )
163216 descriptors += descriptor * mask
164217
165- # Normalize to [0, 1]: mean offsets and Pearson coefficients have signed
166- # ranges that we shift into [0, 1] for prediction.
167- if self .mode == "gaussian" :
168- # Farthest weighted voxel ≈ sigma (3-sigma cap is rarely reached).
169- max_distance = np .asarray (self .sigma , dtype = np .float32 )
170- else : # sphere
171- max_distance = np .asarray ([0.5 * s for s in self .sigma ], dtype = np .float32 )
172-
173- seg_mask = (segmentation != 0 ).astype (np .float32 )
174-
175- if dims == 3 :
176- self ._normalize_3d (descriptors , max_distance , seg_mask , components )
177- else :
178- self ._normalize_2d (descriptors , max_distance , seg_mask , components )
218+ def _accumulate_bbox (
219+ self ,
220+ descriptors : np .ndarray ,
221+ segmentation : np .ndarray ,
222+ labels_arr : np .ndarray ,
223+ sub_sigma_voxel : tuple ,
224+ sub_voxel_size : tuple ,
225+ components : Optional [str ],
226+ dims : int ,
227+ ) -> None :
228+ present = [int (raw_label ) for raw_label in labels_arr if int (raw_label ) != 0 ]
229+ if not present :
230+ return
179231
180- np .clip (descriptors , 0.0 , 1.0 , out = descriptors )
181- return descriptors
232+ radius = tuple (int (np .ceil (TRUNCATE * sigma )) for sigma in sub_sigma_voxel )
233+ max_label = int (segmentation .max ())
234+ use_find_objects = (
235+ np .issubdtype (segmentation .dtype , np .integer )
236+ and max_label >= 1
237+ and max_label <= max (64 , 8 * len (present ))
238+ )
239+ objects = find_objects (segmentation ) if use_find_objects else None
240+
241+ for label in present :
242+ bbox = None
243+ if objects is not None :
244+ if 1 <= label <= len (objects ):
245+ bbox = objects [label - 1 ]
246+ if bbox is None :
247+ continue
248+ else :
249+ eq = segmentation == label
250+ if not np .any (eq ):
251+ continue
252+ slices : list [slice ] = []
253+ for axis in range (dims ):
254+ other_axes = tuple (d for d in range (dims ) if d != axis )
255+ occupied = np .where (eq .any (axis = other_axes ))[0 ]
256+ if occupied .size == 0 :
257+ slices = []
258+ break
259+ slices .append (slice (int (occupied [0 ]), int (occupied [- 1 ]) + 1 ))
260+ if not slices :
261+ continue
262+ bbox = tuple (slices )
263+
264+ crop = tuple (
265+ slice (
266+ max (0 , bbox [d ].start - radius [d ]),
267+ min (segmentation .shape [d ], bbox [d ].stop + radius [d ]),
268+ )
269+ for d in range (dims )
270+ )
271+ sub = segmentation [crop ]
272+ mask = (sub == label ).astype (np .float32 )
273+ coords_local = self ._get_or_build_coords (mask .shape , sub_voxel_size )
274+ offset = np .asarray (
275+ [crop [d ].start * sub_voxel_size [d ] for d in range (dims )],
276+ dtype = np .float32 ,
277+ ).reshape ((dims ,) + (1 ,) * dims )
278+ coords_local = coords_local + offset
279+ desc = np .concatenate (self ._get_stats (coords_local , mask , sub_sigma_voxel , components ))
280+ descriptors [(slice (None ),) + crop ] += desc * mask [None ]
182281
183282 def _get_or_build_coords (self , sub_shape : tuple , sub_voxel_size : tuple ) -> np .ndarray :
184283 key = (sub_shape , sub_voxel_size )
@@ -209,14 +308,10 @@ def _get_stats(
209308 count = np .where (count == 0 , 1.0 , count )
210309
211310 # Mean (center-of-mass per voxel) along each axis.
212- mean = np .stack (
213- [self ._aggregate (masked_coords [d ], sigma_voxel ) for d in range (count_len )]
214- )
311+ mean = np .stack ([self ._aggregate (masked_coords [d ], sigma_voxel ) for d in range (count_len )])
215312 mean = mean / count
216313
217- need_mean_offset = components is None or any (
218- str (c ) in components for c in range (count_len )
219- )
314+ need_mean_offset = components is None or any (str (c ) in components for c in range (count_len ))
220315 need_cov = components is None or any (
221316 str (c ) in components for c in range (count_len , 4 * count_len - 3 )
222317 )
@@ -229,9 +324,7 @@ def _get_stats(
229324 if need_cov :
230325 coords_outer = self ._outer_product (masked_coords )
231326 entries = [0 , 4 , 8 , 1 , 2 , 5 ] if count_len == 3 else [0 , 3 , 1 ]
232- covariance = np .stack (
233- [self ._aggregate (coords_outer [d ], sigma_voxel ) for d in entries ]
234- )
327+ covariance = np .stack ([self ._aggregate (coords_outer [d ], sigma_voxel ) for d in entries ])
235328 covariance = covariance / count
236329 covariance -= self ._outer_product (mean )[entries ]
237330
@@ -275,9 +368,7 @@ def _get_stats(
275368 elif i == 9 :
276369 ret .append (count [None , :])
277370 else :
278- raise ValueError (
279- f"3D LSD components must be in 0..9, got { i } "
280- )
371+ raise ValueError (f"3D LSD components must be in 0..9, got { i } " )
281372 else : # 2D
282373 if 0 <= i < 2 :
283374 ret .append (mean_offset [[i ]])
@@ -288,16 +379,12 @@ def _get_stats(
288379 elif i == 5 :
289380 ret .append (count [None , :])
290381 else :
291- raise ValueError (
292- f"2D LSD components must be in 0..5, got { i } "
293- )
382+ raise ValueError (f"2D LSD components must be in 0..5, got { i } " )
294383 return tuple (ret )
295384
296385 def _aggregate (self , array : np .ndarray , sigma : tuple ) -> np .ndarray :
297386 if self .mode == "gaussian" :
298- return gaussian_filter (
299- array , sigma = sigma , mode = "constant" , cval = 0.0 , truncate = 3.0
300- )
387+ return gaussian_filter (array , sigma = sigma , mode = "constant" , cval = 0.0 , truncate = TRUNCATE )
301388 radius = sigma [0 ]
302389 if any (s != radius for s in sigma ):
303390 raise ValueError ("mode='sphere' requires isotropic sigma" )
@@ -306,7 +393,7 @@ def _aggregate(self, array: np.ndarray, sigma: tuple) -> np.ndarray:
306393
307394 @staticmethod
308395 def _make_sphere (radius : int ) -> np .ndarray :
309- r2 = np .arange (- radius , radius ) ** 2
396+ r2 : np . ndarray = np .arange (- radius , radius ) ** 2
310397 dist2 = r2 [:, None , None ] + r2 [:, None ] + r2
311398 return (dist2 <= radius ** 2 ).astype (np .float32 )
312399
@@ -323,6 +410,8 @@ def _upsample(array: np.ndarray, factor: int) -> np.ndarray:
323410 return array
324411 shape = array .shape
325412 stride = array .strides
413+ sh : tuple [int , ...]
414+ st : tuple [int , ...]
326415 if array .ndim == 4 :
327416 sh = (shape [0 ], shape [1 ], factor , shape [2 ], factor , shape [3 ], factor )
328417 st = (stride [0 ], stride [1 ], 0 , stride [2 ], 0 , stride [3 ], 0 )
@@ -350,9 +439,7 @@ def _normalize_3d(
350439 for slot , token in enumerate (components ):
351440 c = int (token )
352441 if 0 <= c < 3 :
353- descriptors [slot ] = (
354- descriptors [slot ] / max_distance [c ] * 0.5 + 0.5
355- ) * seg_mask
442+ descriptors [slot ] = (descriptors [slot ] / max_distance [c ] * 0.5 + 0.5 ) * seg_mask
356443 elif 6 <= c < 9 :
357444 descriptors [slot ] = (descriptors [slot ] * 0.5 + 0.5 ) * seg_mask
358445
@@ -364,17 +451,13 @@ def _normalize_2d(
364451 components : Optional [str ],
365452 ) -> None :
366453 if components is None :
367- descriptors [[0 , 1 ]] = (
368- descriptors [[0 , 1 ]] / max_distance [:, None , None ] * 0.5 + 0.5
369- )
454+ descriptors [[0 , 1 ]] = descriptors [[0 , 1 ]] / max_distance [:, None , None ] * 0.5 + 0.5
370455 descriptors [[4 ]] = descriptors [[4 ]] * 0.5 + 0.5
371456 descriptors [[0 , 1 , 4 ]] *= seg_mask
372457 return
373458 for slot , token in enumerate (components ):
374459 c = int (token )
375460 if 0 <= c < 2 :
376- descriptors [slot ] = (
377- descriptors [slot ] / max_distance [c ] * 0.5 + 0.5
378- ) * seg_mask
461+ descriptors [slot ] = (descriptors [slot ] / max_distance [c ] * 0.5 + 0.5 ) * seg_mask
379462 elif c == 4 :
380463 descriptors [slot ] = (descriptors [slot ] * 0.5 + 0.5 ) * seg_mask
0 commit comments