@@ -282,6 +282,20 @@ def signed_distance_transform(
282282 return sdt_normalized .astype (np .float32 )
283283
284284
285+ def _thin_centerline_weight_boost (
286+ radius_phys : np .ndarray ,
287+ instance_mask : np .ndarray ,
288+ resolution : Tuple [float , ...],
289+ weight_param : float ,
290+ eps : float ,
291+ ) -> np .ndarray :
292+ """Convert local physical radius to the bounded thin-centerline boost."""
293+ voxel_size = max (float (min (resolution )), eps )
294+ r_vox = (radius_phys + eps ) / voxel_size
295+ boost = float (weight_param ) / np .maximum (1.0 , r_vox )
296+ return (boost * instance_mask .astype (np .float32 )).astype (np .float32 , copy = False )
297+
298+
285299def skeleton_aware_distance_transform (
286300 label : np .ndarray ,
287301 bg_value : float = - 1.0 ,
@@ -292,6 +306,8 @@ def skeleton_aware_distance_transform(
292306 smooth : bool = False ,
293307 smooth_skeleton_only : bool = True ,
294308 max_parallel : int = 1 ,
309+ weight_param : float = 0.0 ,
310+ w_base : float = 1.0 ,
295311):
296312 """Skeleton-based distance transform (SDT).
297313
@@ -313,21 +329,32 @@ def skeleton_aware_distance_transform(
313329 smooth: Whether to smooth edges before skeletonization (default False;
314330 adds ~20% overhead with marginal quality impact when using kimimaro)
315331 smooth_skeleton_only: Only smooth skeleton mask (not entire object)
332+ weight_param: Optional thin-centerline weight boost. Defaults to 0.0,
333+ which preserves the single-channel energy output.
334+ w_base: Base value for the optional spatial weight channel.
316335
317336 Returns:
318- Skeleton-aware distance map with same shape as input
337+ Skeleton-aware distance map with same shape as input. When
338+ ``weight_param > 0``, returns ``[energy, weight]`` stacked on a leading
339+ channel axis.
319340 """
320341 eps = 1e-6
321342
322343 # Fast-path: empty label should produce all background energy.
323344 if np .sum (label > 0 ) == 0 :
324- return np .full (label .shape , bg_value , dtype = np .float32 )
345+ energy = np .full (label .shape , bg_value , dtype = np .float32 )
346+ if weight_param <= 0 :
347+ return energy
348+ weight = np .full (label .shape , w_base , dtype = np .float32 )
349+ return np .stack ([energy , weight ], axis = 0 )
325350
326351 # 1. Relabel outside processor so we can batch-skeletonize.
327352 if relabel :
328353 label = cc3d .connected_components (label , connectivity = 6 )
329354
330355 # 2. Batch skeletonize all instances in one call (parallel across instances).
356+ # The resulting skeleton_vertices are reused by both the energy pass and
357+ # optional weight pass below; weighting does not re-skeletonize.
331358 skeleton_vertices = _batch_skeletonize (label , resolution , max_parallel = max_parallel )
332359 print (f" Skeletonization done: { len (skeleton_vertices )} skeletons extracted" )
333360
@@ -392,8 +419,56 @@ def compute_skeleton_edt(
392419
393420 return energy * temp2 .astype (np .float32 )
394421
422+ def compute_skeleton_boost (
423+ label_crop : np .ndarray , instance_id : int , bbox : Tuple [slice , ...], context : Dict
424+ ) -> Optional [np .ndarray ]:
425+ """Compute the optional thin-centerline boost for a single instance."""
426+ temp2 = remove_small_holes (label_crop == instance_id , 16 , connectivity = 1 )
427+ if not temp2 .any ():
428+ return None
429+
430+ binary = temp2
431+
432+ if context ["smooth" ]:
433+ binary_smooth = smooth_edge (binary .astype (np .uint8 ))
434+ if binary_smooth .astype (int ).sum () > 32 :
435+ if context ["smooth_skeleton_only" ]:
436+ binary = binary_smooth .astype (bool ) & temp2
437+ else :
438+ binary = binary_smooth .astype (bool )
439+ temp2 = binary
440+
441+ skeleton_mask = _skeleton_vertices_to_mask (
442+ context ["skeleton_vertices" ].get (instance_id ),
443+ label_crop .shape ,
444+ bbox ,
445+ context ["pad_offset" ],
446+ )
447+
448+ if skeleton_mask is None or not skeleton_mask .any ():
449+ boundary_edt = distance_transform_edt (temp2 , context ["resolution" ])
450+ if boundary_edt .max () > eps :
451+ return _thin_centerline_weight_boost (
452+ boundary_edt ,
453+ temp2 ,
454+ context ["resolution" ],
455+ context ["weight_param" ],
456+ eps ,
457+ )
458+ return None
459+
460+ skeleton_edt = distance_transform_edt (~ skeleton_mask , context ["resolution" ])
461+ boundary_edt = distance_transform_edt (temp2 , context ["resolution" ])
462+ return _thin_centerline_weight_boost (
463+ skeleton_edt + boundary_edt ,
464+ temp2 ,
465+ context ["resolution" ],
466+ context ["weight_param" ],
467+ eps ,
468+ )
469+
395470 processor = BBoxInstanceProcessor (config )
396- return processor .process (
471+ energy = processor .process (
397472 label ,
398473 compute_skeleton_edt ,
399474 num_workers = max_parallel ,
@@ -404,6 +479,31 @@ def compute_skeleton_edt(
404479 smooth = smooth ,
405480 smooth_skeleton_only = smooth_skeleton_only ,
406481 )
482+ if weight_param <= 0 :
483+ return energy
484+
485+ weight_config = BBoxProcessorConfig (
486+ bg_value = 0.0 ,
487+ relabel = False ,
488+ padding = padding ,
489+ pad_size = 2 ,
490+ bbox_relax = 2 ,
491+ combine_mode = "max" ,
492+ )
493+ weight_processor = BBoxInstanceProcessor (weight_config )
494+ boost = weight_processor .process (
495+ label ,
496+ compute_skeleton_boost ,
497+ num_workers = max_parallel ,
498+ skeleton_vertices = skeleton_vertices ,
499+ pad_offset = pad_offset ,
500+ resolution = resolution ,
501+ weight_param = weight_param ,
502+ smooth = smooth ,
503+ smooth_skeleton_only = smooth_skeleton_only ,
504+ )
505+ weight = w_base + boost
506+ return np .stack ([energy , weight .astype (np .float32 , copy = False )], axis = 0 )
407507
408508
409509def kimimaro_config (label : np .ndarray , resolution : Tuple [float , ...]) -> dict :
@@ -695,6 +795,8 @@ def skeleton_aware_edt_from_skeleton_vol(
695795 resolution : Tuple [float , ...] = (1.0 , 1.0 , 1.0 ),
696796 alpha : float = 0.8 ,
697797 bg_value : float = - 1.0 ,
798+ weight_param : float = 0.0 ,
799+ w_base : float = 1.0 ,
698800) -> np .ndarray :
699801 """Compute skeleton-aware EDT using a precomputed skeleton volume.
700802
@@ -709,14 +811,23 @@ def skeleton_aware_edt_from_skeleton_vol(
709811 resolution: Voxel resolution for anisotropic EDT.
710812 alpha: Skeleton influence exponent.
711813 bg_value: Background fill value.
814+ weight_param: Optional thin-centerline weight boost. Defaults to 0.0,
815+ which preserves the single-channel energy output.
816+ w_base: Base value for the optional spatial weight channel.
712817
713818 Returns:
714- Skeleton-aware distance map, same shape as label.
819+ Skeleton-aware distance map, same shape as label. When
820+ ``weight_param > 0``, returns ``[energy, weight]`` stacked on a leading
821+ channel axis.
715822 """
716823 eps = 1e-6
717824
718825 if np .sum (label > 0 ) == 0 :
719- return np .full (label .shape , bg_value , dtype = np .float32 )
826+ energy = np .full (label .shape , bg_value , dtype = np .float32 )
827+ if weight_param <= 0 :
828+ return energy
829+ weight = np .full (label .shape , w_base , dtype = np .float32 )
830+ return np .stack ([energy , weight ], axis = 0 )
720831
721832 config = BBoxProcessorConfig (
722833 bg_value = bg_value ,
@@ -754,14 +865,67 @@ def compute_edt_with_skeleton(
754865 energy = energy ** context ["alpha" ]
755866 return energy * temp2 .astype (np .float32 )
756867
868+ def compute_skeleton_boost (
869+ label_crop : np .ndarray , instance_id : int , bbox : Tuple [slice , ...], context : Dict
870+ ) -> Optional [np .ndarray ]:
871+ temp2 = remove_small_holes (label_crop == instance_id , 16 , connectivity = 1 )
872+ if not temp2 .any ():
873+ return None
874+
875+ skel_crop = context ["skeleton_vol" ][bbox ]
876+ skeleton_mask = skel_crop == instance_id
877+
878+ if not skeleton_mask .any ():
879+ boundary_edt = distance_transform_edt (temp2 , context ["resolution" ])
880+ if boundary_edt .max () > eps :
881+ return _thin_centerline_weight_boost (
882+ boundary_edt ,
883+ temp2 ,
884+ context ["resolution" ],
885+ context ["weight_param" ],
886+ eps ,
887+ )
888+ return None
889+
890+ skeleton_edt = distance_transform_edt (~ skeleton_mask , context ["resolution" ])
891+ boundary_edt = distance_transform_edt (temp2 , context ["resolution" ])
892+ return _thin_centerline_weight_boost (
893+ skeleton_edt + boundary_edt ,
894+ temp2 ,
895+ context ["resolution" ],
896+ context ["weight_param" ],
897+ eps ,
898+ )
899+
757900 processor = BBoxInstanceProcessor (config )
758- return processor .process (
901+ energy = processor .process (
759902 label ,
760903 compute_edt_with_skeleton ,
761904 skeleton_vol = skeleton_vol ,
762905 resolution = resolution ,
763906 alpha = alpha ,
764907 )
908+ if weight_param <= 0 :
909+ return energy
910+
911+ weight_config = BBoxProcessorConfig (
912+ bg_value = 0.0 ,
913+ relabel = False ,
914+ padding = False ,
915+ pad_size = 2 ,
916+ bbox_relax = 2 ,
917+ combine_mode = "max" ,
918+ )
919+ weight_processor = BBoxInstanceProcessor (weight_config )
920+ boost = weight_processor .process (
921+ label ,
922+ compute_skeleton_boost ,
923+ skeleton_vol = skeleton_vol ,
924+ resolution = resolution ,
925+ weight_param = weight_param ,
926+ )
927+ weight = w_base + boost
928+ return np .stack ([energy , weight .astype (np .float32 , copy = False )], axis = 0 )
765929
766930
767931def sdt_path_for_label (
0 commit comments