Skip to content

Commit 318d41e

Browse files
Donglai Weiclaude
andcommitted
Optimize SDT: batch skeletonization, precomputed volume caching, new defaults
- Batch kimimaro.skeletonize() with parallel=0 replaces N serial per-instance calls - Auto-precompute full-volume SDT and cache to disk (train-labels_sdt.h5) - Precomputed SDT flows through spatial transforms (crop/flip/rotate) as "sdt" key - MultiTaskLabelTransformd uses precomputed SDT when available, skipping per-crop computation - Default smooth=False (saves ~20% overhead), relabel=False (global SDT) - Use cc3d for connected components instead of skimage.measure.label Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a8097d2 commit 318d41e

6 files changed

Lines changed: 235 additions & 56 deletions

File tree

.claude/benchmark/SNEMI.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ Files changed:
142142
- `connectomics/models/loss/metadata.py`: added metadata with `spatial_weight_arg="weight"`
143143
- `connectomics/training/loss/plan.py`: defaults `pos_weight` to `1.0` for this loss (prevents orchestrator from double-weighting; the loss handles class balancing internally)
144144
- `tutorials/bases/loss_profiles.yaml`: new `loss_per_channel` profile (single entry)
145-
- `tutorials/bases/pipeline_profiles.yaml`: `affinity-12` pipeline now uses `loss_per_channel` profile
145+
- `tutorials/bases/pipeline_profiles.yaml`: `aff12` pipeline now uses `loss_per_channel` profile
146146

147-
Config (via `affinity-12` pipeline profile → `loss_per_channel` loss profile):
147+
Config (via `aff12` pipeline profile → `loss_per_channel` loss profile):
148148
```yaml
149149
# loss_profiles.yaml
150150
loss_per_channel:
@@ -273,7 +273,7 @@ default:
273273
down_factors: [[1,2,2], [1,2,2], [1,2,2], [1,2,2]]
274274
input_size: [18, 160, 160]
275275
output_size: [18, 160, 160]
276-
# Loss handled by affinity-12 pipeline profile → loss_per_channel → PerChannelBCEWithLogitsLoss
276+
# Loss handled by aff12 pipeline profile → loss_per_channel → PerChannelBCEWithLogitsLoss
277277
```
278278

279279
### Phase 3: Match Augmentation ✅
@@ -290,7 +290,7 @@ All three items implemented and tested (32/32 augmentation tests pass).
290290
3. **Contrast/brightness ±50%** — `contrast_range=[0.5, 1.5]`, `shift_intensity_offset=0.2` (matches DeepEM `MixedGrayscale2D`).
291291

292292
All settings live in the `aug_em_neuron` profile (`tutorials/bases/augmentation_profiles.yaml`),
293-
which is applied automatically via the `affinity-12` pipeline profile in `tutorials/bases/pipeline_profiles.yaml`.
293+
which is applied automatically via the `aff12` pipeline profile in `tutorials/bases/pipeline_profiles.yaml`.
294294
No inline augmentation overrides are needed in `neuron_snemi.yaml`.
295295

296296
### Phase 4: Inference Improvements

connectomics/data/augment/build.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@
5151
)
5252

5353

