Skip to content

Commit 3e629c6

Browse files
Donglai Weiclaude
andcommitted
sdt: optional thin-centerline weight channel for skeleton-aware EDT
Add weight_param/w_base to skeleton_aware_distance_transform and skeleton_aware_edt_from_skeleton_vol. When weight_param>0 they emit a [energy, weight] stack where the weight boosts thin centerlines (bounded by base + weight_param/r_vox). Wire through MultiTaskLabelTransformd and channel counting; add base_banis+_sdtw.yaml tutorial and unit tests. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 1b356aa commit 3e629c6

5 files changed

Lines changed: 438 additions & 6 deletions

File tree

connectomics/data/processing/build.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def _task_output_channels(task: Any) -> int:
9898

9999
if name == "affinity":
100100
return len(resolve_affinity_offsets_from_kwargs(resolved_kwargs))
101+
if name == "skeleton_aware_edt":
102+
return 2 if float(resolved_kwargs.get("weight_param", 0.0)) > 0 else 1
101103
if name == "polarity":
102104
return 1 if bool(resolved_kwargs.get("exclusive", False)) else 3
103105
if name == "flow":

connectomics/data/processing/distance.py

Lines changed: 170 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,20 @@ def signed_distance_transform(
282282
return sdt_normalized.astype(np.float32)
283283

284284

285+
def _thin_centerline_weight_boost(
286+
radius_phys: np.ndarray,
287+
instance_mask: np.ndarray,
288+
resolution: Tuple[float, ...],
289+
weight_param: float,
290+
eps: float,
291+
) -> np.ndarray:
292+
"""Convert local physical radius to the bounded thin-centerline boost."""
293+
voxel_size = max(float(min(resolution)), eps)
294+
r_vox = (radius_phys + eps) / voxel_size
295+
boost = float(weight_param) / np.maximum(1.0, r_vox)
296+
return (boost * instance_mask.astype(np.float32)).astype(np.float32, copy=False)
297+
298+
285299
def skeleton_aware_distance_transform(
286300
label: np.ndarray,
287301
bg_value: float = -1.0,
@@ -292,6 +306,8 @@ def skeleton_aware_distance_transform(
292306
smooth: bool = False,
293307
smooth_skeleton_only: bool = True,
294308
max_parallel: int = 1,
309+
weight_param: float = 0.0,
310+
w_base: float = 1.0,
295311
):
296312
"""Skeleton-based distance transform (SDT).
297313
@@ -313,21 +329,32 @@ def skeleton_aware_distance_transform(
313329
smooth: Whether to smooth edges before skeletonization (default False;
314330
adds ~20% overhead with marginal quality impact when using kimimaro)
315331
smooth_skeleton_only: Only smooth skeleton mask (not entire object)
332+
weight_param: Optional thin-centerline weight boost. Defaults to 0.0,
333+
which preserves the single-channel energy output.
334+
w_base: Base value for the optional spatial weight channel.
316335
317336
Returns:
318-
Skeleton-aware distance map with same shape as input
337+
Skeleton-aware distance map with same shape as input. When
338+
``weight_param > 0``, returns ``[energy, weight]`` stacked on a leading
339+
channel axis.
319340
"""
320341
eps = 1e-6
321342

322343
# Fast-path: empty label should produce all background energy.
323344
if np.sum(label > 0) == 0:
324-
return np.full(label.shape, bg_value, dtype=np.float32)
345+
energy = np.full(label.shape, bg_value, dtype=np.float32)
346+
if weight_param <= 0:
347+
return energy
348+
weight = np.full(label.shape, w_base, dtype=np.float32)
349+
return np.stack([energy, weight], axis=0)
325350

326351
# 1. Relabel outside processor so we can batch-skeletonize.
327352
if relabel:
328353
label = cc3d.connected_components(label, connectivity=6)
329354

330355
# 2. Batch skeletonize all instances in one call (parallel across instances).
356+
# The resulting skeleton_vertices are reused by both the energy pass and
357+
# optional weight pass below; weighting does not re-skeletonize.
331358
skeleton_vertices = _batch_skeletonize(label, resolution, max_parallel=max_parallel)
332359
print(f" Skeletonization done: {len(skeleton_vertices)} skeletons extracted")
333360

@@ -392,8 +419,56 @@ def compute_skeleton_edt(
392419

393420
return energy * temp2.astype(np.float32)
394421

422+
def compute_skeleton_boost(
423+
label_crop: np.ndarray, instance_id: int, bbox: Tuple[slice, ...], context: Dict
424+
) -> Optional[np.ndarray]:
425+
"""Compute the optional thin-centerline boost for a single instance."""
426+
temp2 = remove_small_holes(label_crop == instance_id, 16, connectivity=1)
427+
if not temp2.any():
428+
return None
429+
430+
binary = temp2
431+
432+
if context["smooth"]:
433+
binary_smooth = smooth_edge(binary.astype(np.uint8))
434+
if binary_smooth.astype(int).sum() > 32:
435+
if context["smooth_skeleton_only"]:
436+
binary = binary_smooth.astype(bool) & temp2
437+
else:
438+
binary = binary_smooth.astype(bool)
439+
temp2 = binary
440+
441+
skeleton_mask = _skeleton_vertices_to_mask(
442+
context["skeleton_vertices"].get(instance_id),
443+
label_crop.shape,
444+
bbox,
445+
context["pad_offset"],
446+
)
447+
448+
if skeleton_mask is None or not skeleton_mask.any():
449+
boundary_edt = distance_transform_edt(temp2, context["resolution"])
450+
if boundary_edt.max() > eps:
451+
return _thin_centerline_weight_boost(
452+
boundary_edt,
453+
temp2,
454+
context["resolution"],
455+
context["weight_param"],
456+
eps,
457+
)
458+
return None
459+
460+
skeleton_edt = distance_transform_edt(~skeleton_mask, context["resolution"])
461+
boundary_edt = distance_transform_edt(temp2, context["resolution"])
462+
return _thin_centerline_weight_boost(
463+
skeleton_edt + boundary_edt,
464+
temp2,
465+
context["resolution"],
466+
context["weight_param"],
467+
eps,
468+
)
469+
395470
processor = BBoxInstanceProcessor(config)
396-
return processor.process(
471+
energy = processor.process(
397472
label,
398473
compute_skeleton_edt,
399474
num_workers=max_parallel,
@@ -404,6 +479,31 @@ def compute_skeleton_edt(
404479
smooth=smooth,
405480
smooth_skeleton_only=smooth_skeleton_only,
406481
)
482+
if weight_param <= 0:
483+
return energy
484+
485+
weight_config = BBoxProcessorConfig(
486+
bg_value=0.0,
487+
relabel=False,
488+
padding=padding,
489+
pad_size=2,
490+
bbox_relax=2,
491+
combine_mode="max",
492+
)
493+
weight_processor = BBoxInstanceProcessor(weight_config)
494+
boost = weight_processor.process(
495+
label,
496+
compute_skeleton_boost,
497+
num_workers=max_parallel,
498+
skeleton_vertices=skeleton_vertices,
499+
pad_offset=pad_offset,
500+
resolution=resolution,
501+
weight_param=weight_param,
502+
smooth=smooth,
503+
smooth_skeleton_only=smooth_skeleton_only,
504+
)
505+
weight = w_base + boost
506+
return np.stack([energy, weight.astype(np.float32, copy=False)], axis=0)
407507

408508

409509
def kimimaro_config(label: np.ndarray, resolution: Tuple[float, ...]) -> dict:
@@ -695,6 +795,8 @@ def skeleton_aware_edt_from_skeleton_vol(
695795
resolution: Tuple[float, ...] = (1.0, 1.0, 1.0),
696796
alpha: float = 0.8,
697797
bg_value: float = -1.0,
798+
weight_param: float = 0.0,
799+
w_base: float = 1.0,
698800
) -> np.ndarray:
699801
"""Compute skeleton-aware EDT using a precomputed skeleton volume.
700802
@@ -709,14 +811,23 @@ def skeleton_aware_edt_from_skeleton_vol(
709811
resolution: Voxel resolution for anisotropic EDT.
710812
alpha: Skeleton influence exponent.
711813
bg_value: Background fill value.
814+
weight_param: Optional thin-centerline weight boost. Defaults to 0.0,
815+
which preserves the single-channel energy output.
816+
w_base: Base value for the optional spatial weight channel.
712817
713818
Returns:
714-
Skeleton-aware distance map, same shape as label.
819+
Skeleton-aware distance map, same shape as label. When
820+
``weight_param > 0``, returns ``[energy, weight]`` stacked on a leading
821+
channel axis.
715822
"""
716823
eps = 1e-6
717824

718825
if np.sum(label > 0) == 0:
719-
return np.full(label.shape, bg_value, dtype=np.float32)
826+
energy = np.full(label.shape, bg_value, dtype=np.float32)
827+
if weight_param <= 0:
828+
return energy
829+
weight = np.full(label.shape, w_base, dtype=np.float32)
830+
return np.stack([energy, weight], axis=0)
720831

721832
config = BBoxProcessorConfig(
722833
bg_value=bg_value,
@@ -754,14 +865,67 @@ def compute_edt_with_skeleton(
754865
energy = energy ** context["alpha"]
755866
return energy * temp2.astype(np.float32)
756867

868+
def compute_skeleton_boost(
869+
label_crop: np.ndarray, instance_id: int, bbox: Tuple[slice, ...], context: Dict
870+
) -> Optional[np.ndarray]:
871+
temp2 = remove_small_holes(label_crop == instance_id, 16, connectivity=1)
872+
if not temp2.any():
873+
return None
874+
875+
skel_crop = context["skeleton_vol"][bbox]
876+
skeleton_mask = skel_crop == instance_id
877+
878+
if not skeleton_mask.any():
879+
boundary_edt = distance_transform_edt(temp2, context["resolution"])
880+
if boundary_edt.max() > eps:
881+
return _thin_centerline_weight_boost(
882+
boundary_edt,
883+
temp2,
884+
context["resolution"],
885+
context["weight_param"],
886+
eps,
887+
)
888+
return None
889+
890+
skeleton_edt = distance_transform_edt(~skeleton_mask, context["resolution"])
891+
boundary_edt = distance_transform_edt(temp2, context["resolution"])
892+
return _thin_centerline_weight_boost(
893+
skeleton_edt + boundary_edt,
894+
temp2,
895+
context["resolution"],
896+
context["weight_param"],
897+
eps,
898+
)
899+
757900
processor = BBoxInstanceProcessor(config)
758-
return processor.process(
901+
energy = processor.process(
759902
label,
760903
compute_edt_with_skeleton,
761904
skeleton_vol=skeleton_vol,
762905
resolution=resolution,
763906
alpha=alpha,
764907
)
908+
if weight_param <= 0:
909+
return energy
910+
911+
weight_config = BBoxProcessorConfig(
912+
bg_value=0.0,
913+
relabel=False,
914+
padding=False,
915+
pad_size=2,
916+
bbox_relax=2,
917+
combine_mode="max",
918+
)
919+
weight_processor = BBoxInstanceProcessor(weight_config)
920+
boost = weight_processor.process(
921+
label,
922+
compute_skeleton_boost,
923+
skeleton_vol=skeleton_vol,
924+
resolution=resolution,
925+
weight_param=weight_param,
926+
)
927+
weight = w_base + boost
928+
return np.stack([energy, weight.astype(np.float32, copy=False)], axis=0)
765929

766930

767931
def sdt_path_for_label(

connectomics/data/processing/transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,8 @@ class MultiTaskLabelTransformd(MapTransform):
790790
"alpha": 0.8,
791791
"smooth": False,
792792
"smooth_skeleton_only": True,
793+
"weight_param": 0.0,
794+
"w_base": 1.0,
793795
},
794796
"semantic_edt": {"mode": "2d", "alpha_fore": 8.0, "alpha_back": 50.0, "resolution": None},
795797
"signed_distance": {"alpha": 8.0},
@@ -1008,7 +1010,15 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
10081010

10091011
# Auto-detect mode: SDT has negative floats, skeleton has integer IDs.
10101012
is_sdt = aux.dtype.kind == "f" and float(aux.min()) < 0
1013+
weight_param = float(spec["kwargs"].get("weight_param", 0.0))
10111014
if is_sdt:
1015+
if weight_param > 0:
1016+
raise ValueError(
1017+
"skeleton_aware_edt weight_param>0 cannot use a precomputed "
1018+
"full-SDT label_aux cache because it stores only the energy "
1019+
"channel. Precompute skeletons with label_aux_type=skeleton "
1020+
"or disable the full-SDT cache."
1021+
)
10121022
result = aux.astype(np.float32)
10131023
else:
10141024
result = skeleton_aware_edt_from_skeleton_vol(
@@ -1017,6 +1027,8 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
10171027
resolution=spec["kwargs"].get("resolution", (1.0, 1.0, 1.0)),
10181028
alpha=spec["kwargs"].get("alpha", 0.8),
10191029
bg_value=spec["kwargs"].get("bg_value", -1.0),
1030+
weight_param=weight_param,
1031+
w_base=spec["kwargs"].get("w_base", 1.0),
10201032
)
10211033
out_arr = self._normalize_output(result, spatial_ndim)
10221034
outputs.append(out_arr)

0 commit comments

Comments
 (0)