Skip to content

Commit 83425ad

Browse files
author
Donglai Wei
committed
Extract runtime dispatch helpers
1 parent eef2ea5 commit 83425ad

19 files changed

Lines changed: 1428 additions & 1475 deletions

connectomics/config/pipeline/config_io.py

Lines changed: 0 additions & 287 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616

1717
from omegaconf import DictConfig, ListConfig, OmegaConf
1818

19-
from ...data.processing.build import count_stacked_label_transform_channels
20-
from ...models.architectures.registry import get_architecture_info
21-
from ...utils.channel_slices import infer_min_required_channels
22-
from ...utils.model_outputs import resolve_configured_output_head, resolve_output_heads
2319
from ..schema import Config
2420
from ..schema.root import MergeContext
2521
from .profile_engine import _YAML_PROFILE_ENGINE
@@ -486,289 +482,6 @@ def validate_config(cfg: Config) -> None:
486482
f"when model.heads has multiple entries ({sorted(model_heads.keys())})"
487483
)
488484

489-
# --- Cross-section coherence validation (UX 2.4 #3) ---
490-
_validate_cross_section_coherence(cfg)
491-
492-
493-
def _architecture_supports_deep_supervision(arch_type: str) -> bool:
494-
"""Infer deep-supervision support from architecture registry metadata."""
495-
arch_info = get_architecture_info().get(arch_type)
496-
if arch_info is None:
497-
return True
498-
499-
module_name = arch_info.get("module", "")
500-
# Current MONAI wrappers are single-scale and do not expose deep supervision.
501-
return not module_name.endswith("monai_models")
502-
503-
504-
def _validate_cross_section_coherence(cfg: Config) -> None:
505-
"""Validate resolved cross-section coherence and raise clear errors."""
506-
# 1) model.input_size vs data.dataloader.patch_size mismatch
507-
model_input = list(cfg.model.input_size)
508-
patch_size = list(cfg.data.dataloader.patch_size)
509-
if model_input != patch_size:
510-
raise ValueError(
511-
"Cross-section validation failed: model.input_size "
512-
f"{model_input} must match data.dataloader.patch_size {patch_size}."
513-
)
514-
515-
# 2) output-channel coherence vs loss/label/decoding/activation channel usage
516-
out_channels = cfg.model.out_channels
517-
model_heads = getattr(cfg.model, "heads", None) or {}
518-
primary_head = getattr(cfg.model, "primary_head", None)
519-
sole_head = next(iter(model_heads.keys())) if len(model_heads) == 1 else None
520-
required_output_channels: List[tuple[str, int]] = []
521-
522-
label_cfg = getattr(cfg.data, "label_transform", None)
523-
stacked_label_channels = (
524-
count_stacked_label_transform_channels(label_cfg) if label_cfg is not None else None
525-
)
526-
527-
def _resolve_selector_head(entry: Any, *, selector_key: str) -> Optional[str]:
528-
if selector_key == "pred2_slice":
529-
selector_head = entry.get("pred2_head", entry.get("pred_head", None))
530-
else:
531-
selector_head = entry.get("pred_head", None)
532-
533-
if selector_head is None:
534-
selector_head = primary_head or sole_head
535-
return selector_head
536-
537-
def _validate_head_channel_capacity(
538-
*,
539-
selector_key: str,
540-
selector_head: str,
541-
min_channels: int,
542-
loss_idx: int,
543-
) -> bool:
544-
if selector_head not in model_heads:
545-
return False
546-
547-
head_channels = int(getattr(model_heads[selector_head], "out_channels", 0))
548-
if min_channels > head_channels:
549-
raise ValueError(
550-
"Cross-section validation failed: "
551-
f"model.loss.losses[{loss_idx}].{selector_key} requires at least "
552-
f"{min_channels} channels in head '{selector_head}', but "
553-
f"model.heads.{selector_head}.out_channels is {head_channels}."
554-
)
555-
return True
556-
557-
def _validate_label_channel_capacity(selector_value: Any, *, path: str) -> None:
558-
min_channels = infer_min_required_channels(selector_value, context=path)
559-
if min_channels is None:
560-
return
561-
562-
if stacked_label_channels is not None:
563-
if min_channels > stacked_label_channels:
564-
raise ValueError(
565-
"Cross-section validation failed: "
566-
f"{path} requires at least {min_channels} stacked label channels, but "
567-
f"data.label_transform.targets produces {stacked_label_channels}."
568-
)
569-
return
570-
571-
if not model_heads:
572-
required_output_channels.append((path, min_channels))
573-
574-
# 2a) Loss channel selectors
575-
model_loss_cfg = getattr(cfg.model, "loss", None)
576-
losses_cfg = getattr(model_loss_cfg, "losses", None) if model_loss_cfg else None
577-
if losses_cfg is not None:
578-
for i, entry in enumerate(losses_cfg):
579-
if not isinstance(entry, dict):
580-
continue
581-
selector_keys = ("pred_slice", "target_slice", "mask_slice", "pred2_slice")
582-
for selector_key in selector_keys:
583-
min_channels = infer_min_required_channels(
584-
entry.get(selector_key),
585-
context=f"model.loss.losses[{i}].{selector_key}",
586-
)
587-
if min_channels is not None:
588-
path = f"model.loss.losses[{i}].{selector_key}"
589-
if selector_key in {"pred_slice", "pred2_slice"} and model_heads:
590-
selector_head = _resolve_selector_head(entry, selector_key=selector_key)
591-
if selector_head is not None and _validate_head_channel_capacity(
592-
selector_key=selector_key,
593-
selector_head=selector_head,
594-
min_channels=min_channels,
595-
loss_idx=i,
596-
):
597-
continue
598-
if selector_key in {"target_slice", "mask_slice"}:
599-
_validate_label_channel_capacity(entry.get(selector_key), path=path)
600-
continue
601-
required_output_channels.append((path, min_channels))
602-
603-
# 2b) Label transform targets (legacy lower-bound expectation for flat outputs)
604-
if not model_heads and stacked_label_channels:
605-
required_output_channels.append(("data.label_transform.targets", stacked_label_channels))
606-
607-
# 2c) Explicit head-to-label routing
608-
if model_heads:
609-
for head_name, head_cfg in model_heads.items():
610-
target_slice = getattr(head_cfg, "target_slice", None)
611-
if target_slice is None:
612-
continue
613-
_validate_label_channel_capacity(
614-
target_slice,
615-
path=f"model.heads.{head_name}.target_slice",
616-
)
617-
618-
# 2d) Decoding kwargs channel selectors (*_channels)
619-
decoding_cfg = getattr(cfg, "decoding", None)
620-
decode_has_channel_selection = False
621-
decode_output_head = None
622-
decode_available_channels = out_channels
623-
decode_channel_scope = "model output"
624-
if isinstance(decoding_cfg, list):
625-
for i, decode_step in enumerate(decoding_cfg):
626-
kwargs = getattr(decode_step, "kwargs", None)
627-
if not isinstance(kwargs, dict):
628-
continue
629-
if any(key.endswith("_channels") for key in kwargs):
630-
decode_has_channel_selection = True
631-
break
632-
633-
if model_heads and decode_has_channel_selection:
634-
decode_heads = resolve_output_heads(cfg, purpose="decode channel selection")
635-
if len(model_heads) > 1 and not decode_heads:
636-
raise ValueError(
637-
"Cross-section validation failed: decode channel selectors require "
638-
"inference.head or model.primary_head when model.heads has multiple "
639-
f"entries ({sorted(model_heads.keys())})."
640-
)
641-
if len(decode_heads) > 1:
642-
decode_available_channels = sum(
643-
int(getattr(model_heads[h], "out_channels", 0)) for h in decode_heads
644-
)
645-
decode_channel_scope = f"merged heads {decode_heads}"
646-
decode_output_head = decode_heads[0]
647-
elif decode_heads:
648-
decode_output_head = decode_heads[0]
649-
if decode_output_head in model_heads:
650-
decode_available_channels = int(
651-
getattr(model_heads[decode_output_head], "out_channels", out_channels)
652-
)
653-
decode_channel_scope = f"head '{decode_output_head}'"
654-
655-
for i, decode_step in enumerate(decoding_cfg):
656-
kwargs = getattr(decode_step, "kwargs", None)
657-
if not isinstance(kwargs, dict):
658-
continue
659-
for key, value in kwargs.items():
660-
if not key.endswith("_channels"):
661-
continue
662-
min_channels = infer_min_required_channels(
663-
value,
664-
context=f"decoding[{i}].kwargs.{key}",
665-
)
666-
if min_channels is not None:
667-
path = f"decoding[{i}].kwargs.{key}"
668-
if model_heads and decode_has_channel_selection:
669-
if min_channels > decode_available_channels:
670-
raise ValueError(
671-
"Cross-section validation failed: "
672-
f"{path} requires at least {min_channels} channels in "
673-
f"{decode_channel_scope}, but only "
674-
f"{decode_available_channels} are available."
675-
)
676-
continue
677-
required_output_channels.append((path, min_channels))
678-
679-
# 2e) Inference channel selectors
680-
tta_cfg = getattr(cfg.inference, "test_time_augmentation", None)
681-
channel_activations = getattr(tta_cfg, "channel_activations", None) if tta_cfg else None
682-
select_channel = getattr(cfg.inference, "select_channel", None)
683-
inference_has_channel_selection = bool(channel_activations) or select_channel is not None
684-
tta_heads = (
685-
resolve_output_heads(cfg, purpose="inference channel selection") if model_heads else []
686-
)
687-
tta_output_head = tta_heads[0] if tta_heads else None
688-
if model_heads and len(model_heads) > 1 and inference_has_channel_selection and not tta_heads:
689-
raise ValueError(
690-
"Cross-section validation failed: inference channel selectors require inference.head "
691-
"or model.primary_head when model.heads has multiple entries "
692-
f"({sorted(model_heads.keys())})."
693-
)
694-
if len(tta_heads) > 1:
695-
tta_available_channels = sum(
696-
int(getattr(model_heads[h], "out_channels", 0)) for h in tta_heads
697-
)
698-
tta_channel_scope = f"merged heads {tta_heads}"
699-
else:
700-
tta_available_channels = (
701-
int(getattr(model_heads[tta_output_head], "out_channels", out_channels))
702-
if tta_output_head in model_heads
703-
else out_channels
704-
)
705-
tta_channel_scope = (
706-
f"head '{tta_output_head}'" if tta_output_head in model_heads else "model output"
707-
)
708-
709-
def _validate_tta_channel_capacity(selector_value: Any, *, path: str) -> None:
710-
min_selector_channels = infer_min_required_channels(
711-
selector_value,
712-
context=path,
713-
)
714-
if min_selector_channels is None:
715-
return
716-
if min_selector_channels > tta_available_channels:
717-
raise ValueError(
718-
"Cross-section validation failed: "
719-
f"{path} requires at least {min_selector_channels} channels in {tta_channel_scope}, "
720-
f"but only {tta_available_channels} are available."
721-
)
722-
723-
if isinstance(channel_activations, list):
724-
for i, spec in enumerate(channel_activations):
725-
if not isinstance(spec, dict):
726-
raise ValueError(
727-
"Cross-section validation failed: "
728-
f"inference.test_time_augmentation.channel_activations[{i}] "
729-
"must be a mapping with 'channels' and 'activation'."
730-
)
731-
if "channels" not in spec or "activation" not in spec:
732-
raise ValueError(
733-
"Cross-section validation failed: "
734-
f"inference.test_time_augmentation.channel_activations[{i}] "
735-
"must define both 'channels' and 'activation'."
736-
)
737-
_validate_tta_channel_capacity(
738-
spec["channels"],
739-
path=f"inference.test_time_augmentation.channel_activations[{i}].channels",
740-
)
741-
_validate_tta_channel_capacity(
742-
select_channel,
743-
path="inference.select_channel",
744-
)
745-
746-
if required_output_channels:
747-
required_max = max(req for _, req in required_output_channels)
748-
if required_max > out_channels:
749-
details = ", ".join(
750-
f"{path} needs >= {req}"
751-
for path, req in sorted(required_output_channels, key=lambda x: x[1], reverse=True)
752-
)
753-
raise ValueError(
754-
"Cross-section validation failed: model.out_channels is "
755-
f"{out_channels}, but resolved pipeline components require at least "
756-
f"{required_max} channels ({details})."
757-
)
758-
759-
# 3) deep_supervision=True with architectures that don't support it
760-
deep_supervision = (
761-
getattr(model_loss_cfg, "deep_supervision", False) if model_loss_cfg else False
762-
)
763-
if deep_supervision:
764-
arch_type = getattr(cfg.model.arch, "type", "")
765-
if not _architecture_supports_deep_supervision(arch_type):
766-
raise ValueError(
767-
"Cross-section validation failed: model.loss.deep_supervision=True but "
768-
f"architecture '{arch_type}' does not support deep supervision. "
769-
"Use MedNeXt/RSUNet or disable deep supervision."
770-
)
771-
772485

