Skip to content

Commit 5a0b06f

Browse files
Donglai Weiclaude
andcommitted
Add label_aux support: precomputed skeleton/SDT with cached dataset integration
- Add label_aux and label_aux_type fields to DataInputConfig ("skeleton" precomputes skeleton volume, "sdt" full SDT, "none" per-crop) - Auto-precompute runs before build_train_transforms so keys are correct - CachedVolumeDataset loads/caches/crops label_aux alongside image/label/mask - MultiTaskLabelTransformd auto-detects skeleton vs SDT from data values - skeleton_aware_edt_from_skeleton_vol: compute EDT per crop from precomputed skeleton - kimimaro_config: derive TEASAR params from label statistics and resolution - BBoxProcessor: add ThreadPoolExecutor parallelism for per-instance EDT - precompute_skeleton_volume: rasterize kimimaro skeletons to label-like volume Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 318d41e commit 5a0b06f

9 files changed

Lines changed: 393 additions & 92 deletions

File tree

connectomics/config/schema/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ class DataInputConfig:
214214
# Using Any to support both str and List[str] (OmegaConf doesn't support Union of containers)
215215
image: Any = None # str, List[str], or None
216216
label: Any = None # str, List[str], or None
217+
label_aux: Any = None # str, List[str], or None (auto-derived from label if null)
218+
label_aux_type: str = "skeleton" # "skeleton", "sdt", or "none"
217219
mask: Any = None # str, List[str], or None (Valid region mask)
218220

219221
# Paths - JSON/filename-based datasets

connectomics/data/augment/build.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,6 @@
5151
)
5252

5353

