Skip to content

Commit 6f90acc

Browse files
author
Donglai Wei
committed
Normalize package layout and add MedNeXt multi-head support
1 parent a055edf commit 6f90acc

110 files changed

Lines changed: 3289 additions & 451 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.claude/feature/mednext_multi_head.md

Lines changed: 424 additions & 0 deletions
Large diffs are not rendered by default.

connectomics/config/pipeline/config_io.py

Lines changed: 282 additions & 40 deletions
Large diffs are not rendered by default.

connectomics/config/schema/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
SlidingWindowConfig,
4040
TestTimeAugmentationConfig,
4141
)
42-
from .model import LossConfig, ModelArchConfig, ModelConfig
42+
from .model import LossConfig, ModelArchConfig, ModelConfig, ModelHeadConfig
4343
from .model_mednext import MedNeXtConfig
4444
from .model_monai import MonaiConfig, TransformerConfig
4545
from .model_nnunet import NNUNetConfig
@@ -80,6 +80,7 @@
8080
# Model configuration
8181
"ModelConfig",
8282
"ModelArchConfig",
83+
"ModelHeadConfig",
8384
"MonaiConfig",
8485
"TransformerConfig",
8586
"MedNeXtConfig",

connectomics/config/schema/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ class DataConfig:
442442
- 2D data support with do_2d parameter
443443
"""
444444

445+
# Root path prepended to train/val/test split paths (empty = no prefix)
446+
root_path: str = ""
447+
445448
# Train/Val Split (inspired by DeepEM)
446449
split_enabled: bool = False # Enable automatic train/val split (default: False)
447450
split_train_range: List[float] = field(default_factory=lambda: [0.0, 0.8]) # Train: 0-80%

connectomics/config/schema/inference.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class SavePredictionConfig:
7979
output_formats: List[str] = field(default_factory=lambda: ["h5"]) # Any of: h5, tiff, png
8080
output_path: Optional[str] = None
8181
cache_suffix: str = "_x1_prediction.h5"
82+
save_all_heads: bool = False
8283

8384
# Data scaling and output typing
8485
# -1 keeps native float probabilities/logits; >0 scales and casts to integer dtype if chosen.
@@ -189,6 +190,9 @@ class InferenceConfig:
189190
Note: stage-specific overrides are merged before runtime; consumers should read `cfg.inference`.
190191
"""
191192

193+
# Named output head selection for multi-head models. When unset, falls back
194+
# to model.primary_head or the sole configured head.
195+
head: Optional[str] = None
192196
sliding_window: SlidingWindowConfig = field(default_factory=SlidingWindowConfig)
193197
test_time_augmentation: TestTimeAugmentationConfig = field(
194198
default_factory=TestTimeAugmentationConfig

connectomics/config/schema/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ class ModelArchConfig:
6464
params: Dict[str, Any] = field(default_factory=dict)
6565

6666

67+
@dataclass
68+
class ModelHeadConfig:
69+
"""Named task head configuration for MedNeXt multi-head models."""
70+
71+
out_channels: int = 1
72+
num_blocks: int = 0
73+
hidden_channels: Optional[int] = None
74+
target_slice: Optional[Any] = None
75+
76+
6777
@dataclass
6878
class ModelConfig:
6979
"""Model architecture configuration.
@@ -87,7 +97,12 @@ class ModelConfig:
8797
input_size: List[int] = field(default_factory=lambda: [128, 128, 128])
8898
output_size: List[int] = field(default_factory=lambda: [128, 128, 128])
8999
in_channels: int = 1
100+
# Legacy global output width used by single-tensor models and some fallback
101+
# code paths. Multi-head MedNeXt models should define per-head widths in
102+
# model.heads and should not rely on this to mirror the sum of all heads.
90103
out_channels: int = 1
104+
primary_head: Optional[str] = None
105+
heads: Dict[str, ModelHeadConfig] = field(default_factory=dict)
91106

92107
# Architecture-specific nested blocks
93108
monai: MonaiConfig = field(default_factory=MonaiConfig)

connectomics/config/schema/monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ImageLoggingConfig:
5858
channels: Optional[Tuple[int, ...]] = None
5959
channel_mode: str = "all" # "argmax", "all", or "selected"
6060
selected_channels: Optional[List[int]] = None
61+
head: Optional[str] = None
6162

6263

6364
@dataclass

connectomics/data/__init__.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,23 @@
1-
"""
2-
Data package for PyTorch Connectomics.
1+
"""Data package for PyTorch Connectomics.
32
43
This package provides:
5-
- Dataset classes (dataset/)
6-
- Data augmentation (augment/)
7-
- Data processing transforms (process/)
4+
- Dataset classes (datasets/)
5+
- Data augmentation (augmentation/)
6+
- Data processing transforms (processing/)
87
- I/O utilities (io/)
98
- DataModules for PyTorch Lightning (see training/lightning/data.py)
109
1110
Recommended imports:
12-
from connectomics.data.dataset import CachedVolumeDataset
13-
from connectomics.data.augment import RandMisAlignmentd, build_train_transforms
14-
from connectomics.data.process import MultiTaskLabelTransformd, create_label_transform_pipeline
11+
from connectomics.data.datasets import CachedVolumeDataset
12+
from connectomics.data.augmentation import RandMisAlignmentd, build_train_transforms
13+
from connectomics.data.processing import MultiTaskLabelTransformd, create_label_transform_pipeline
1514
"""
1615

17-
# Make submodules available
18-
from . import augment, dataset, io, process
16+
from . import augmentation, datasets, io, processing
1917

2018
__all__ = [
21-
# Submodules
22-
"augment",
23-
"dataset",
19+
"augmentation",
20+
"datasets",
2421
"io",
25-
"process",
22+
"processing",
2623
]
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)