54+
def _has_precomputed_sdt(cfg: Config) -> bool:
55+
"""Check if the label transform includes skeleton_aware_edt (precomputed SDT)."""
56+
targets = getattr(cfg.data.label_transform, "targets", None)
57+
if not targets:
58+
return False
59+
for t in targets:
60+
name = t.get("name") if isinstance(t, dict) else getattr(t, "name", None)
61+
if name == "skeleton_aware_edt":
62+
return True
63+
return False
64+
65+
5466
def _strict_binarize_mask(mask, threshold: float = 0.0):
5567
"""Binarize mask with strict greater-than semantics (mask > threshold)."""
5668
if torch.is_tensor(mask):
@@ -101,6 +113,9 @@ def build_train_transforms(
101113
keys = ["image", "label"]
102114
if cfg.data.train.mask is not None:
103115
keys.append("mask")
116+
# Include precomputed SDT key if present (auto-detected from label_transform).
117+
if _has_precomputed_sdt(cfg):
118+
keys.append("sdt")
104119

105120
transforms = []
106121

connectomics/data/dataset/data_dicts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def create_data_dicts_from_paths(
1313
image_paths: List[str],
1414
label_paths: Optional[List[str]] = None,
1515
mask_paths: Optional[List[str]] = None,
16+
extra_paths: Optional[Dict[str, List[str]]] = None,
1617
) -> List[Dict[str, object]]:
1718
"""
1819
Create MONAI-style data dictionaries from file paths.
@@ -21,6 +22,8 @@ def create_data_dicts_from_paths(
2122
image_paths: List of image file paths
2223
label_paths: Optional list of label file paths
2324
mask_paths: Optional list of mask file paths
25+
extra_paths: Optional dict of additional keys to include, e.g.
26+
``{"sdt": ["/path/to/sdt1.h5", "/path/to/sdt2.h5"]}``
2427
2528
Returns:
2629
List of dictionaries with 'image', 'label', and/or 'mask' keys
@@ -36,6 +39,10 @@ def create_data_dicts_from_paths(
3639
if mask_paths is not None:
3740
data_dict["mask"] = mask_paths[i]
3841

42+
if extra_paths is not None:
43+
for key, paths in extra_paths.items():
44+
data_dict[key] = paths[i]
45+
3946
data_dicts.append(data_dict)
4047

4148
return data_dicts

connectomics/data/process/distance.py

Lines changed: 142 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
remove_small_holes,
1414
)
1515

16+
import cc3d
17+
1618
from .bbox_processor import BBoxInstanceProcessor, BBoxProcessorConfig
1719
from .quantize import energy_quantize
1820

@@ -21,6 +23,7 @@
2123
"edt_instance",
2224
"distance_transform",
2325
"skeleton_aware_distance_transform",
26+
"precompute_sdt_volume",
2427
"smooth_edge",
2528
"signed_distance_transform",
2629
]
@@ -283,11 +286,11 @@ def signed_distance_transform(
283286
def skeleton_aware_distance_transform(
284287
label: np.ndarray,
285288
bg_value: float = -1.0,
286-
relabel: bool = True,
289+
relabel: bool = False,
287290
padding: bool = False,
288291
resolution: Tuple[float] = (1.0, 1.0, 1.0),
289292
alpha: float = 0.8,
290-
smooth: bool = True,
293+
smooth: bool = False,
291294
smooth_skeleton_only: bool = True,
292295
):
293296
"""Skeleton-based distance transform (SDT).
@@ -296,8 +299,9 @@ def skeleton_aware_distance_transform(
296299
Distance Transform." International Conference on Medical Image Computing and
297300
Computer-Assisted Intervention. Cham: Springer Nature Switzerland, 2023.
298301
299-
Refactored to use BBoxInstanceProcessor for cleaner code and consistency.
300-
Uses kimimaro for fast skeletonization (10-100x faster than scikit-image).
302+
Uses batch kimimaro skeletonization: all instances are skeletonized in a single
303+
call with automatic parallelism, then per-instance EDT is computed via
304+
BBoxInstanceProcessor.
301305
302306
Args:
303307
label: Instance segmentation (H, W) or (D, H, W)
@@ -306,7 +310,8 @@ def skeleton_aware_distance_transform(
306310
padding: Whether to pad before computing distance
307311
resolution: Voxel resolution for anisotropic data (z, y, x)
308312
alpha: Skeleton influence exponent (higher = stronger skeleton influence)
309-
smooth: Whether to smooth edges before skeletonization
313+
smooth: Whether to smooth edges before skeletonization (default False;
314+
adds ~20% overhead with marginal quality impact when using kimimaro)
310315
smooth_skeleton_only: Only smooth skeleton mask (not entire object)
311316
312317
Returns:
@@ -318,24 +323,33 @@ def skeleton_aware_distance_transform(
318323
if np.sum(label > 0) == 0:
319324
return np.full(label.shape, bg_value, dtype=np.float32)
320325

321-
# Configure bbox processor
326+
# 1. Relabel outside processor so we can batch-skeletonize.
327+
if relabel:
328+
label = cc3d.connected_components(label, connectivity=6)
329+
330+
# 2. Batch skeletonize all instances in one call (parallel across instances).
331+
skeleton_vertices = _batch_skeletonize(label, resolution)
332+
333+
# 3. Per-instance EDT using BBoxProcessor (skeletons already computed).
334+
# Padding coordinate offset: if padding is enabled, the processor pads the
335+
# label internally, shifting coordinates by pad_size. We account for this
336+
# when translating skeleton vertices to bbox-local coordinates.
337+
pad_offset = 2 if padding else 0
338+
322339
config = BBoxProcessorConfig(
323340
bg_value=bg_value,
324-
relabel=relabel,
341+
relabel=False, # already relabeled above
325342
padding=padding,
326343
pad_size=2,
327344
bbox_relax=2,
328345
combine_mode="max",
329346
)
330347

331-
# Define per-instance skeleton EDT computation
332348
def compute_skeleton_edt(
333349
label_crop: np.ndarray, instance_id: int, bbox: Tuple[slice, ...], context: Dict
334350
) -> Optional[np.ndarray]:
335351
"""Compute skeleton-aware EDT for a single instance within bbox."""
336-
# Extract and clean mask
337352
temp2 = remove_small_holes(label_crop == instance_id, 16, connectivity=1)
338-
339353
if not temp2.any():
340354
return None
341355

@@ -351,10 +365,15 @@ def compute_skeleton_edt(
351365
binary = binary_smooth.astype(bool)
352366
temp2 = binary
353367

354-
# Skeletonize using kimimaro
355-
skeleton_mask = _skeletonize_instance(label_crop, instance_id, context["resolution"])
368+
# Look up pre-computed skeleton and translate to bbox-local coordinates.
369+
skeleton_mask = _skeleton_vertices_to_mask(
370+
context["skeleton_vertices"].get(instance_id),
371+
label_crop.shape,
372+
bbox,
373+
context["pad_offset"],
374+
)
356375

357-
# Fallback to regular EDT if skeletonization fails
376+
# Fallback to regular EDT if skeletonization failed for this instance.
358377
if skeleton_mask is None or not skeleton_mask.any():
359378
boundary_edt = distance_transform_edt(temp2, context["resolution"])
360379
edt_max = boundary_edt.max()
@@ -367,70 +386,143 @@ def compute_skeleton_edt(
367386
skeleton_edt = distance_transform_edt(~skeleton_mask, context["resolution"])
368387
boundary_edt = distance_transform_edt(temp2, context["resolution"])
369388

370-
# Normalized energy
371389
energy = boundary_edt / (skeleton_edt + boundary_edt + eps)
372390
energy = energy ** context["alpha"]
373391

374392
return energy * temp2.astype(np.float32)
375393

376-
# Process all instances
377394
processor = BBoxInstanceProcessor(config)
378395
return processor.process(
379396
label,
380397
compute_skeleton_edt,
398+
skeleton_vertices=skeleton_vertices,
399+
pad_offset=pad_offset,
381400
resolution=resolution,
382401
alpha=alpha,
383402
smooth=smooth,
384403
smooth_skeleton_only=smooth_skeleton_only,
385404
)
386405

387406

388-
def _skeletonize_instance(
389-
label_crop: np.ndarray, instance_id: int, resolution: Tuple[float, ...]
390-
) -> Optional[np.ndarray]:
391-
"""Helper function to skeletonize a single instance using kimimaro.
392-
393-
Args:
394-
label_crop: Cropped label array containing the instance
395-
instance_id: ID of the instance to skeletonize
396-
resolution: Voxel resolution for anisotropic data
407+
def _batch_skeletonize(
408+
label: np.ndarray, resolution: Tuple[float, ...]
409+
) -> Dict[int, np.ndarray]:
410+
"""Skeletonize all instances in one kimimaro call.
397411
398412
Returns:
399-
Binary skeleton mask, or None if skeletonization fails
413+
Dict mapping instance_id → (N, ndim) int array of vertex coordinates
414+
in the input label's coordinate system.
400415
"""
401-
instance_label = np.where(label_crop == instance_id, 1, 0).astype(np.uint32)
402-
403416
try:
404417
skeletons = kimimaro.skeletonize(
405-
instance_label,
418+
label.astype(np.uint32),
406419
anisotropy=resolution,
407420
fix_branching=False,
408421
fix_borders=False,
409422
dust_threshold=5,
410-
parallel=1,
423+
parallel=0, # auto-detect cores
411424
progress=False,
412425
)
426+
except Exception:
427+
return {}
413428

414-
if 1 in skeletons and len(skeletons[1].vertices) > 0:
415-
skeleton_mask = np.zeros(label_crop.shape, dtype=bool)
416-
vertices = skeletons[1].vertices.astype(int)
417-
418-
# Filter valid vertices
419-
valid_mask = np.all(
420-
(vertices >= 0) & (vertices < np.array(skeleton_mask.shape)), axis=1
421-
)
422-
valid_vertices = vertices[valid_mask]
423-
424-
if len(valid_vertices) > 0:
425-
if label_crop.ndim == 3:
426-
skeleton_mask[
427-
valid_vertices[:, 0], valid_vertices[:, 1], valid_vertices[:, 2]
428-
] = True
429-
else:
430-
skeleton_mask[valid_vertices[:, 0], valid_vertices[:, 1]] = True
431-
return skeleton_mask
429+
result = {}
430+
for inst_id, skel in skeletons.items():
431+
if len(skel.vertices) > 0:
432+
result[inst_id] = skel.vertices.astype(int)
433+
return result
432434

433-
except Exception:
434-
pass
435435

436-
return None
436+
def _skeleton_vertices_to_mask(
437+
vertices: Optional[np.ndarray],
438+
crop_shape: Tuple[int, ...],
439+
bbox: Tuple[slice, ...],
440+
pad_offset: int,
441+
) -> Optional[np.ndarray]:
442+
"""Convert skeleton vertices (full-volume coords) to a binary mask in bbox-local coords.
443+
444+
Args:
445+
vertices: (N, ndim) vertex coordinates in the original (unpadded) label space,
446+
or None if this instance had no skeleton.
447+
crop_shape: Shape of the bbox crop.
448+
bbox: Tuple of slices defining the bbox in the (possibly padded) label.
449+
pad_offset: Coordinate offset added by padding (0 if no padding).
450+
"""
451+
if vertices is None or len(vertices) == 0:
452+
return None
453+
454+
# Translate: original-label coords → padded-label coords → bbox-local coords.
455+
bbox_origin = np.array([s.start for s in bbox])
456+
local_verts = vertices + pad_offset - bbox_origin
457+
458+
# Filter to valid range.
459+
valid = np.all((local_verts >= 0) & (local_verts < np.array(crop_shape)), axis=1)
460+
local_verts = local_verts[valid]
461+
462+
if len(local_verts) == 0:
463+
return None
464+
465+
mask = np.zeros(crop_shape, dtype=bool)
466+
if len(crop_shape) == 3:
467+
mask[local_verts[:, 0], local_verts[:, 1], local_verts[:, 2]] = True
468+
else:
469+
mask[local_verts[:, 0], local_verts[:, 1]] = True
470+
return mask
471+
472+
473+
def precompute_sdt_volume(
474+
label_path: str,
475+
output_path: str,
476+
resolution: Tuple[float, ...] = (1.0, 1.0, 1.0),
477+
alpha: float = 0.8,
478+
bg_value: float = -1.0,
479+
) -> str:
480+
"""Precompute skeleton-aware distance transform on a full label volume.
481+
482+
Computes the SDT once on the entire volume and saves to HDF5.
483+
Subsequent training runs load the precomputed result, avoiding
484+
the expensive per-crop skeletonization.
485+
486+
Args:
487+
label_path: Path to the instance segmentation label volume.
488+
output_path: Path to save the precomputed SDT (HDF5).
489+
resolution: Voxel resolution (z, y, x) for anisotropic data.
490+
alpha: Skeleton influence exponent.
491+
bg_value: Background value for non-instance regions.
492+
493+
Returns:
494+
The output_path (for chaining).
495+
"""
496+
import logging
497+
import time
498+
499+
from ..io.io import read_volume, save_volume
500+
501+
logger = logging.getLogger(__name__)
502+
logger.info(f"Precomputing SDT: {label_path}{output_path}")
503+
504+
label = read_volume(label_path)
505+
logger.info(f" Label shape: {label.shape}, unique instances: {len(np.unique(label)) - 1}")
506+
507+
t0 = time.time()
508+
sdt = skeleton_aware_distance_transform(
509+
label, resolution=resolution, alpha=alpha, bg_value=bg_value
510+
)
511+
elapsed = time.time() - t0
512+
logger.info(f" SDT computed in {elapsed:.1f}s, range: [{sdt.min():.3f}, {sdt.max():.3f}]")
513+
514+
save_volume(output_path, sdt)
515+
logger.info(f" Saved to {output_path}")
516+
517+
return output_path
518+
519+
520+
def sdt_path_for_label(label_path: str) -> str:
521+
"""Derive the SDT cache path from a label file path.
522+
523+
Example: ``datasets/SNEMI/train-labels.tif`` → ``datasets/SNEMI/train-labels_sdt.h5``
524+
"""
525+
import os
526+
527+
base, _ = os.path.splitext(label_path)
528+
return base + "_sdt.h5"

0 commit comments

Comments
 (0)