773486
# ---------------------------------------------------------------------------
774487
# Config hashing and naming

connectomics/config/schema/root.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
import importlib
4-
import inspect
5-
from dataclasses import dataclass, field, is_dataclass
3+
from dataclasses import dataclass, field
64
from typing import List, Optional, Set
75

6+
from ...runtime.torch_safe_globals import register_torch_safe_globals
87
from .data import DataConfig
98
from .inference import DecodeModeConfig, EvaluationConfig, InferenceConfig
109
from .model import ModelConfig
@@ -90,46 +89,4 @@ class Config:
9089
)
9190

9291

93-
def _register_torch_safe_globals() -> None:
94-
"""Register schema dataclasses for torch 2.6+ weights_only checkpoint loading."""
95-
try:
96-
import torch
97-
98-
if not (
99-
hasattr(torch, "serialization") and hasattr(torch.serialization, "add_safe_globals")
100-
):
101-
return
102-
103-
safe_dataclasses = []
104-
schema_modules = [
105-
"connectomics.config.schema.system",
106-
"connectomics.config.schema.model",
107-
"connectomics.config.schema.model_monai",
108-
"connectomics.config.schema.model_mednext",
109-
"connectomics.config.schema.model_rsunet",
110-
"connectomics.config.schema.model_nnunet",
111-
"connectomics.config.schema.data",
112-
"connectomics.config.schema.optimization",
113-
"connectomics.config.schema.monitor",
114-
"connectomics.config.schema.inference",
115-
"connectomics.config.schema.stages",
116-
"connectomics.config.schema.root",
117-
]
118-
for module_name in schema_modules:
119-
module = importlib.import_module(module_name)
120-
safe_dataclasses.extend(
121-
obj
122-
for obj in module.__dict__.values()
123-
if inspect.isclass(obj) and is_dataclass(obj)
124-
)
125-
126-
# De-duplicate while preserving order.
127-
deduped = list(dict.fromkeys(safe_dataclasses))
128-
torch.serialization.add_safe_globals(deduped)
129-
except Exception:
130-
# Best-effort registration; ignore if torch not available at import time.
131-
pass
132-
133-
134-
# Register safe globals on import.
135-
_register_torch_safe_globals()
92+
register_torch_safe_globals()

0 commit comments

Comments
 (0)