Skip to content

Commit 17dc96e

Browse files
Donglai Weiclaude
andcommitted
Unify profile engine to dict-valued profiles and fix legacy tutorial configs
All profile families (loss, decoding, activation) now use dict-valued profiles with merge semantics instead of bare lists with replacement. This makes the profile system consistent: every profile is a config subtree merged into its parent section. Migrate 12 tutorial configs from legacy flat keys (test_image, test_label, test_path, test_resolution, reject_sampling) to the modern nested DataInputConfig structure (test.data.test.*). Add mask_transform field to DataConfig and wire it into build_test_transforms and test_pipeline so vesicle_xm.yaml loads. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6466254 commit 17dc96e

21 files changed

Lines changed: 256 additions & 199 deletions

connectomics/config/pipeline/profile_engine.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
target_path: str,
195195
merge_into_existing: bool = True,
196196
cleanup_keys: Optional[Tuple[str, ...]] = None,
197+
list_key: Optional[str] = None,
197198
) -> None:
198199
if isinstance(selector_paths, str):
199200
selector_paths = [selector_paths]
@@ -202,6 +203,7 @@ def __init__(
202203
self.target_path = target_path
203204
self.merge_into_existing = merge_into_existing
204205
self.cleanup_keys = cleanup_keys if cleanup_keys is not None else (profiles_key,)
206+
self.list_key = list_key
205207

206208
def _apply_list_overrides(self, yaml_conf: DictConfig, selector_path: str) -> None:
207209
"""Apply per-index overrides to a list-valued profile.
@@ -224,7 +226,8 @@ def _apply_list_overrides(self, yaml_conf: DictConfig, selector_path: str) -> No
224226
if overrides is None or not isinstance(overrides, (DictConfig, dict)):
225227
return
226228

227-
target_list = OmegaConf.select(yaml_conf, self.target_path)
229+
list_path = f"{self.target_path}.{self.list_key}" if self.list_key else self.target_path
230+
target_list = OmegaConf.select(yaml_conf, list_path)
228231
if not isinstance(target_list, (ListConfig, list)):
229232
return
230233

@@ -343,10 +346,12 @@ def __init__(
343346
profiles_key: str,
344347
target_paths: List[str],
345348
profile_key: str = "profile",
349+
list_key: Optional[str] = None,
346350
) -> None:
347351
self.profiles_key = profiles_key
348352
self.target_paths = target_paths
349353
self.profile_key = profile_key
354+
self.list_key = list_key
350355
self.cleanup_keys = (profiles_key,)
351356

352357
def apply(self, yaml_conf: DictConfig) -> None:
@@ -376,34 +381,40 @@ def apply(self, yaml_conf: DictConfig) -> None:
376381
)
377382

378383
profile_payload = profiles[profile_name]
384+
385+
# Extract nested list from dict-valued profiles.
386+
profile_list = profile_payload
387+
if self.list_key and isinstance(profile_payload, (DictConfig, dict)):
388+
profile_list = profile_payload.get(self.list_key, profile_payload)
389+
379390
overrides = {k: v for k, v in item.items() if k != self.profile_key}
380391

381392
if not overrides:
382393
# Pure profile reference: expand entire profile list inline
383-
if not isinstance(profile_payload, (ListConfig, list)):
394+
if not isinstance(profile_list, (ListConfig, list)):
384395
raise ValueError(
385396
f"Profile '{profile_name}' in {self.profiles_key} must resolve to a list "
386-
f"for target '{target_path}', got {type(profile_payload)}"
397+
f"for target '{target_path}', got {type(profile_list)}"
387398
)
388-
expanded_items.extend(list(profile_payload))
399+
expanded_items.extend(list(profile_list))
389400
else:
390401
# Profile + overrides: use first profile item as base, merge overrides on top
391-
if isinstance(profile_payload, (ListConfig, list)):
392-
if not profile_payload:
402+
if isinstance(profile_list, (ListConfig, list)):
403+
if not profile_list:
393404
raise ValueError(
394405
f"Profile '{profile_name}' in {self.profiles_key} is empty."
395406
)
396407
base = OmegaConf.create(
397-
OmegaConf.to_container(profile_payload[0], resolve=False)
408+
OmegaConf.to_container(profile_list[0], resolve=False)
398409
)
399-
elif isinstance(profile_payload, (DictConfig, dict)):
410+
elif isinstance(profile_list, (DictConfig, dict)):
400411
base = OmegaConf.create(
401-
OmegaConf.to_container(profile_payload, resolve=False)
412+
OmegaConf.to_container(profile_list, resolve=False)
402413
)
403414
else:
404415
raise ValueError(
405416
f"Profile '{profile_name}' in {self.profiles_key} must be a list or dict, "
406-
f"got {type(profile_payload)}"
417+
f"got {type(profile_list)}"
407418
)
408419
merged = OmegaConf.merge(base, OmegaConf.create(overrides))
409420
expanded_items.append(merged)
@@ -460,26 +471,28 @@ def _stage_path(stage: str, rel_path: str) -> str:
460471

461472

462473
# Each tuple:
463-
# (profiles_key, stages, selector_rel, target_rel, merge_into_existing)
474+
# (profiles_key, stages, selector_rel, target_rel, merge_into_existing, list_key)
464475
# Order matters: pipeline/system first, then arch, then the rest.
465-
_VALUE_PROFILE_FAMILIES: List[Tuple[str, Tuple[str, ...], str, str, bool]] = [
466-
("pipeline_profiles", (_STAGE_DEFAULT,), "pipeline_profile", "", True),
467-
("system_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "system.profile", "system", True),
468-
("arch_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "model.arch.profile", "model", True),
469-
("augmentation_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "data.augmentation.profile", "data.augmentation", True),
470-
("dataloader_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "data.dataloader.profile", "data.dataloader", True),
471-
("optimizer_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "optimization.profile", "optimization", True),
472-
("loss_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "model.loss.profile", "model.loss.losses", False),
473-
("label_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "data.label_transform.profile", "data.label_transform", False),
474-
("decoding_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.decoding_profile", "inference.decoding", False),
475-
("activation_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.test_time_augmentation.activation_profile", "inference.test_time_augmentation.channel_activations", False),
476+
# All profiles are dict-valued and use merge semantics. ``list_key`` identifies the
477+
# nested list within dict-valued profiles (used for positional overrides).
478+
_VALUE_PROFILE_FAMILIES: List[Tuple[str, Tuple[str, ...], str, str, bool, Optional[str]]] = [
479+
("pipeline_profiles", (_STAGE_DEFAULT,), "pipeline_profile", "", True, None),
480+
("system_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "system.profile", "system", True, None),
481+
("arch_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "model.arch.profile", "model", True, None),
482+
("augmentation_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "data.augmentation.profile", "data.augmentation", True, None),
483+
("dataloader_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "data.dataloader.profile", "data.dataloader", True, None),
484+
("optimizer_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "optimization.profile", "optimization", True, None),
485+
("loss_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "model.loss.profile", "model.loss", True, "losses"),
486+
("label_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "data.label_transform.profile", "data.label_transform", True, None),
487+
("decoding_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.decoding_profile", "inference", True, "decoding"),
488+
("activation_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.test_time_augmentation.activation_profile", "inference.test_time_augmentation", True, "channel_activations"),
476489
]
477490

478491

479-
def _build_value_profile_specs() -> List[Tuple[List[str], str, str, bool, Tuple[str, ...]]]:
480-
specs: List[Tuple[List[str], str, str, bool, Tuple[str, ...]]] = []
492+
def _build_value_profile_specs() -> List[Tuple[List[str], str, str, bool, Tuple[str, ...], Optional[str]]]:
493+
specs: List[Tuple[List[str], str, str, bool, Tuple[str, ...], Optional[str]]] = []
481494

482-
for profiles_key, stages, selector_rel, target_rel, merge in _VALUE_PROFILE_FAMILIES:
495+
for profiles_key, stages, selector_rel, target_rel, merge, list_key in _VALUE_PROFILE_FAMILIES:
483496
for stage in stages:
484497
selector_path = _stage_path(stage, selector_rel)
485498
target_path = _stage_path(stage, target_rel)
@@ -490,6 +503,7 @@ def _build_value_profile_specs() -> List[Tuple[List[str], str, str, bool, Tuple[
490503
target_path,
491504
merge,
492505
(profiles_key, selector_path),
506+
list_key,
493507
)
494508
)
495509

@@ -498,10 +512,10 @@ def _build_value_profile_specs() -> List[Tuple[List[str], str, str, bool, Tuple[
498512

499513
# Each tuple: (profiles_key, stages, target_rel)
500514
_REFERENCE_PROFILE_FAMILIES: List[Tuple[str, Tuple[str, ...], str]] = [
501-
("loss_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "model.loss.losses"),
515+
("loss_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "model.loss"),
502516
("label_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "data.label_transform"),
503-
("decoding_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.decoding"),
504-
("activation_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.test_time_augmentation.channel_activations"),
517+
("decoding_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference"),
518+
("activation_profiles", (_STAGE_DEFAULT, _STAGE_TEST, _STAGE_TUNE), "inference.test_time_augmentation"),
505519
("augmentation_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "data.augmentation"),
506520
("dataloader_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN, _STAGE_TEST, _STAGE_TUNE), "data.dataloader"),
507521
("optimizer_profiles", (_STAGE_DEFAULT, _STAGE_TRAIN), "optimization"),
@@ -517,24 +531,25 @@ def _build_reference_profile_specs() -> List[Tuple[str, List[str]]]:
517531
return specs
518532

519533

520-
# Each tuple: (profiles_key, stages, target_rel)
521-
_LIST_REFERENCE_FAMILIES: List[Tuple[str, Tuple[str, ...], str]] = [
522-
("decoding_profiles", (_STAGE_DEFAULT, _STAGE_TUNE, _STAGE_TEST), "inference.decoding"),
534+
# Each tuple: (profiles_key, stages, target_rel, list_key)
535+
# ``list_key`` identifies the nested list within dict-valued profiles.
536+
_LIST_REFERENCE_FAMILIES: List[Tuple[str, Tuple[str, ...], str, str]] = [
537+
("decoding_profiles", (_STAGE_DEFAULT, _STAGE_TUNE, _STAGE_TEST), "inference.decoding", "decoding"),
523538
]
524539

525540

526-
def _build_list_reference_specs() -> List[Tuple[str, List[str]]]:
527-
specs: List[Tuple[str, List[str]]] = []
528-
for profiles_key, stages, target_rel in _LIST_REFERENCE_FAMILIES:
541+
def _build_list_reference_specs() -> List[Tuple[str, List[str], str]]:
542+
specs: List[Tuple[str, List[str], str]] = []
543+
for profiles_key, stages, target_rel, list_key in _LIST_REFERENCE_FAMILIES:
529544
target_paths = [_stage_path(stage, target_rel) for stage in stages]
530-
specs.append((profiles_key, target_paths))
545+
specs.append((profiles_key, target_paths, list_key))
531546
return specs
532547

533548

534549
def _build_allowed_selector_paths() -> set[str]:
535550
paths: set[str] = set()
536551

537-
for _, stages, selector_rel, _, _ in _VALUE_PROFILE_FAMILIES:
552+
for _, stages, selector_rel, _, _, _ in _VALUE_PROFILE_FAMILIES:
538553
for stage in stages:
539554
paths.add(_normalize_selector_path(_stage_path(stage, selector_rel)))
540555

@@ -574,13 +589,14 @@ def _build_profile_engine() -> YamlProfileEngine:
574589
appliers: List[YamlProfileApplier] = []
575590

576591
# 1) Value profile appliers (order defined by _VALUE_PROFILE_FAMILIES table)
577-
for selector_paths, profiles_key, target_path, merge, cleanup_keys in _VALUE_PROFILE_SPECS:
592+
for selector_paths, profiles_key, target_path, merge, cleanup_keys, list_key in _VALUE_PROFILE_SPECS:
578593
appliers.append(ValueProfileApplier(
579594
selector_paths=selector_paths,
580595
profiles_key=profiles_key,
581596
target_path=target_path,
582597
merge_into_existing=merge,
583598
cleanup_keys=cleanup_keys,
599+
list_key=list_key,
584600
))
585601

586602
# 2) Reference profile appliers
@@ -591,10 +607,11 @@ def _build_profile_engine() -> YamlProfileEngine:
591607
))
592608

593609
# 3) List profile reference appliers
594-
for profiles_key, target_paths in _LIST_REFERENCE_SPECS:
610+
for profiles_key, target_paths, list_key in _LIST_REFERENCE_SPECS:
595611
appliers.append(ListProfileReferenceApplier(
596612
profiles_key=profiles_key,
597613
target_paths=target_paths,
614+
list_key=list_key,
598615
))
599616

600617
return YamlProfileEngine(appliers)

connectomics/config/schema/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,9 @@ class DataConfig:
469469
default_factory=NNUNetPreprocessingConfig
470470
)
471471

472+
# Mask-specific transformation (overrides data_transform for masks when present)
473+
mask_transform: Optional[DataTransformConfig] = None
474+
472475
# Multi-channel label transformation (for affinity maps, distance transforms, etc.)
473476
label_transform: LabelTransformConfig = field(default_factory=LabelTransformConfig)
474477

connectomics/data/augment/build.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,11 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
319319
# Apply resize if configured (before cropping)
320320
image_resize_factors = getattr(data_cfg.image_transform, "resize", None)
321321

322+
# Prefer mask_transform over data_transform for mask-specific settings.
323+
mask_cfg = getattr(data_cfg, "mask_transform", None) or data_cfg.data_transform
322324
mask_resize_factors = None
323-
if mode in {"test", "tune"} and data_cfg.data_transform.resize is not None:
324-
mask_resize_factors = data_cfg.data_transform.resize
325+
if mode in {"test", "tune"} and mask_cfg.resize is not None:
326+
mask_resize_factors = mask_cfg.resize
325327

326328
if image_resize_factors is not None and image_resize_factors:
327329
transforms.append(
@@ -366,8 +368,8 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
366368
mask_binarize = False
367369
mask_threshold = 0.0
368370
if mode in {"test", "tune"}:
369-
mask_binarize = bool(getattr(data_cfg.data_transform, "binarize", False))
370-
mask_threshold = float(getattr(data_cfg.data_transform, "threshold", 0.0))
371+
mask_binarize = bool(getattr(mask_cfg, "binarize", False))
372+
mask_threshold = float(getattr(mask_cfg, "threshold", 0.0))
371373

372374
if "mask" in keys and mask_binarize:
373375
transforms.append(

connectomics/training/lightning/test_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,9 @@ def run_test_step(module, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
831831
logger.info("Starting sliding-window inference...")
832832

833833
mask_align_to_image = False
834-
mask_transform_cfg = getattr(module.cfg.data, "data_transform", None)
834+
mask_transform_cfg = getattr(module.cfg.data, "mask_transform", None) or getattr(
835+
module.cfg.data, "data_transform", None
836+
)
835837
if mask_transform_cfg is not None:
836838
mask_align_to_image = bool(getattr(mask_transform_cfg, "align_to_image", False))
837839

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
activation_profiles:
22
act_binary:
3-
- {channels: ":", activation: sigmoid}
3+
channel_activations:
4+
- {channels: ":", activation: sigmoid}
45

56
act_bcd:
6-
- {channels: "0:2", activation: sigmoid}
7-
- {channels: "2:3", activation: tanh}
7+
channel_activations:
8+
- {channels: "0:2", activation: sigmoid}
9+
- {channels: "2:3", activation: tanh}
810

911
act_bd:
10-
- {channels: "0:-1", activation: sigmoid}
11-
- {channels: "-1:", activation: tanh}
12+
channel_activations:
13+
- {channels: "0:-1", activation: sigmoid}
14+
- {channels: "-1:", activation: tanh}
Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
decoding_profiles:
22
decoding_bcd:
3-
- name: decode_instance_binary_contour_distance
4-
kwargs:
5-
binary_threshold: [0.9, 0.85]
6-
contour_threshold: [0.8, 1.1]
7-
distance_threshold: [0.5, -0.5]
8-
min_instance_size: 30
9-
min_seed_size: 5
3+
decoding:
4+
- name: decode_instance_binary_contour_distance
5+
kwargs:
6+
binary_threshold: [0.9, 0.85]
7+
contour_threshold: [0.8, 1.1]
8+
distance_threshold: [0.5, -0.5]
9+
min_instance_size: 30
10+
min_seed_size: 5
1011
decoding_abiss:
11-
- name: decode_abiss
12-
kwargs:
13-
command: "{python_exe} scripts/run_abiss_single.py --input {input_h5} --output {output_h5}"
14-
# Affinity thresholds accept absolute values (e.g. 0.80) or
15-
# percentiles (e.g. "80%") resolved from the data.
16-
cli_args:
17-
ws_high_threshold: 80%
18-
ws_low_threshold: 1%
19-
ws_size_threshold: 800
20-
ws_dust_threshold: 600
21-
ws_merge_threshold: 20%
12+
decoding:
13+
- name: decode_abiss
14+
kwargs:
15+
command: "{python_exe} scripts/run_abiss_single.py --input {input_h5} --output {output_h5}"
16+
# Affinity thresholds accept absolute values (e.g. 0.80) or
17+
# percentiles (e.g. "80%") resolved from the data.
18+
cli_args:
19+
ws_high_threshold: 80%
20+
ws_low_threshold: 1%
21+
ws_size_threshold: 800
22+
ws_dust_threshold: 600
23+
ws_merge_threshold: 20%

0 commit comments

Comments
 (0)