1313 remove_small_holes ,
1414)
1515
16+ import cc3d
17+
1618from .bbox_processor import BBoxInstanceProcessor , BBoxProcessorConfig
1719from .quantize import energy_quantize
1820
2123 "edt_instance" ,
2224 "distance_transform" ,
2325 "skeleton_aware_distance_transform" ,
26+ "precompute_sdt_volume" ,
2427 "smooth_edge" ,
2528 "signed_distance_transform" ,
2629]
@@ -283,11 +286,11 @@ def signed_distance_transform(
283286def skeleton_aware_distance_transform (
284287 label : np .ndarray ,
285288 bg_value : float = - 1.0 ,
286- relabel : bool = True ,
289+ relabel : bool = False ,
287290 padding : bool = False ,
288291 resolution : Tuple [float ] = (1.0 , 1.0 , 1.0 ),
289292 alpha : float = 0.8 ,
290- smooth : bool = True ,
293+ smooth : bool = False ,
291294 smooth_skeleton_only : bool = True ,
292295):
293296 """Skeleton-based distance transform (SDT).
@@ -296,8 +299,9 @@ def skeleton_aware_distance_transform(
296299 Distance Transform." International Conference on Medical Image Computing and
297300 Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.
298301
299- Refactored to use BBoxInstanceProcessor for cleaner code and consistency.
300- Uses kimimaro for fast skeletonization (10-100x faster than scikit-image).
302+ Uses batch kimimaro skeletonization: all instances are skeletonized in a single
303+ call with automatic parallelism, then per-instance EDT is computed via
304+ BBoxInstanceProcessor.
301305
302306 Args:
303307 label: Instance segmentation (H, W) or (D, H, W)
@@ -306,7 +310,8 @@ def skeleton_aware_distance_transform(
306310 padding: Whether to pad before computing distance
307311 resolution: Voxel resolution for anisotropic data (z, y, x)
308312 alpha: Skeleton influence exponent (higher = stronger skeleton influence)
309- smooth: Whether to smooth edges before skeletonization
313+ smooth: Whether to smooth edges before skeletonization (default False;
314+ adds ~20% overhead with marginal quality impact when using kimimaro)
310315 smooth_skeleton_only: Only smooth skeleton mask (not entire object)
311316
312317 Returns:
@@ -318,24 +323,33 @@ def skeleton_aware_distance_transform(
318323 if np .sum (label > 0 ) == 0 :
319324 return np .full (label .shape , bg_value , dtype = np .float32 )
320325
321- # Configure bbox processor
326+ # 1. Relabel outside processor so we can batch-skeletonize.
327+ if relabel :
328+ label = cc3d .connected_components (label , connectivity = 6 )
329+
330+ # 2. Batch skeletonize all instances in one call (parallel across instances).
331+ skeleton_vertices = _batch_skeletonize (label , resolution )
332+
333+ # 3. Per-instance EDT using BBoxProcessor (skeletons already computed).
334+ # Padding coordinate offset: if padding is enabled, the processor pads the
335+ # label internally, shifting coordinates by pad_size. We account for this
336+ # when translating skeleton vertices to bbox-local coordinates.
337+ pad_offset = 2 if padding else 0
338+
322339 config = BBoxProcessorConfig (
323340 bg_value = bg_value ,
324- relabel = relabel ,
341+ relabel = False , # already relabeled above
325342 padding = padding ,
326343 pad_size = 2 ,
327344 bbox_relax = 2 ,
328345 combine_mode = "max" ,
329346 )
330347
331- # Define per-instance skeleton EDT computation
332348 def compute_skeleton_edt (
333349 label_crop : np .ndarray , instance_id : int , bbox : Tuple [slice , ...], context : Dict
334350 ) -> Optional [np .ndarray ]:
335351 """Compute skeleton-aware EDT for a single instance within bbox."""
336- # Extract and clean mask
337352 temp2 = remove_small_holes (label_crop == instance_id , 16 , connectivity = 1 )
338-
339353 if not temp2 .any ():
340354 return None
341355
@@ -351,10 +365,15 @@ def compute_skeleton_edt(
351365 binary = binary_smooth .astype (bool )
352366 temp2 = binary
353367
354- # Skeletonize using kimimaro
355- skeleton_mask = _skeletonize_instance (label_crop , instance_id , context ["resolution" ])
368+ # Look up pre-computed skeleton and translate to bbox-local coordinates.
369+ skeleton_mask = _skeleton_vertices_to_mask (
370+ context ["skeleton_vertices" ].get (instance_id ),
371+ label_crop .shape ,
372+ bbox ,
373+ context ["pad_offset" ],
374+ )
356375
357- # Fallback to regular EDT if skeletonization fails
376+ # Fallback to regular EDT if skeletonization failed for this instance.
358377 if skeleton_mask is None or not skeleton_mask .any ():
359378 boundary_edt = distance_transform_edt (temp2 , context ["resolution" ])
360379 edt_max = boundary_edt .max ()
@@ -367,70 +386,143 @@ def compute_skeleton_edt(
367386 skeleton_edt = distance_transform_edt (~ skeleton_mask , context ["resolution" ])
368387 boundary_edt = distance_transform_edt (temp2 , context ["resolution" ])
369388
370- # Normalized energy
371389 energy = boundary_edt / (skeleton_edt + boundary_edt + eps )
372390 energy = energy ** context ["alpha" ]
373391
374392 return energy * temp2 .astype (np .float32 )
375393
376- # Process all instances
377394 processor = BBoxInstanceProcessor (config )
378395 return processor .process (
379396 label ,
380397 compute_skeleton_edt ,
398+ skeleton_vertices = skeleton_vertices ,
399+ pad_offset = pad_offset ,
381400 resolution = resolution ,
382401 alpha = alpha ,
383402 smooth = smooth ,
384403 smooth_skeleton_only = smooth_skeleton_only ,
385404 )
386405
387406
388- def _skeletonize_instance (
389- label_crop : np .ndarray , instance_id : int , resolution : Tuple [float , ...]
390- ) -> Optional [np .ndarray ]:
391- """Helper function to skeletonize a single instance using kimimaro.
392-
393- Args:
394- label_crop: Cropped label array containing the instance
395- instance_id: ID of the instance to skeletonize
396- resolution: Voxel resolution for anisotropic data
407+ def _batch_skeletonize (
408+ label : np .ndarray , resolution : Tuple [float , ...]
409+ ) -> Dict [int , np .ndarray ]:
410+ """Skeletonize all instances in one kimimaro call.
397411
398412 Returns:
399- Binary skeleton mask, or None if skeletonization fails
413+ Dict mapping instance_id → (N, ndim) int array of vertex coordinates
414+ in the input label's coordinate system.
400415 """
401- instance_label = np .where (label_crop == instance_id , 1 , 0 ).astype (np .uint32 )
402-
403416 try :
404417 skeletons = kimimaro .skeletonize (
405- instance_label ,
418+ label . astype ( np . uint32 ) ,
406419 anisotropy = resolution ,
407420 fix_branching = False ,
408421 fix_borders = False ,
409422 dust_threshold = 5 ,
410- parallel = 1 ,
423+ parallel = 0 , # auto-detect cores
411424 progress = False ,
412425 )
426+ except Exception :
427+ return {}
413428
414- if 1 in skeletons and len (skeletons [1 ].vertices ) > 0 :
415- skeleton_mask = np .zeros (label_crop .shape , dtype = bool )
416- vertices = skeletons [1 ].vertices .astype (int )
417-
418- # Filter valid vertices
419- valid_mask = np .all (
420- (vertices >= 0 ) & (vertices < np .array (skeleton_mask .shape )), axis = 1
421- )
422- valid_vertices = vertices [valid_mask ]
423-
424- if len (valid_vertices ) > 0 :
425- if label_crop .ndim == 3 :
426- skeleton_mask [
427- valid_vertices [:, 0 ], valid_vertices [:, 1 ], valid_vertices [:, 2 ]
428- ] = True
429- else :
430- skeleton_mask [valid_vertices [:, 0 ], valid_vertices [:, 1 ]] = True
431- return skeleton_mask
429+ result = {}
430+ for inst_id , skel in skeletons .items ():
431+ if len (skel .vertices ) > 0 :
432+ result [inst_id ] = skel .vertices .astype (int )
433+ return result
432434
433- except Exception :
434- pass
435435
436- return None
436+ def _skeleton_vertices_to_mask (
437+ vertices : Optional [np .ndarray ],
438+ crop_shape : Tuple [int , ...],
439+ bbox : Tuple [slice , ...],
440+ pad_offset : int ,
441+ ) -> Optional [np .ndarray ]:
442+ """Convert skeleton vertices (full-volume coords) to a binary mask in bbox-local coords.
443+
444+ Args:
445+ vertices: (N, ndim) vertex coordinates in the original (unpadded) label space,
446+ or None if this instance had no skeleton.
447+ crop_shape: Shape of the bbox crop.
448+ bbox: Tuple of slices defining the bbox in the (possibly padded) label.
449+ pad_offset: Coordinate offset added by padding (0 if no padding).
450+ """
451+ if vertices is None or len (vertices ) == 0 :
452+ return None
453+
454+ # Translate: original-label coords → padded-label coords → bbox-local coords.
455+ bbox_origin = np .array ([s .start for s in bbox ])
456+ local_verts = vertices + pad_offset - bbox_origin
457+
458+ # Filter to valid range.
459+ valid = np .all ((local_verts >= 0 ) & (local_verts < np .array (crop_shape )), axis = 1 )
460+ local_verts = local_verts [valid ]
461+
462+ if len (local_verts ) == 0 :
463+ return None
464+
465+ mask = np .zeros (crop_shape , dtype = bool )
466+ if len (crop_shape ) == 3 :
467+ mask [local_verts [:, 0 ], local_verts [:, 1 ], local_verts [:, 2 ]] = True
468+ else :
469+ mask [local_verts [:, 0 ], local_verts [:, 1 ]] = True
470+ return mask
471+
472+
473+ def precompute_sdt_volume (
474+ label_path : str ,
475+ output_path : str ,
476+ resolution : Tuple [float , ...] = (1.0 , 1.0 , 1.0 ),
477+ alpha : float = 0.8 ,
478+ bg_value : float = - 1.0 ,
479+ ) -> str :
480+ """Precompute skeleton-aware distance transform on a full label volume.
481+
482+ Computes the SDT once on the entire volume and saves to HDF5.
483+ Subsequent training runs load the precomputed result, avoiding
484+ the expensive per-crop skeletonization.
485+
486+ Args:
487+ label_path: Path to the instance segmentation label volume.
488+ output_path: Path to save the precomputed SDT (HDF5).
489+ resolution: Voxel resolution (z, y, x) for anisotropic data.
490+ alpha: Skeleton influence exponent.
491+ bg_value: Background value for non-instance regions.
492+
493+ Returns:
494+ The output_path (for chaining).
495+ """
496+ import logging
497+ import time
498+
499+ from ..io .io import read_volume , save_volume
500+
501+ logger = logging .getLogger (__name__ )
502+ logger .info (f"Precomputing SDT: { label_path } → { output_path } " )
503+
504+ label = read_volume (label_path )
505+ logger .info (f" Label shape: { label .shape } , unique instances: { len (np .unique (label )) - 1 } " )
506+
507+ t0 = time .time ()
508+ sdt = skeleton_aware_distance_transform (
509+ label , resolution = resolution , alpha = alpha , bg_value = bg_value
510+ )
511+ elapsed = time .time () - t0
512+ logger .info (f" SDT computed in { elapsed :.1f} s, range: [{ sdt .min ():.3f} , { sdt .max ():.3f} ]" )
513+
514+ save_volume (output_path , sdt )
515+ logger .info (f" Saved to { output_path } " )
516+
517+ return output_path
518+
519+
520+ def sdt_path_for_label (label_path : str ) -> str :
521+ """Derive the SDT cache path from a label file path.
522+
523+ Example: ``datasets/SNEMI/train-labels.tif`` → ``datasets/SNEMI/train-labels_sdt.h5``
524+ """
525+ import os
526+
527+ base , _ = os .path .splitext (label_path )
528+ return base + "_sdt.h5"
0 commit comments