54-
def _has_precomputed_sdt(cfg: Config) -> bool:
55-
"""Check if the label transform includes skeleton_aware_edt (precomputed SDT)."""
56-
targets = getattr(cfg.data.label_transform, "targets", None)
57-
if not targets:
58-
return False
59-
for t in targets:
60-
name = t.get("name") if isinstance(t, dict) else getattr(t, "name", None)
61-
if name == "skeleton_aware_edt":
62-
return True
63-
return False
64-
65-
6654
def _strict_binarize_mask(mask, threshold: float = 0.0):
6755
"""Binarize mask with strict greater-than semantics (mask > threshold)."""
6856
if torch.is_tensor(mask):
@@ -111,11 +99,10 @@ def build_train_transforms(
11199
"""
112100
if keys is None:
113101
keys = ["image", "label"]
102+
if cfg.data.train.label_aux is not None:
103+
keys.append("label_aux")
114104
if cfg.data.train.mask is not None:
115105
keys.append("mask")
116-
# Include precomputed SDT key if present (auto-detected from label_transform).
117-
if _has_precomputed_sdt(cfg):
118-
keys.append("sdt")
119106

120107
transforms = []
121108

connectomics/data/dataset/data_dicts.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,22 @@
1212
def create_data_dicts_from_paths(
1313
image_paths: List[str],
1414
label_paths: Optional[List[str]] = None,
15+
label_aux_paths: Optional[List[str]] = None,
1516
mask_paths: Optional[List[str]] = None,
16-
extra_paths: Optional[Dict[str, List[str]]] = None,
1717
) -> List[Dict[str, object]]:
1818
"""
1919
Create MONAI-style data dictionaries from file paths.
2020
2121
Args:
2222
image_paths: List of image file paths
2323
label_paths: Optional list of label file paths
24+
label_aux_paths: Optional list of auxiliary label file paths
25+
(e.g. precomputed SDT volumes)
2426
mask_paths: Optional list of mask file paths
25-
extra_paths: Optional dict of additional keys to include, e.g.
26-
``{"sdt": ["/path/to/sdt1.h5", "/path/to/sdt2.h5"]}``
2727
2828
Returns:
29-
List of dictionaries with 'image', 'label', and/or 'mask' keys
29+
List of dictionaries with 'image', 'label', 'label_aux',
30+
and/or 'mask' keys
3031
"""
3132
data_dicts: List[Dict[str, object]] = []
3233

@@ -36,13 +37,12 @@ def create_data_dicts_from_paths(
3637
if label_paths is not None:
3738
data_dict["label"] = label_paths[i]
3839

40+
if label_aux_paths is not None:
41+
data_dict["label_aux"] = label_aux_paths[i]
42+
3943
if mask_paths is not None:
4044
data_dict["mask"] = mask_paths[i]
4145

42-
if extra_paths is not None:
43-
for key, paths in extra_paths.items():
44-
data_dict[key] = paths[i]
45-
4646
data_dicts.append(data_dict)
4747

4848
return data_dicts

connectomics/data/dataset/dataset_volume_cached.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
self,
103103
image_paths: List[str],
104104
label_paths: Optional[List[str]] = None,
105+
label_aux_paths: Optional[List[str]] = None,
105106
mask_paths: Optional[List[str]] = None,
106107
patch_size: Tuple[int, ...] = (112, 112, 112),
107108
iter_num: int = 500,
@@ -130,40 +131,48 @@ def __init__(
130131
self.sample_nonzero_mask = sample_nonzero_mask
131132

132133
label_paths = label_paths or [None] * len(image_paths)
134+
label_aux_paths = label_aux_paths or [None] * len(image_paths)
133135
mask_paths = mask_paths or [None] * len(image_paths)
134136

135137
# Load all volumes into memory
136138
logger.info("Loading %d volumes into memory...", len(image_paths))
137139
self.cached_images: List[np.ndarray] = []
138140
self.cached_labels: List[Optional[np.ndarray]] = []
141+
self.cached_label_aux: List[Optional[np.ndarray]] = []
139142
self.cached_masks: List[Optional[np.ndarray]] = []
140143

141-
for i, (img_path, lbl_path, msk_path) in enumerate(
142-
zip(image_paths, label_paths, mask_paths)
144+
for i, (img_path, lbl_path, aux_path, msk_path) in enumerate(
145+
zip(image_paths, label_paths, label_aux_paths, mask_paths)
143146
):
144147
img = self._load_volume(img_path)
145148
lbl = self._load_volume(lbl_path) if lbl_path else None
149+
aux = self._load_volume(aux_path) if aux_path else None
146150
msk = self._load_volume(msk_path) if msk_path else None
147151

148152
# Apply one-time preprocessing before caching
149153
if pre_cache_transforms is not None:
150154
sample = {"image": img}
151155
if lbl is not None:
152156
sample["label"] = lbl
157+
if aux is not None:
158+
sample["label_aux"] = aux
153159
if msk is not None:
154160
sample["mask"] = msk
155161
sample = pre_cache_transforms(sample)
156162
img = sample["image"]
157163
lbl = sample.get("label")
164+
aux = sample.get("label_aux")
158165
msk = sample.get("mask")
159166

160167
# Pad and ensure minimum size
161168
img = self._prepare_volume(img)
162169
lbl = self._prepare_volume(lbl) if lbl is not None else None
170+
aux = self._prepare_volume(aux) if aux is not None else None
163171
msk = self._prepare_volume(msk) if msk is not None else None
164172

165173
self.cached_images.append(img)
166174
self.cached_labels.append(lbl)
175+
self.cached_label_aux.append(aux)
167176
self.cached_masks.append(msk)
168177
logger.info("Volume %d/%d: %s", i + 1, len(image_paths), img.shape)
169178

@@ -210,6 +219,7 @@ def __init__(
210219
def _crop_volumes(self, vol_idx: int, pos: Tuple[int, ...]) -> Dict[str, Any]:
211220
image = self.cached_images[vol_idx]
212221
label = self.cached_labels[vol_idx]
222+
label_aux = self.cached_label_aux[vol_idx]
213223
mask = self.cached_masks[vol_idx]
214224

215225
image_crop = crop_volume(image, self.patch_size, pos, pad_mode="reflect")
@@ -218,13 +228,23 @@ def _crop_volumes(self, vol_idx: int, pos: Tuple[int, ...]) -> Dict[str, Any]:
218228
if label is not None
219229
else None
220230
)
231+
label_aux_crop = (
232+
crop_volume(label_aux, self.patch_size, pos, pad_mode="constant")
233+
if label_aux is not None
234+
else None
235+
)
221236
mask_crop = (
222237
crop_volume(mask, self.patch_size, pos, pad_mode="constant")
223238
if mask is not None
224239
else None
225240
)
226241

227-
return {"image": image_crop, "label": label_crop, "mask": mask_crop}
242+
return {
243+
"image": image_crop,
244+
"label": label_crop,
245+
"label_aux": label_aux_crop,
246+
"mask": mask_crop,
247+
}
228248

229249
def _has_labels(self, vol_idx: int) -> bool:
230250
return self.cached_labels[vol_idx] is not None

connectomics/data/process/bbox_processor.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def process(
7171
self,
7272
label: np.ndarray,
7373
instance_fn: Callable[[np.ndarray, int, Tuple[slice, ...], Dict], Optional[np.ndarray]],
74+
num_workers: int = 0,
7475
**kwargs,
7576
) -> np.ndarray:
7677
"""
@@ -90,6 +91,10 @@ def process(
9091
Returns:
9192
- result_crop: Same shape as label_crop, or None to skip
9293
94+
num_workers: Number of threads for parallel instance processing.
95+
0 = sequential (default). Scipy EDT releases the GIL, so
96+
threads give real parallelism for the numeric heavy lifting.
97+
9398
**kwargs: Additional arguments passed to instance_fn
9499
95100
Returns:
@@ -113,30 +118,40 @@ def process(
113118
distance = self._apply_bg_value(distance)
114119
return self._postprocess(distance, was_padded)
115120

116-
# 5. Process each instance within its bounding box
117-
for i in range(bbox_array.shape[0]):
121+
# 5. Prepare per-instance work items
122+
n = bbox_array.shape[0]
123+
work_items = []
124+
for i in range(n):
118125
instance_id = int(bbox_array[i, 0])
119126
bbox = self._extract_bbox(bbox_array[i], label_shape, label.ndim)
120-
121-
# Extract instance crop
122127
label_crop = label[bbox]
123-
124-
# Call user-provided instance processing function
125-
try:
126-
result_crop = instance_fn(label_crop, instance_id, bbox, kwargs)
127-
except Exception as e:
128-
# Skip instance on error
129-
print(f"Warning: Failed to process instance {instance_id}: {e}")
130-
continue
131-
132-
# Skip if function returned None or empty result
133-
if result_crop is None or not np.any(result_crop):
134-
continue
135-
136-
# Aggregate result back to full volume
137-
self._aggregate_result(distance, bbox, result_crop)
138-
139-
# 6. Postprocessing
128+
work_items.append((label_crop, instance_id, bbox))
129+
130+
# 6. Process instances (parallel or sequential)
131+
if num_workers > 0:
132+
from concurrent.futures import ThreadPoolExecutor
133+
134+
def _run(item):
135+
label_crop, instance_id, bbox = item
136+
try:
137+
return bbox, instance_fn(label_crop, instance_id, bbox, kwargs)
138+
except Exception:
139+
return bbox, None
140+
141+
with ThreadPoolExecutor(max_workers=num_workers) as pool:
142+
for bbox, result_crop in pool.map(_run, work_items):
143+
if result_crop is not None and np.any(result_crop):
144+
self._aggregate_result(distance, bbox, result_crop)
145+
else:
146+
for label_crop, instance_id, bbox in work_items:
147+
try:
148+
result_crop = instance_fn(label_crop, instance_id, bbox, kwargs)
149+
except Exception:
150+
continue
151+
if result_crop is not None and np.any(result_crop):
152+
self._aggregate_result(distance, bbox, result_crop)
153+
154+
# 7. Postprocessing
140155
distance = self._apply_bg_value(distance)
141156
return self._postprocess(distance, was_padded)
142157

0 commit comments

Comments
 (0)