diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index df972046..75c81837 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -28,6 +28,10 @@ jobs: pip install pytest pytest-cov pip install -e . + - name: Validate tutorial configs + run: | + python scripts/validate_tutorial_configs.py + - name: Run tests run: | pytest tests/ -v --cov=connectomics --cov-report=xml diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 00000000..baeae698 --- /dev/null +++ b/AGENT.md @@ -0,0 +1,70 @@ +# AGENT.md + +This file provides instructions for Codex and similar coding agents working in this repository. + +## Project + +PyTorch Connectomics (PyTC) is a Hydra/OmegaConf + PyTorch Lightning + MONAI codebase for EM segmentation. +Primary entry point: `scripts/main.py`. + +## Mandatory Guardrails + +- Preserve user-facing behavior unless explicitly requested. +- Preserve `scripts/main.py` CLI arguments and mode semantics. +- Preserve Hydra/OmegaConf key structure and override behavior. +- Prefer small, reviewable refactors over large rewrites. +- Do not add new runtime dependencies unless explicitly approved. + +## Environment + +- Use conda env `pytc` for all validation. +- Do not install into `base`. +- Prefer `conda run -n pytc ` for deterministic execution. + +## Required Verification + +Run these after meaningful changes (or explain why not possible): + +- `conda run -n pytc python scripts/main.py --demo` +- `conda run -n pytc pytest -q` + +For targeted changes, run focused tests plus the demo smoke test. + +## Lint/Type Checks + +Use changed-file scope unless specifically fixing global style/type debt: + +- `black --check ` +- `isort --check-only ` +- `flake8 --max-line-length=100 ` +- `mypy --config-file .github/mypy_changed.ini ` + +Note: repository-wide `black --check connectomics/` is not currently clean. + +## Code Change Style + +- Keep diffs minimal and localized. +- Avoid changing config keys and defaults without a clear migration plan. +- Keep module boundaries explicit (config, data, training, decoding). +- Add or update tests when behavior or contracts are touched. + +## Git and PR Expectations + +- Commit logically (one milestone/concern per commit). +- In PR descriptions, include: + - what changed + - why + - exact validation commands run + - key results + +## Repository Hotspots + +- Config system: `connectomics/config/` +- Data pipeline: `connectomics/data/` +- Lightning runtime: `connectomics/training/lit/` +- Decoding/postprocess: `connectomics/decoding/` + +## Safety + +- Do not use destructive git commands (`reset --hard`, checkout discards) unless explicitly requested. +- If unexpected working-tree changes appear, pause and confirm intent before touching them. diff --git a/connectomics/config/hydra_config.py b/connectomics/config/hydra_config.py index 1d962f89..adfe0de1 100644 --- a/connectomics/config/hydra_config.py +++ b/connectomics/config/hydra_config.py @@ -24,7 +24,7 @@ from connectomics.config.hydra_utils import load_config # Load configuration from YAML - cfg = load_config("tutorials/monai_lucchi++.yaml") + cfg = load_config("tutorials/mito_lucchi++.yaml") # Access configuration sections print(f"Model architecture: {cfg.model.architecture}") @@ -37,9 +37,10 @@ """ from __future__ import annotations -from dataclasses import dataclass, field, is_dataclass -from typing import Dict, List, Optional, Tuple, Any, Union + import inspect +from dataclasses import dataclass, field, is_dataclass +from typing import Any, Dict, List, Optional, Tuple, Union # Note: MISSING can be imported from omegaconf if needed for required fields @@ -946,7 +947,8 @@ class SavePredictionConfig: enabled: Enable saving intermediate predictions (default: True) intensity_scale: Scale factor for predictions (e.g., 255 for uint8 visualization) intensity_dtype: Data type for saved predictions (e.g., 'uint8', 'float32') - output_formats: List of output formats to save predictions in (e.g., ['h5', 'tiff', 'nii.gz']) + output_formats: List of output formats to save predictions + (e.g., ['h5', 'tiff', 'nii.gz']) Supported formats: 'h5', 'tiff', 'nii', 'nii.gz', 'png' Default: ['h5', 'nii.gz'] for backward compatibility """ diff --git a/connectomics/config/hydra_utils.py b/connectomics/config/hydra_utils.py index 320abee8..ab790440 100644 --- a/connectomics/config/hydra_utils.py +++ b/connectomics/config/hydra_utils.py @@ -5,13 +5,74 @@ """ from __future__ import annotations + from pathlib import Path -from typing import Optional, Union, Dict, Any, List -from omegaconf import OmegaConf, DictConfig +from typing import Any, Dict, List, Optional, Tuple, Union + +from omegaconf import DictConfig, ListConfig, OmegaConf from .hydra_config import Config +def _normalize_base_paths(base_field: Any, config_path: Path) -> List[Path]: + """Normalize `_base_` field to an ordered list of absolute paths.""" + if base_field is None: + return [] + + if isinstance(base_field, (str, Path)): + base_entries = [str(base_field)] + elif isinstance(base_field, (list, tuple, ListConfig)): + base_entries = [str(item) for item in base_field] + else: + raise TypeError( + f"Invalid _base_ value in {config_path}: expected string or list, " + f"got {type(base_field)}" + ) + + resolved_paths: List[Path] = [] + for base_entry in base_entries: + base_path = Path(base_entry) + if not base_path.is_absolute(): + base_path = (config_path.parent / base_path).resolve() + if not base_path.exists(): + raise FileNotFoundError( + f"Base config not found: {base_entry} (resolved to {base_path}) in {config_path}" + ) + resolved_paths.append(base_path) + + return resolved_paths + + +def _load_config_with_bases(config_path: Path, loading_stack: Tuple[Path, ...] = ()) -> DictConfig: + """Load YAML config recursively with `_base_` inheritance.""" + config_path = config_path.resolve() + if config_path in loading_stack: + cycle = " -> ".join(str(p) for p in (*loading_stack, config_path)) + raise ValueError(f"Detected cyclic _base_ config inheritance: {cycle}") + + if not config_path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + yaml_conf = OmegaConf.load(config_path) + if yaml_conf is None: + yaml_conf = OmegaConf.create({}) + if not isinstance(yaml_conf, DictConfig): + raise TypeError( + f"Config root must be a mapping in {config_path}, got {type(yaml_conf)} instead" + ) + + base_field = yaml_conf.get("_base_", None) + if "_base_" in yaml_conf: + del yaml_conf["_base_"] + + merged_base = OmegaConf.create({}) + for base_path in _normalize_base_paths(base_field, config_path): + base_conf = _load_config_with_bases(base_path, (*loading_stack, config_path)) + merged_base = OmegaConf.merge(merged_base, base_conf) + + return OmegaConf.merge(merged_base, yaml_conf) + + def load_config(config_path: Union[str, Path]) -> Config: """ Load configuration from YAML file. @@ -22,12 +83,8 @@ def load_config(config_path: Union[str, Path]) -> Config: Returns: Config object with defaults merged """ - config_path = Path(config_path) - if not config_path.exists(): - raise FileNotFoundError(f"Config file not found: {config_path}") - - # Load YAML - yaml_conf = OmegaConf.load(config_path) + config_path = Path(config_path).resolve() + yaml_conf = _load_config_with_bases(config_path) # Merge with structured config defaults default_conf = OmegaConf.structured(Config) @@ -190,7 +247,7 @@ def validate_config(cfg: Config) -> None: if cfg.optimization.max_epochs <= 0: raise ValueError("optimization.max_epochs must be positive when max_steps is not set") # If max_steps is set, max_epochs can be anything (will be overridden to -1 in trainer) - + if cfg.optimization.gradient_clip_val < 0: raise ValueError("optimization.gradient_clip_val must be non-negative") if cfg.optimization.accumulate_grad_batches <= 0: @@ -299,9 +356,11 @@ def _combine_path( # Handle list of paths if isinstance(file_path, list): - result = [] + result: List[str] = [] for p in file_path: resolved = _combine_path(base_path, p) + if resolved is None: + continue # If resolved is a list (from glob expansion), extend if isinstance(resolved, list): result.extend(resolved) @@ -375,42 +434,40 @@ def _combine_path( cfg.data.train_image = _combine_path(train_base, cfg.data.train_image) cfg.data.train_label = _combine_path(train_base, cfg.data.train_label) cfg.data.train_mask = _combine_path(train_base, cfg.data.train_mask) - cfg.data.train_json = _combine_path(train_base, cfg.data.train_json) + train_json_resolved = _combine_path(train_base, cfg.data.train_json) + if isinstance(train_json_resolved, list): + cfg.data.train_json = train_json_resolved[0] if train_json_resolved else None + else: + cfg.data.train_json = train_json_resolved # Resolve validation paths (always expand globs, use val_path as base if available) val_base = cfg.data.val_path if cfg.data.val_path else "" cfg.data.val_image = _combine_path(val_base, cfg.data.val_image) cfg.data.val_label = _combine_path(val_base, cfg.data.val_label) cfg.data.val_mask = _combine_path(val_base, cfg.data.val_mask) - cfg.data.val_json = _combine_path(val_base, cfg.data.val_json) + val_json_resolved = _combine_path(val_base, cfg.data.val_json) + if isinstance(val_json_resolved, list): + cfg.data.val_json = val_json_resolved[0] if val_json_resolved else None + else: + cfg.data.val_json = val_json_resolved # Resolve test data paths (cfg.test.data.test_*) - if hasattr(cfg, "test") and hasattr(cfg.test, "data"): - test_base = ( - cfg.test.data.test_path - if hasattr(cfg.test.data, "test_path") and cfg.test.data.test_path - else "" - ) - if hasattr(cfg.test.data, "test_image"): - cfg.test.data.test_image = _combine_path(test_base, cfg.test.data.test_image) - if hasattr(cfg.test.data, "test_label"): - cfg.test.data.test_label = _combine_path(test_base, cfg.test.data.test_label) - if hasattr(cfg.test.data, "test_mask"): - cfg.test.data.test_mask = _combine_path(test_base, cfg.test.data.test_mask) + if cfg.test is not None: + test_data = cfg.test.data + test_path_value = getattr(test_data, "test_path", "") + test_base = test_path_value if isinstance(test_path_value, str) else "" + test_data.test_image = _combine_path(test_base, test_data.test_image) + test_data.test_label = _combine_path(test_base, test_data.test_label) + test_data.test_mask = _combine_path(test_base, test_data.test_mask) # Resolve tuning data paths (cfg.tune.data.tune_*) - if hasattr(cfg, "tune") and hasattr(cfg.tune, "data"): - tune_base = ( - cfg.tune.data.test_path - if hasattr(cfg.tune.data, "test_path") and cfg.tune.data.test_path - else "" - ) - if hasattr(cfg.tune.data, "tune_image"): - cfg.tune.data.tune_image = _combine_path(tune_base, cfg.tune.data.tune_image) - if hasattr(cfg.tune.data, "tune_label"): - cfg.tune.data.tune_label = _combine_path(tune_base, cfg.tune.data.tune_label) - if hasattr(cfg.tune.data, "tune_mask"): - cfg.tune.data.tune_mask = _combine_path(tune_base, cfg.tune.data.tune_mask) + if cfg.tune is not None: + tune_data = cfg.tune.data + tune_path_value = getattr(tune_data, "test_path", "") + tune_base = tune_path_value if isinstance(tune_path_value, str) else "" + tune_data.tune_image = _combine_path(tune_base, tune_data.tune_image) + tune_data.tune_label = _combine_path(tune_base, tune_data.tune_label) + tune_data.tune_mask = _combine_path(tune_base, tune_data.tune_mask) return cfg diff --git a/connectomics/decoding/optuna_tuner.py b/connectomics/decoding/optuna_tuner.py index 0d59c80f..34653506 100644 --- a/connectomics/decoding/optuna_tuner.py +++ b/connectomics/decoding/optuna_tuner.py @@ -14,19 +14,20 @@ """ from __future__ import annotations -from typing import Dict, Any, Optional, Tuple, List, Callable -from pathlib import Path + import warnings from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, Optional -import numpy as np import h5py +import numpy as np from omegaconf import DictConfig, OmegaConf try: import optuna - from optuna.samplers import TPESampler, CmaEsSampler, RandomSampler - from optuna.pruners import MedianPruner, HyperbandPruner + from optuna.pruners import HyperbandPruner, MedianPruner + from optuna.samplers import CmaEsSampler, RandomSampler, TPESampler OPTUNA_AVAILABLE = True except ImportError: @@ -36,14 +37,12 @@ "Parameter tuning will not work without Optuna." ) -# Import decoding functions -from .segmentation import decode_instance_binary_contour_distance -from .utils import remove_small_instances - # Import metrics from connectomics.metrics.metrics_seg import adapted_rand -from omegaconf import OmegaConf +# Import decoding functions +from .segmentation import decode_instance_binary_contour_distance +from .utils import remove_small_instances __all__ = ["OptunaDecodingTuner", "run_tuning", "load_and_apply_best_params"] @@ -127,9 +126,13 @@ def _validate_data(self): """Validate data shapes and types.""" # Handle 2D data: (C, H, W) → (C, 1, H, W) if self.predictions.ndim == 3: - print(f" 📐 2D data detected, expanding predictions: {self.predictions.shape} → {self.predictions.shape[:1] + (1,) + self.predictions.shape[1:]}") + expanded_shape = self.predictions.shape[:1] + (1,) + self.predictions.shape[1:] + print( + " 📐 2D data detected, expanding predictions: " + f"{self.predictions.shape} → {expanded_shape}" + ) self.predictions = self.predictions[:, np.newaxis, :, :] - + # Predictions should be (C, D, H, W) if self.predictions.ndim != 4: raise ValueError( @@ -138,9 +141,13 @@ def _validate_data(self): # Handle 2D ground truth: (H, W) → (1, H, W) if self.ground_truth.ndim == 2: - print(f" 📐 2D ground truth detected, expanding: {self.ground_truth.shape} → {(1,) + self.ground_truth.shape}") + expanded_shape = (1,) + self.ground_truth.shape + print( + f" 📐 2D ground truth detected, expanding: {self.ground_truth.shape} → " + f"{expanded_shape}" + ) self.ground_truth = self.ground_truth[np.newaxis, :, :] - + # Ground truth should be (D, H, W) if self.ground_truth.ndim != 3: raise ValueError( @@ -158,9 +165,11 @@ def _validate_data(self): # Handle 2D mask if provided if self.mask is not None: if self.mask.ndim == 2: - print(f" 📐 2D mask detected, expanding: {self.mask.shape} → {(1,) + self.mask.shape}") + print( + f" 📐 2D mask detected, expanding: {self.mask.shape} → {(1,) + self.mask.shape}" + ) self.mask = self.mask[np.newaxis, :, :] - + if self.mask.shape != self.ground_truth.shape: raise ValueError( f"Mask shape {self.mask.shape} doesn't match " @@ -457,8 +466,8 @@ def _reconstruct_decoding_params(self, sampled_params: Dict[str, Any]) -> Dict[s decoding_params = dict(decoding_defaults) # Start with defaults # Group tuple parameters - tuple_params = defaultdict(dict) - scalar_params = {} + tuple_params: Dict[str, Dict[int, Any]] = defaultdict(dict) + scalar_params: Dict[str, Any] = {} for param_name, value in sampled_params.items(): # Skip post-processing parameters @@ -553,17 +562,19 @@ def _print_results(self, study: optuna.Study): print(f"Number of finished trials: {len(study.trials)}") print(f"\nBest trial: #{study.best_trial.number}") print(f" Value: {study.best_value:.4f}") - print(f"\n Params:") + print("\n Params:") # Reconstruct and print parameters best_decoding_params = self._reconstruct_decoding_params(study.best_params) for key, value in best_decoding_params.items(): print(f" {key}: {value}") - if getattr(self.param_space_cfg, "postprocessing", None) and getattr(self.param_space_cfg.postprocessing, "enabled", False): + if getattr(self.param_space_cfg, "postprocessing", None) and getattr( + self.param_space_cfg.postprocessing, "enabled", False + ): best_postproc_params = self._reconstruct_postproc_params(study.best_params) if best_postproc_params: - print(f"\n Post-processing params:") + print("\n Post-processing params:") for key, value in best_postproc_params.items(): print(f" {key}: {value}") @@ -669,16 +680,17 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None): print(f"Output directory: {output_dir}") # Step 1: Run inference on tune dataset - from connectomics.training.lit import create_datamodule - from connectomics.data.io import read_volume import glob + from connectomics.data.io import read_volume + from connectomics.training.lit import create_datamodule + print("\n[1/4] Running inference on tuning dataset...") # Get tune config sections (used later for loading predictions, ground truth, masks) tune_data = getattr(cfg.tune, "data", None) tune_output = getattr(cfg.tune, "output", None) - + if tune_data is None: raise ValueError("Missing tune.data in configuration") if tune_output is None: @@ -744,7 +756,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None): label_files = sorted(glob.glob(tune_label_pattern)) else: raise TypeError(f"tune_label must be string or list, got {type(tune_label_pattern)}") - + if not label_files: raise FileNotFoundError(f"No label files found matching pattern: {tune_label_pattern}") @@ -776,7 +788,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None): mask_files = sorted(glob.glob(tune_mask_pattern)) else: raise TypeError(f"tune_mask must be string or list, got {type(tune_mask_pattern)}") - + if not mask_files: print(f" ⚠️ No mask files found matching pattern: {tune_mask_pattern}") else: @@ -808,7 +820,7 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None): print("TUNING COMPLETED") print(f"{'='*80}") print(f"✓ Best parameters saved to: {best_params_file}") - print(f"\nBest trial:") + print("\nBest trial:") print(f" Value: {study.best_value:.4f}") print(f" Parameters: {study.best_params}") @@ -827,7 +839,7 @@ def load_and_apply_best_params(cfg): cfg: Updated configuration object with best parameters applied Example: - >>> cfg = load_config('tutorials/hydra-lv.yaml') + >>> cfg = load_config('tutorials/misc/hydra-lv.yaml') >>> cfg = load_and_apply_best_params(cfg) >>> # cfg.test now has optimized decoding parameters """ @@ -849,7 +861,7 @@ def load_and_apply_best_params(cfg): # Load best parameters best_params = OmegaConf.load(best_params_file) - print(f"✓ Loaded best parameters:") + print("✓ Loaded best parameters:") print(OmegaConf.to_yaml(best_params)) # Apply to test.decoding config @@ -869,7 +881,9 @@ def load_and_apply_best_params(cfg): # Find decoder with matching function name decoder_idx = None for idx, decoder in enumerate(cfg.test.decoding): - decoder_name = decoder.get("name") if isinstance(decoder, dict) else getattr(decoder, "name", None) + decoder_name = ( + decoder.get("name") if isinstance(decoder, dict) else getattr(decoder, "name", None) + ) if decoder_name == decoding_function: decoder_idx = idx break @@ -882,7 +896,7 @@ def load_and_apply_best_params(cfg): # Update parameters if decoder_idx < len(cfg.test.decoding): decoder = cfg.test.decoding[decoder_idx] - + # Handle both dict and config object if isinstance(decoder, dict): if "kwargs" not in decoder: diff --git a/scripts/cellmap/.gitignore b/scripts/cellmap/.gitignore deleted file mode 100644 index 45082041..00000000 --- a/scripts/cellmap/.gitignore +++ /dev/null @@ -1,25 +0,0 @@ -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so - -# Outputs -outputs/ -submission.zarr -submission.zip -datasplit.csv -*.ckpt -*.pth - -# Logs -tensorboard/ -lightning_logs/ -*.log - -# IDE -.vscode/ -.idea/ -*.swp -*.swo -*~ diff --git a/scripts/cellmap/QUICKSTART.md b/scripts/cellmap/QUICKSTART.md deleted file mode 100644 index 991166d8..00000000 --- a/scripts/cellmap/QUICKSTART.md +++ /dev/null @@ -1,309 +0,0 @@ -# CellMap Challenge - 5-Minute Quickstart - -Get started with CellMap challenge in 5 minutes! - ---- - -## 1. Install (30 seconds) - -```bash -# Activate PyTC environment -source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc - -# Install CellMap packages -pip install cellmap-data cellmap-segmentation-challenge -``` - ---- - -## 2. Quick Test (2 minutes) - -Test the full pipeline with fast-dev-run: - -```bash -# Test with 1 batch (very fast, just to verify setup) -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_cos7.yaml \ - --fast-dev-run - -# Monitor training (in another terminal) -tensorboard --logdir outputs/cellmap_cos7/tensorboard -``` - ---- - -## 3. Full Training (8-12 hours) - -Train production model on COS7 multi-organelle segmentation: - -```bash -# Train MedNeXt-M for 500 epochs -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_cos7.yaml - -# Expected time on 1x A100: ~12 hours -# Expected time on 4x A100: ~3-4 hours -``` - ---- - -## 4. Phase 2 Option A: Simple Unified (47 Semantic) - -For fast submission - all 47 classes as semantic: - -```bash -# Train MedNeXt-L for all 47 classes as semantic (8nm resolution) -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_semantic47.yaml - -# Expected time on 4x A100: ~30 hours -# Output: Multi-class binary mask (47 channels) -# No SDT, no instance separation -# Simple unified approach - faster but lower accuracy for instance classes -``` - ---- - -## 5. Phase 2 Option B: Combination (36 Semantic + 11 Instance) - -For best accuracy - specialized models with SDT: - -```bash -# Part 1: Train semantic model (36 classes) -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_semantic_full.yaml - -# Expected time on 4x A100: ~24 hours -# Output: Multi-class binary mask (softmax) -# No post-processing needed - -# Part 2: Train instance model (11 classes with SDT) -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_instance_full.yaml - -# Expected time on 4x A100: ~36 hours -# Output: Multi-class binary mask (11 channels, sigmoid) -# Post-processing: Binary → SDT → Watershed → Instance IDs -# Total time: ~60 hours for both models -``` - ---- - -## 6. Inference (1-2 hours) - -```bash -# Predict on test crops -python scripts/cellmap/predict_cellmap.py \ - --checkpoint outputs/cellmap_cos7/checkpoints/last.ckpt \ - --config tutorials/cellmap_cos7.yaml \ - --output predictions/ -``` - ---- - -## 7. Submit (10 minutes) - -```bash -# Package predictions -python scripts/cellmap/submit_cellmap.py \ - --predictions predictions/ \ - --output submission.zarr - -# Upload submission.zarr to: -# https://cellmapchallenge.janelia.org/submissions/ -``` - ---- - -## File Tree - -``` -scripts/cellmap/ -├── README.md # Full documentation -├── QUICKSTART.md # This file -│ -├── train_cellmap.py # Training script (258 lines, Hydra-based) -├── predict_cellmap.py # Inference script -├── submit_cellmap.py # Submission script -│ -tutorials/ -├── cellmap_cos7.yaml # Multi-organelle config (Hydra YAML) -└── cellmap_mito.yaml # Mitochondria config (Hydra YAML) -``` - ---- - -## Development Plan - -### Phase 1: Make Sure Simple One Works ✅ -| Config | Classes | Type | Time (4x A100) | -|--------|---------|------|---------------| -| `cellmap_mito.yaml` | 1 (mito) | Instance + SDT | ~6 hours | -| `cellmap_cos7.yaml` | 5 (nuc, mito, er, golgi, ves) | Semantic | ~3 hours | - -### Phase 2: Full Submission 🚀 - -**Option A: Simple Unified (47 Semantic)** -| Config | Classes | Type | Time (4x A100) | -|--------|---------|------|---------------| -| `cellmap_semantic47.yaml` | **All 47 as semantic** | Multi-class binary mask | ~30 hours | - -**Option B: Combination Approach (Best Accuracy)** -| Config | Classes | Type | Time (4x A100) | -|--------|---------|------|---------------| -| `cellmap_semantic_full.yaml` | **36 semantic** | Multi-class binary mask | ~24 hours | -| `cellmap_instance_full.yaml` | **11 instance + SDT** | Binary mask + SDT | ~36 hours | -| **Total** | **47 classes** | **Combined** | **~60 hours** | - -### Key Differences - -**Phase 1 Configs (Validation):** -- **`cellmap_mito.yaml`**: 1 instance class (mito) - - Binary mask output + SDT post-processing - - Validates instance pipeline: Binary → SDT → Watershed → Instance IDs - -- **`cellmap_cos7.yaml`**: 5 semantic classes - - Multi-class binary mask (softmax activation) - - Validates semantic pipeline: Direct output, no post-processing - -**Phase 2 Option A (Simple Unified):** -- **`cellmap_semantic47.yaml`**: All 47 classes as semantic - - Treats everything as multi-class binary mask - - No SDT, no instance separation - - Simpler pipeline, faster training (30h vs 60h) - - Lower accuracy for instance classes - -**Phase 2 Option B (Combination - Best Accuracy):** -- **`cellmap_semantic_full.yaml`**: 36 semantic classes - - Multi-class binary mask (softmax) - - Direct output, no post-processing - - 8nm resolution, 1000 epochs, ~24h - -- **`cellmap_instance_full.yaml`**: 11 instance classes - - Multi-class binary mask (11 channels, sigmoid) - - SDT post-processing: Binary → SDT → Watershed → Instance IDs - - 4nm resolution, 1500 epochs, ~36h - - Critical: `mito` appears in 14/16 test crops - -**Recommendation:** -- **Start with Phase 1** to validate pipelines -- **Phase 2 Option A** for quick submission (30h) -- **Phase 2 Option B** for best leaderboard score (60h) - ---- - -## Config Override Examples - -```bash -# Quick test with smaller model -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_cos7.yaml \ - model.architecture=monai_basic_unet3d \ - model.input_size="[64, 64, 64]" \ - optimization.max_epochs=100 - -# Multi-GPU training -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_cos7.yaml \ - system.training.num_gpus=4 - -# Lower batch size for GPU memory -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_cos7.yaml \ - system.training.batch_size=1 -``` - ---- - -## What You Get - -✅ **Zero PyTC modifications** - Completely isolated -✅ **Official CellMap tools** - Guaranteed compatibility -✅ **PyTC model zoo** - 8+ MONAI architectures -✅ **Hydra configs** - Standard PyTC config format -✅ **Production ready** - Lightning + callbacks + logging -✅ **Easy to use** - Just run 3 commands - ---- - -## Architecture Comparison - -The new design is **much simpler**: - -**Before (Python configs)**: -- 273 lines, custom LightningModule -- Custom loss wrappers, optimizer setup -- Python-based configuration files - -**After (Hydra configs)**: -- 258 lines, reuses PyTC's `ConnectomicsModule` -- Reuses all `scripts/main.py` infrastructure -- Standard Hydra YAML configs -- Only custom: `CellMapDataModule` (60 lines) - -**Code reuse**: -```python -from connectomics.training.lit import ( - ConnectomicsModule, # Model wrapper - create_trainer, # Trainer setup - setup_config, # Config loading - # ... everything from main.py -) -``` - ---- - -## Troubleshooting - -### Import Error - -```bash -# Error: No module named 'cellmap_data' -pip install cellmap-data cellmap-segmentation-challenge -``` - -### CUDA Out of Memory - -```bash -# Reduce batch size or patch size -python scripts/cellmap/train_cellmap.py \ - --config tutorials/cellmap_cos7.yaml \ - system.training.batch_size=1 \ - model.input_size="[96, 96, 96]" -``` - -### Data Not Found - -```bash -# Check data location -ls /projects/weilab/dataset/cellmap/ - -# Should see: jrc_cos7-1a, jrc_hela-2, etc. -``` - ---- - -## Next Steps - -1. ✅ Run quick test (`--fast-dev-run`) -2. ✅ Run full training (`cellmap_cos7.yaml`) -3. ✅ Optional: Train mito-specific model (`cellmap_mito.yaml`) -4. ✅ Predict on test set -5. ✅ Submit to challenge -6. 📊 Check leaderboard! - ---- - -## Help - -- **Full documentation**: [README.md](README.md) -- **Instance segmentation guide**: [.claude/CELLMAP_SUBMISSION.md](../../.claude/CELLMAP_SUBMISSION.md) -- **CellMap challenge**: https://www.cellmapchallenge.janelia.org/ -- **PyTC docs**: [../../CLAUDE.md](../../CLAUDE.md) - ---- - -**Time to first results**: 2 minutes (fast-dev-run) -**Time to submission**: ~12 hours (full training + inference) - -Let's go! 🚀 diff --git a/scripts/cellmap/README.md b/scripts/cellmap/README.md deleted file mode 100644 index 54447698..00000000 --- a/scripts/cellmap/README.md +++ /dev/null @@ -1,344 +0,0 @@ -# CellMap Segmentation Challenge - PyTC Integration - -This directory provides a lightweight integration between the [CellMap Segmentation Challenge](https://www.cellmapchallenge.janelia.org/) and PyTorch Connectomics (PyTC). - -## 🎯 Key Features - -- **Zero PyTC modifications** - All code isolated in `scripts/cellmap/` -- **Official CellMap tools** - Uses `cellmap-data` package for data loading -- **Hydra YAML configs** - Standard PyTC config format (no Python configs needed) -- **Full PyTC features** - Lightning callbacks, checkpointing, logging, TTA, etc. -- **419 test predictions** - Complete coverage for challenge submission - -## 📦 Installation - -```bash -# 1. Activate PyTC environment -source /projects/weilab/weidf/lib/miniconda3/bin/activate pytc - -# 2. Install CellMap packages (official challenge tools) -pip install cellmap-data cellmap-segmentation-challenge - -# 3. Verify installation -python -c "from cellmap_segmentation_challenge.utils import get_dataloader; print('✅ CellMap installed')" -``` - -## 🚀 Quick Start (5 minutes) - -```bash -# Train on COS7 multi-organelle (5 classes) -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml --fast-dev-run - -# Full training (8-12 hours on 1 GPU) -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml - -# Mitochondria-specific (optimized for instance segmentation) -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_mito.yaml -``` - -## 📋 Configuration Files - -### Hydra YAML Configs (Recommended) - -All configs use PyTC's standard Hydra format: - -**[tutorials/cellmap_cos7.yaml](../../tutorials/cellmap_cos7.yaml)** - Multi-organelle segmentation -- **Classes**: nuc, mito, er, golgi, ves (5 organelles) -- **Model**: MedNeXt-M (17.6M params) -- **Resolution**: 8nm isotropic -- **Training**: 500 epochs, ~12 hours on 1 GPU -- **Use case**: General-purpose baseline - -**[tutorials/cellmap_mito.yaml](../../tutorials/cellmap_mito.yaml)** - Mitochondria-specific -- **Classes**: mito (single class) -- **Model**: MedNeXt-L (61.8M params) - best for single class -- **Resolution**: 4nm isotropic - higher res for boundaries -- **Training**: 1000 epochs, ~24 hours on 1 GPU -- **Use case**: Best instance segmentation quality - -### Key Config Sections - -```yaml -# CellMap-specific data configuration -data: - dataset_type: cellmap # Special marker for CellMap - - cellmap: - data_root: /projects/weilab/dataset/cellmap - datasplit_path: outputs/cellmap_cos7/datasplit.csv # Auto-generated - classes: [nuc, mito, er, golgi, ves] - force_all_classes: both - - # Patch configuration - input_array_info: - shape: [128, 128, 128] - scale: [8, 8, 8] # 8nm isotropic - - # CellMap-style augmentation - spatial_transforms: - mirror: {axes: {x: 0.5, y: 0.5, z: 0.5}} - transpose: {axes: [x, y, z]} - rotate: {axes: {x: [-180, 180], y: [-180, 180], z: [-180, 180]}} -``` - -## 📊 Available Datasets - -CellMap challenge provides 23 datasets with 60+ organelle classes: - -```bash -# List all available datasets -ls /projects/weilab/dataset/cellmap/ - -# Example datasets: -# - jrc_cos7-1a, jrc_cos7-1b : COS7 cells -# - jrc_hela-2, jrc_hela-3 : HeLa cells -# - jrc_jurkat-1 : Jurkat cells -# - jrc_macrophage-2 : Macrophages -# - jrc_mus-liver, jrc_mus-kidney: Mouse organs -``` - -## 🎓 Training Workflow - -### 1. Data Preparation (Automatic) - -The datasplit is automatically generated on first run: - -```bash -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml -# Generates: outputs/cellmap_cos7/datasplit.csv -``` - -The datasplit includes: -- Train/validation split -- Crop coordinates -- Class availability per crop -- Uses CellMap's official `make_datasplit_csv()` - -### 2. Training - -```bash -# Standard training -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml - -# Override config from CLI -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml \ - system.training.num_gpus=4 \ - optimization.max_epochs=1000 - -# Multi-GPU training (automatic DDP) -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml \ - system.training.num_gpus=4 -``` - -### 3. Inference (Challenge Submission) - -For challenge submission, use the official CellMap prediction scripts: - -```bash -# 1. Run inference on all test crops (uses sliding window + TTA) -python scripts/cellmap/predict_cellmap.py \ - --checkpoint outputs/cellmap_cos7/checkpoints/last.ckpt \ - --config tutorials/cellmap_cos7.yaml \ - --output predictions/ - -# 2. Package predictions for submission -python scripts/cellmap/submit_cellmap.py \ - --predictions predictions/ \ - --output submission.zarr - -# 3. Upload submission.zarr to challenge platform -``` - -## 📁 File Structure - -``` -scripts/cellmap/ -├── train_cellmap.py # Training script (258 lines) -├── predict_cellmap.py # Inference script (for challenge submission) -├── submit_cellmap.py # Submission packaging (uses official tool) -├── README.md # This file -└── QUICKSTART.md # 5-minute quickstart guide - -tutorials/ -├── cellmap_cos7.yaml # Multi-organelle config (Hydra format) -└── cellmap_mito.yaml # Mitochondria-specific config (Hydra format) -``` - -## 🔧 How It Works - -### Simplified Architecture - -The new design is **much simpler** than the original: - -**New approach** (258 lines, Hydra configs): -- Reuses `ConnectomicsModule` from PyTC -- Reuses `create_trainer()` from PyTC -- Reuses all callbacks, logging, checkpointing -- Only custom component: `CellMapDataModule` (60 lines) - -### Code Reuse from scripts/main.py - -`train_cellmap.py` reuses almost everything from `scripts/main.py`: - -```python -from connectomics.training.lit import ( - ConnectomicsModule, # Model wrapper - create_trainer, # Trainer setup - setup_config, # Config loading - setup_run_directory, # Directory management - modify_checkpoint_state, # Checkpoint handling - # ... and more -) -``` - -**Only custom component**: `CellMapDataModule` (60 lines) -- Wraps CellMap's `get_dataloader()` in Lightning interface -- Auto-generates datasplit if missing -- Handles train/val/test splits - -## 🎯 Challenge Details - -### Task Statistics - -- **Total predictions**: 419 predictions -- **Test crops**: 16 crops across 6 datasets -- **Classes**: 47 classes (11 instance + 36 semantic) - -### Instance vs Semantic Segmentation - -**Instance Segmentation** (11 classes - harder): -- nuc, mito, ves, endo, lyso, ld, perox, np, mt, cell, vim -- Requires unique IDs per object -- Evaluated with Adapted Rand Error, VOI -- **Server auto-runs connected components** on submission - -**Semantic Segmentation** (36 classes - easier): -- All membrane/lumen subclasses, cytoplasm, etc. -- Binary masks (0/1) -- Evaluated with Dice, IoU - -See [.claude/CELLMAP_SUBMISSION.md](../../.claude/CELLMAP_SUBMISSION.md) for detailed guide. - -## 💡 Tips & Best Practices - -### Model Selection - -**Multi-class segmentation** (5+ classes): -- Use MedNeXt-M or MedNeXt-B (good balance) -- 8nm resolution is sufficient -- 500 epochs usually enough - -**Single-class segmentation** (e.g., mitochondria): -- Use MedNeXt-L (61.8M params) for best quality -- Higher resolution (4nm) for better boundaries -- Extended training (1000 epochs) -- Critical for instance segmentation - -### Training Strategy - -1. **Quick baseline** (1-2 hours): - ```bash - # Test with MONAI BasicUNet on small patch size - python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml \ - model.architecture=monai_basic_unet3d \ - model.input_size="[64, 64, 64]" \ - optimization.max_epochs=100 - ``` - -2. **Production training** (8-12 hours): - ```bash - # Full MedNeXt training - python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml - ``` - -3. **Mitochondria optimization** (24 hours): - ```bash - # Mito-specific config for best instance segmentation - python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_mito.yaml - ``` - -### Instance Segmentation Quality - -For best instance segmentation results: - -1. **Binary masks are sufficient** for initial submission - - Server automatically runs connected components - - Focus on clean boundaries - -2. **Optional: Watershed post-processing** for better quality - - Implement in `predict_cellmap.py` - - See `.claude/CELLMAP_SUBMISSION.md` for code examples - -3. **Mitochondria is the hardest task** - - Use dedicated config (`tutorials/cellmap_mito.yaml`) - - Higher resolution (4nm) - - Larger model (MedNeXt-L) - - Extended training (1000 epochs) - -## 📚 Additional Resources - -**Documentation:** -- [QUICKSTART.md](QUICKSTART.md) - 5-minute quickstart -- [.claude/CELLMAP_SUBMISSION.md](../../.claude/CELLMAP_SUBMISSION.md) - Instance segmentation guide -- [.claude/CELLMAP_CHALLENGE_SUMMARY.md](../../.claude/CELLMAP_CHALLENGE_SUMMARY.md) - Challenge overview -- [.claude/CELLMAP_INTEGRATION_DESIGN_V2.md](../../.claude/CELLMAP_INTEGRATION_DESIGN_V2.md) - Design decisions - -**Challenge Links:** -- [Challenge Website](https://www.cellmapchallenge.janelia.org/) -- [CellMap Data Package](https://github.com/janelia-cellmap/cellmap-data) -- [Challenge Utils](https://github.com/janelia-cellmap/cellmap-segmentation-challenge) - -**PyTC Documentation:** -- [Main README](../../README.md) -- [CLAUDE.md](../../CLAUDE.md) - PyTC architecture guide -- [PyTC Models](../../connectomics/models/arch/) - Available architectures - -## 🐛 Troubleshooting - -### Import Error: cellmap-data not found - -```bash -pip install cellmap-data cellmap-segmentation-challenge -``` - -### Datasplit generation fails - -```bash -# Check data exists -ls /projects/weilab/dataset/cellmap/jrc_cos7-1a/ - -# Manually generate datasplit -python -c " -from cellmap_segmentation_challenge.utils import make_datasplit_csv -make_datasplit_csv( - csv_path='outputs/cellmap_cos7/datasplit.csv', - raw_path='/projects/weilab/dataset/cellmap', - classes=['nuc', 'mito', 'er', 'golgi', 'ves'], -) -" -``` - -### GPU out of memory - -Reduce batch size or patch size: -```bash -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml \ - system.training.batch_size=1 \ - model.input_size="[96, 96, 96]" -``` - -### Training too slow - -Increase workers or use multiple GPUs: -```bash -python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml \ - system.training.num_workers=8 \ - system.training.num_gpus=4 -``` - -## 🙋 Getting Help - -1. Check [TROUBLESHOOTING.md](../../TROUBLESHOOTING.md) -2. Review [.claude/CELLMAP_*.md](../../.claude/) documentation -3. Open an issue on [PyTC GitHub](https://github.com/zudi-lin/pytorch_connectomics/issues) -4. Join [PyTC Slack](https://join.slack.com/t/pytorchconnectomics/shared_invite/zt-obufj5d1-v5_NndNS5yog8vhxy4L12w) diff --git a/scripts/cellmap/predict_cellmap.py b/scripts/cellmap/predict_cellmap.py deleted file mode 100755 index 445b48f4..00000000 --- a/scripts/cellmap/predict_cellmap.py +++ /dev/null @@ -1,315 +0,0 @@ -#!/usr/bin/env python -""" -Inference on CellMap test crops using trained PyTC model. - -Uses: -- CellMap's TEST_CROPS for official metadata -- MONAI's SlidingWindowInferer for efficient inference -- PyTC's trained models - -Usage: - python scripts/cellmap/predict_cellmap.py \ - --checkpoint outputs/cellmap_cos7/checkpoints/last.ckpt \ - --config tutorials/cellmap_cos7.yaml \ - --output predictions/ - - # Predict specific crops only - python scripts/cellmap/predict_cellmap.py \ - --checkpoint outputs/cellmap_cos7/checkpoints/last.ckpt \ - --config tutorials/cellmap_cos7.yaml \ - --crops 234,236,237 - -Requirements: - pip install cellmap-data cellmap-segmentation-challenge -""" - -import os -import sys -from pathlib import Path - -PYTC_ROOT = Path(__file__).parent.parent.parent -sys.path.insert(0, str(PYTC_ROOT)) - -import torch -import zarr -import numpy as np -from tqdm import tqdm -from monai.inferers import SlidingWindowInferer -import torch.nn.functional as F - -# CellMap utilities -from cellmap_segmentation_challenge.utils import TEST_CROPS, load_safe_config -from cellmap_segmentation_challenge.evaluate import match_crop_space - -# PyTC models -from connectomics.models import build_model -from omegaconf import OmegaConf - - -def select_scale_level(zarr_path, target_resolution): - """Return the scale path plus voxel size/translation metadata closest to target resolution.""" - store = zarr.open(zarr_path, mode='r') - - multiscale_meta = store.attrs.get('multiscales', [{}])[0] - datasets_meta = multiscale_meta.get('datasets', []) - - # Default fallback if metadata is missing - if not datasets_meta: - return { - "path": "s2", - "voxel_size": np.array(target_resolution, dtype=float), - "translation": np.zeros(3, dtype=float), - } - - best = datasets_meta[0] - min_diff = float('inf') - - for ds_meta in datasets_meta: - transforms = ds_meta.get('coordinateTransformations', []) - scale = next( - (np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'), - np.ones(3, dtype=float), - ) - avg_resolution = np.mean(scale) - diff = abs(avg_resolution - np.mean(target_resolution)) - if diff < min_diff: - min_diff = diff - best = ds_meta - - transforms = best.get('coordinateTransformations', []) - voxel_size = next( - (np.array(t.get('scale', [1, 1, 1]), dtype=float) for t in transforms if t.get('type') == 'scale'), - np.array(target_resolution, dtype=float), - ) - translation = next( - (np.array(t.get('translation', [0, 0, 0]), dtype=float) for t in transforms if t.get('type') == 'translation'), - np.zeros(3, dtype=float), - ) - - return { - "path": best.get('path', 's0'), - "voxel_size": voxel_size, - "translation": translation, - } - - -def predict_cellmap(checkpoint_path, config_path, output_dir, crop_filter=None): - """ - Run inference on all test crops. - - Args: - checkpoint_path: Path to trained model checkpoint - config_path: Path to training config file - output_dir: Directory to save predictions - crop_filter: List of crop IDs to predict (None = all crops) - """ - - # Load config - print(f"Loading config from: {config_path}") - config = load_safe_config(config_path) - classes = getattr(config, 'classes', ['nuc', 'mito', 'er']) - model_name = getattr(config, 'model_name', 'mednext') - target_resolution = getattr(config, 'input_array_info', {}).get('scale', (8, 8, 8)) - - print(f"Prediction configuration:") - print(f" Model: {model_name}") - print(f" Classes: {classes}") - print(f" Target resolution: {target_resolution} nm") - - # Build model - print(f"Building model: {model_name}") - model_config = OmegaConf.create({ - 'model': { - 'architecture': model_name, - 'in_channels': 1, - 'out_channels': len(classes), - 'mednext_size': getattr(config, 'mednext_size', 'B'), - 'mednext_kernel_size': getattr(config, 'mednext_kernel_size', 5), - 'deep_supervision': getattr(config, 'deep_supervision', True), - } - }) - model = build_model(model_config) - - # Load checkpoint - print(f"Loading checkpoint: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) - - # Handle Lightning checkpoint format - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - # Remove 'model.' prefix from keys if present - state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()} - else: - state_dict = checkpoint - - model.load_state_dict(state_dict, strict=False) - model.eval() - - # Move to GPU if available - device = 'cuda' if torch.cuda.is_available() else 'cpu' - model = model.to(device) - print(f"Using device: {device}") - - base_roi = (128, 128, 128) - inferer_cache: dict[tuple[int, int, int], SlidingWindowInferer] = {} - - def get_inferer(roi_size: tuple[int, int, int]) -> SlidingWindowInferer: - if roi_size not in inferer_cache: - inferer_cache[roi_size] = SlidingWindowInferer( - roi_size=roi_size, - sw_batch_size=4, - overlap=0.5, - mode='gaussian', - device=torch.device(device), - ) - return inferer_cache[roi_size] - - # Filter test crops if specified - if crop_filter: - crop_ids = [int(c) for c in crop_filter] - test_crops = [crop for crop in TEST_CROPS if crop.id in crop_ids] - print(f"Predicting on {len(test_crops)} crops: {crop_ids}") - else: - test_crops = TEST_CROPS - print(f"Predicting on all {len(test_crops)} test crops") - - # Group crops by dataset for efficiency - crops_by_dataset = {} - for crop in test_crops: - if crop.dataset not in crops_by_dataset: - crops_by_dataset[crop.dataset] = [] - crops_by_dataset[crop.dataset].append(crop) - - # Predict on all test crops - os.makedirs(output_dir, exist_ok=True) - - for dataset, dataset_crops in crops_by_dataset.items(): - print(f"\nProcessing dataset: {dataset}") - - zarr_path = f"/projects/weilab/dataset/cellmap/{dataset}/{dataset}.zarr" - - # Find appropriate scale level for target resolution - em_path = f"{zarr_path}/recon-1/em/fibsem-uint8" - scale_info = select_scale_level(em_path, target_resolution) - scale_level = scale_info['path'] - scale_voxel_size = scale_info['voxel_size'] - scale_translation = scale_info['translation'] - print(f" Using scale level: {scale_level} (voxel size: {scale_voxel_size} nm)") - - # Load EM data once for all crops in this dataset - try: - raw_array = zarr.open(f"{em_path}/{scale_level}", mode='r') - except Exception as e: - print(f" Error loading EM data: {e}") - print(f" Skipping dataset {dataset}") - continue - raw_shape = np.array(raw_array.shape, dtype=int) - - for crop in tqdm(dataset_crops, desc=f" Crops in {dataset}"): - crop_id = crop.id - class_label = crop.class_label - - # Skip if this class is not in our training classes - if class_label not in classes: - continue - - # Extract crop region using precise metadata - crop_output_dir = f"{output_dir}/{dataset}/crop{crop_id}" - os.makedirs(crop_output_dir, exist_ok=True) - - try: - target_shape = np.array(crop.shape, dtype=int) - target_voxel = np.array(crop.voxel_size, dtype=float) - translation_nm = np.array(crop.translation, dtype=float) - - physical_extent = target_shape * target_voxel - start_idx = np.floor((translation_nm - scale_translation) / scale_voxel_size).astype(int) - end_idx = np.ceil((translation_nm + physical_extent - scale_translation) / scale_voxel_size).astype(int) - - end_idx = np.maximum(end_idx, start_idx + 1) - start_idx = np.clip(start_idx, 0, np.maximum(raw_shape - 1, 0)) - end_idx = np.clip(end_idx, start_idx + 1, raw_shape) - - slices = tuple(slice(int(s), int(e)) for s, e in zip(start_idx, end_idx)) - raw_volume = raw_array[slices] - - # Normalize and convert to tensor - raw_volume = np.array(raw_volume).astype(np.float32) / 255.0 - raw_tensor = torch.from_numpy(raw_volume[None, None, ...]).to(device) # (1, 1, D, H, W) - - roi_size = tuple( - int(max(1, min(base_dim, vol_dim))) - for base_dim, vol_dim in zip(base_roi, raw_volume.shape) - ) - inferer = get_inferer(roi_size) - - # Run inference - with torch.no_grad(): - predictions = inferer(raw_tensor, model) - predictions = torch.sigmoid(predictions).cpu().numpy()[0] # (C, D, H, W) - - # Resize predictions back to the official crop shape if needed - target_shape_tuple = tuple(int(x) for x in target_shape) - if predictions.shape[1:] != target_shape_tuple: - pred_tensor = torch.from_numpy(predictions).unsqueeze(0) - predictions = ( - F.interpolate( - pred_tensor, - size=target_shape_tuple, - mode="trilinear", - align_corners=False, - ) - .squeeze(0) - .cpu() - .numpy() - ) - - # Save predictions for each class - for i, cls in enumerate(classes): - pred_array = (predictions[i] > 0.5).astype(np.uint8) - - # Save as Zarr (CellMap format) - zarr_out_path = f"{crop_output_dir}/{cls}" - os.makedirs(zarr_out_path, exist_ok=True) - - zarr_out = zarr.open( - f"{zarr_out_path}/s0", - mode='w', - shape=pred_array.shape, - dtype='uint8', - chunks=(64, 64, 64), - compressor=zarr.Blosc(cname='zstd', clevel=5), - ) - zarr_out[:] = pred_array - - # Add metadata - zarr_out.attrs['voxel_size'] = crop.voxel_size - zarr_out.attrs['translation'] = crop.translation - zarr_out.attrs['shape'] = crop.shape - - except Exception as e: - print(f" Error processing crop {crop_id}: {e}") - continue - - print(f"\nPredictions saved to: {output_dir}") - print(f"Next step: python scripts/cellmap/submit_cellmap.py --predictions {output_dir}") - - -if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser(description='Run inference on CellMap test crops') - parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint') - parser.add_argument('--config', required=True, help='Path to config file') - parser.add_argument('--output', default='outputs/cellmap/predictions', help='Output directory') - parser.add_argument('--crops', type=str, help='Comma-separated crop IDs to predict (default: all)') - args = parser.parse_args() - - crop_filter = args.crops.split(',') if args.crops else None - - predict_cellmap( - checkpoint_path=args.checkpoint, - config_path=args.config, - output_dir=args.output, - crop_filter=crop_filter, - ) diff --git a/scripts/cellmap/submit_cellmap.py b/scripts/cellmap/submit_cellmap.py deleted file mode 100755 index d489f7b2..00000000 --- a/scripts/cellmap/submit_cellmap.py +++ /dev/null @@ -1,165 +0,0 @@ -#!/usr/bin/env python -""" -Package predictions for CellMap challenge submission. - -Uses CellMap's official packaging utility - guaranteed to work! - -This script: -1. Resamples predictions to match test crop resolutions -2. Validates prediction format -3. Packages into submission.zarr -4. Creates submission.zip for upload - -Usage: - python scripts/cellmap/submit_cellmap.py \ - --predictions outputs/cellmap/predictions \ - --output submission.zarr - - # Then upload submission.zip to challenge portal - # https://cellmapchallenge.janelia.org/submissions/ - -Requirements: - pip install cellmap-data cellmap-segmentation-challenge -""" - -import os -import sys -from pathlib import Path - -PYTC_ROOT = Path(__file__).parent.parent.parent -sys.path.insert(0, str(PYTC_ROOT)) - -from cellmap_segmentation_challenge.utils import package_submission - - -def submit_cellmap(predictions_dir, output_path, overwrite=True, max_workers=None): - """ - Package predictions for CellMap challenge submission. - - This uses CellMap's official packaging utility which: - - Resamples predictions to match test crop resolution/shape - - Validates format and metadata - - Creates Zarr archive - - Zips for upload - - Args: - predictions_dir: Directory containing predictions (from predict_cellmap.py) - output_path: Output path for submission.zarr - overwrite: Whether to overwrite existing submission - max_workers: Number of parallel workers (default: CPU count) - """ - - if max_workers is None: - max_workers = os.cpu_count() - - print("CellMap Challenge Submission Packager") - print("=" * 60) - print(f"Input predictions: {predictions_dir}") - print(f"Output: {output_path}") - print(f"Workers: {max_workers}") - print() - - # Check predictions directory exists - if not os.path.exists(predictions_dir): - print(f"Error: Predictions directory not found: {predictions_dir}") - print("Run predict_cellmap.py first to generate predictions.") - sys.exit(1) - - # Use official packaging (handles resampling, validation, zipping) - print("Packaging submission...") - print("This will:") - print(" 1. Resample predictions to match test crop resolutions") - print(" 2. Validate format and metadata") - print(" 3. Create submission.zarr") - print(" 4. Create submission.zip") - print() - - try: - package_submission( - input_search_path=predictions_dir, - output_path=output_path, - overwrite=overwrite, - max_workers=max_workers, - ) - except Exception as e: - print(f"\nError during packaging: {e}") - print("\nPlease check:") - print(" 1. Predictions directory structure is correct") - print(" 2. All required test crops have predictions") - print(" 3. Zarr arrays have correct metadata") - sys.exit(1) - - # Check output - zip_path = output_path.replace('.zarr', '.zip') - - if os.path.exists(zip_path): - print() - print("=" * 60) - print("Submission packaged successfully!") - print() - print(f"Submission file: {zip_path}") - print(f"File size: {os.path.getsize(zip_path) / 1e9:.2f} GB") - print() - print("Next steps:") - print(" 1. Verify submission.zip is complete") - print(" 2. Upload to: https://cellmapchallenge.janelia.org/submissions/") - print(" 3. Check evaluation results on leaderboard") - print() - else: - print("\nWarning: Submission zip file not found!") - print("Packaging may have failed.") - sys.exit(1) - - -if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser( - description='Package predictions for CellMap challenge submission', - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Package predictions - python scripts/cellmap/submit_cellmap.py --predictions outputs/predictions - - # Custom output path - python scripts/cellmap/submit_cellmap.py \\ - --predictions outputs/predictions \\ - --output my_submission.zarr - - # Use more workers for faster packaging - python scripts/cellmap/submit_cellmap.py \\ - --predictions outputs/predictions \\ - --workers 32 - """ - ) - parser.add_argument( - '--predictions', - default='outputs/cellmap/predictions', - help='Directory containing predictions (default: outputs/cellmap/predictions)' - ) - parser.add_argument( - '--output', - default='submission.zarr', - help='Output path for submission.zarr (default: submission.zarr)' - ) - parser.add_argument( - '--no-overwrite', - action='store_true', - help='Do not overwrite existing submission' - ) - parser.add_argument( - '--workers', - type=int, - default=None, - help='Number of parallel workers (default: CPU count)' - ) - - args = parser.parse_args() - - submit_cellmap( - predictions_dir=args.predictions, - output_path=args.output, - overwrite=not args.no_overwrite, - max_workers=args.workers, - ) diff --git a/scripts/cellmap/train_cellmap.py b/scripts/cellmap/train_cellmap.py deleted file mode 100755 index 9dbb27d1..00000000 --- a/scripts/cellmap/train_cellmap.py +++ /dev/null @@ -1,326 +0,0 @@ -#!/usr/bin/env python3 -""" -CellMap training script using PyTC's Lightning framework. - -This script provides a thin wrapper that: -1. Creates CellMap dataloaders using cellmap-data package -2. Wraps them in PyTC's Lightning DataModule interface -3. Reuses all PyTC training infrastructure (model building, checkpointing, logging) - -Usage: - python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml - python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_mito.yaml - -Requirements: - pip install cellmap-data cellmap-segmentation-challenge -""" - -import sys -from pathlib import Path - -# Add parent directory to path for imports -PYTC_ROOT = Path(__file__).parent.parent.parent -sys.path.insert(0, str(PYTC_ROOT)) - -import torch -import pytorch_lightning as pl - -from connectomics.config import Config -from connectomics.training.lit import ( - ConnectomicsModule, - cleanup_run_directory, - create_trainer, - modify_checkpoint_state, - parse_args, - setup_config, - setup_run_directory, - setup_seed_everything, -) - -# CellMap data loading (official) -try: - from cellmap_segmentation_challenge.utils import get_dataloader, make_datasplit_csv -except ImportError: - print("❌ Error: cellmap-data not installed") - print(" Please run: pip install cellmap-data cellmap-segmentation-challenge") - sys.exit(1) - -# Setup seed_everything with version fallback -seed_everything = setup_seed_everything() - - -class CellMapDataModule(pl.LightningDataModule): - """ - Lightning DataModule wrapper for CellMap dataloaders. - - This class bridges CellMap's get_dataloader() with PyTC's Lightning framework. - """ - - class _KeyMappingLoader: - """Adapter to rename CellMap batch keys to PyTC conventions.""" - - def __init__(self, loader): - self.loader = loader - - def __iter__(self): - for batch in self.loader: - yield self._map_batch(batch) - - def __len__(self): - return len(self.loader) - - @property - def dataset(self): - return getattr(self.loader, "dataset", None) - - @property - def batch_size(self): - return getattr(self.loader, "batch_size", None) - - def _map_batch(self, batch): - mapped = {} - if "input" in batch: - mapped["image"] = batch["input"] - if "output" in batch: - label = batch["output"] - # Replace any NaNs/infs coming from upstream data transforms - if torch.isnan(label).any() or torch.isinf(label).any(): - label = torch.nan_to_num(label, nan=0.0, posinf=0.0, neginf=0.0) - mapped["label"] = label.clamp_(0.0, 1.0) - for key, value in batch.items(): - if key in {"input", "output"}: - continue - if key == "__metadata__": - mapped["metadata"] = value - else: - mapped[key] = value - return mapped - - def __init__( - self, - cfg: Config, - mode: str = "train", - ): - super().__init__() - self.cfg = cfg - self.mode = mode - self.train_loader = None - self.val_loader = None - self.test_loader = None - - def prepare_data(self): - """Prepare data (download, generate datasplit, etc.)""" - cellmap_cfg = self.cfg.data.cellmap - - # Ensure datasplit exists - datasplit_path = Path(cellmap_cfg["datasplit_path"]) - if not datasplit_path.exists(): - print(f"🔧 Generating datasplit: {datasplit_path}") - datasplit_path.parent.mkdir(parents=True, exist_ok=True) - - # Extract scale from input_array_info - scale = cellmap_cfg["input_array_info"]["scale"] - - # Build search path from data_root - data_root = cellmap_cfg["data_root"] - search_path = f"{data_root}/{{dataset}}/{{dataset}}.zarr/recon-1/{{name}}" - - # Use CellMap's official datasplit generator - make_datasplit_csv( - csv_path=str(datasplit_path), - classes=cellmap_cfg["classes"], - scale=scale, - force_all_classes=cellmap_cfg["force_all_classes"], - search_path=search_path, - ) - print(f"✅ Datasplit generated: {datasplit_path}") - - @staticmethod - def _unwrap_loader(loader): - """Return the underlying PyTorch DataLoader if wrapped by CellMapDataLoader.""" - if loader is None: - return None - return getattr(loader, "loader", loader) - - def setup(self, stage: str = None): - """Setup train/val/test dataloaders""" - cellmap_cfg = self.cfg.data.cellmap - - # Get system config based on mode - if stage == "fit" or stage is None: - system_cfg = self.cfg.system.training - else: - system_cfg = self.cfg.system.inference - - # Use CUDA only when there are no multiprocessing workers (safe on main process); - # otherwise force CPU to avoid CUDA init in forked workers. - dataloader_device = ( - "cuda" - if torch.cuda.is_available() - and system_cfg.num_gpus > 0 - and system_cfg.num_workers == 0 - else "cpu" - ) - - # Get absolute path to datasplit - from pathlib import Path as PathLib - csv_path_abs = str(PathLib(cellmap_cfg["datasplit_path"]).absolute()) - print(f"📂 Loading datasplit from: {csv_path_abs}") - - # Common dataloader kwargs - dataloader_kwargs = { - "batch_size": system_cfg.batch_size, - "datasplit_path": csv_path_abs, - "classes": cellmap_cfg["classes"], - "input_array_info": cellmap_cfg["input_array_info"], - "target_array_info": cellmap_cfg["target_array_info"], - "num_workers": system_cfg.num_workers, - "device": dataloader_device, - } - - if stage == "fit" or stage is None: - print("📦 Creating CellMap train/val dataloaders...") - train_loader, val_loader = get_dataloader( - **dataloader_kwargs, - spatial_transforms=cellmap_cfg["spatial_transforms"], - iterations_per_epoch=self.cfg.data.iter_num_per_epoch, - ) - self.train_loader = self._KeyMappingLoader(self._unwrap_loader(train_loader)) - self.val_loader = ( - self._KeyMappingLoader(self._unwrap_loader(val_loader)) - if val_loader is not None - else None - ) - if self.train_loader is not None: - print(f" Train batches per epoch: {len(self.train_loader)}") - if self.val_loader is not None: - print(f" Val batches: {len(self.val_loader)}") - - if stage == "test": - print("📦 Creating CellMap test dataloader...") - # Note: For CellMap challenge submission, you'd need test crops - # This is a placeholder for when test data has labels - if hasattr(self.cfg, "test") and self.cfg.test.data.test_image: - test_loader, _ = get_dataloader( - **dataloader_kwargs, - spatial_transforms=None, - iterations_per_epoch=self.cfg.data.iter_num_per_epoch, - ) - self.test_loader = self._KeyMappingLoader(self._unwrap_loader(test_loader)) - if self.test_loader is not None: - print(f" Test batches: {len(self.test_loader)}") - - def train_dataloader(self): - return self.train_loader - - def val_dataloader(self): - return self.val_loader - - def test_dataloader(self): - return self.test_loader - - -def main(): - """Main training function (reuses main.py logic)""" - # Parse arguments (same as main.py) - args = parse_args() - - # Validate config is provided - if not args.config: - print("❌ Error: --config is required") - print("\nUsage:") - print(" python scripts/cellmap/train_cellmap.py --config tutorials/cellmap_cos7.yaml") - sys.exit(1) - - # Setup config (same as main.py) - print("\n" + "=" * 60) - print("🚀 CellMap Training with PyTC Lightning Framework") - print("=" * 60) - cfg = setup_config(args) - - # Validate CellMap config - if not hasattr(cfg.data, "cellmap"): - print("❌ Error: Config must have data.cellmap section") - print(" See tutorials/cellmap_cos7.yaml for example") - sys.exit(1) - - # Setup run directory (same as main.py) - dirpath = cfg.monitor.checkpoint.dirpath - run_dir = setup_run_directory(args.mode, cfg, dirpath) - output_base = run_dir.parent - - # Set random seed (same as main.py) - if cfg.system.seed is not None: - print(f"🎲 Random seed set to: {cfg.system.seed}") - seed_everything(cfg.system.seed, workers=True) - - # Create model (same as main.py) - print(f"Creating model: {cfg.model.architecture}") - model = ConnectomicsModule(cfg) - - # Count parameters - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f" Model parameters: {num_params:,}") - - # Handle checkpoint (same as main.py) - ckpt_path = modify_checkpoint_state( - args.checkpoint, - run_dir, - reset_optimizer=args.reset_optimizer, - reset_scheduler=args.reset_scheduler, - reset_epoch=args.reset_epoch, - reset_early_stopping=args.reset_early_stopping, - ) - - # Create trainer (same as main.py) - trainer = create_trainer( - cfg, - run_dir=run_dir, - fast_dev_run=args.fast_dev_run, - ckpt_path=ckpt_path, - mode=args.mode, - ) - - # Create CellMap datamodule (custom for CellMap) - datamodule = CellMapDataModule(cfg, mode=args.mode) - - # Training/testing workflow (same as main.py) - try: - if args.mode == "train": - print("\n" + "=" * 60) - print("🏃 STARTING TRAINING") - print("=" * 60) - - trainer.fit( - model, - datamodule=datamodule, - ckpt_path=ckpt_path, - ) - print("\n✅ Training completed successfully!") - - elif args.mode == "test": - print("\n" + "=" * 60) - print("🧪 RUNNING TEST") - print("=" * 60) - - trainer.test( - model, - datamodule=datamodule, - ckpt_path=ckpt_path, - ) - - except Exception as e: - mode_name = args.mode.capitalize() if args.mode else "Operation" - print(f"\n❌ {mode_name} failed: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) - finally: - # Cleanup (same as main.py) - if args.mode == "train": - cleanup_run_directory(output_base) - - -if __name__ == "__main__": - main() diff --git a/scripts/main.py b/scripts/main.py index fc248f19..27c7f784 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -10,20 +10,20 @@ Usage: # Basic training - python scripts/main.py --config tutorials/lucchi.yaml + python scripts/main.py --config tutorials/mito_lucchi++.yaml # Testing mode - python scripts/main.py --config tutorials/lucchi.yaml --mode test --checkpoint path/to/checkpoint.ckpt + python scripts/main.py --config tutorials/mito_lucchi++.yaml --mode test --checkpoint path/to/checkpoint.ckpt # Fast dev run (1 batch for debugging, auto-sets num_gpus=1, num_cpus=1, num_workers=1) - python scripts/main.py --config tutorials/lucchi.yaml --fast-dev-run - python scripts/main.py --config tutorials/lucchi.yaml --fast-dev-run 2 # Run 2 batches + python scripts/main.py --config tutorials/mito_lucchi++.yaml --fast-dev-run + python scripts/main.py --config tutorials/mito_lucchi++.yaml --fast-dev-run 2 # Run 2 batches # Override config parameters - python scripts/main.py --config tutorials/lucchi.yaml data.batch_size=8 optimization.max_epochs=200 + python scripts/main.py --config tutorials/mito_lucchi++.yaml data.batch_size=8 optimization.max_epochs=200 # Resume training with different max_epochs - python scripts/main.py --config tutorials/lucchi.yaml --checkpoint path/to/ckpt.ckpt --reset-max-epochs 500 + python scripts/main.py --config tutorials/mito_lucchi++.yaml --checkpoint path/to/ckpt.ckpt --reset-max-epochs 500 """ import sys @@ -153,7 +153,7 @@ def main(): if not args.config: print("❌ Error: --config is required (or use --demo for a quick test)") print("\nUsage:") - print(" python scripts/main.py --config tutorials/lucchi.yaml") + print(" python scripts/main.py --config tutorials/mito_lucchi++.yaml") print(" python scripts/main.py --demo") sys.exit(1) diff --git a/scripts/profile_dataloader.py b/scripts/profile_dataloader.py index b7ae5f59..f305941e 100644 --- a/scripts/profile_dataloader.py +++ b/scripts/profile_dataloader.py @@ -12,6 +12,7 @@ from connectomics.config import load_config from scripts.main import create_datamodule + def profile_dataloader(config_path: str, num_batches: int = 10): """Profile dataloader performance.""" @@ -22,8 +23,8 @@ def profile_dataloader(config_path: str, num_batches: int = 10): # Load config cfg = load_config(config_path) print(f"Config: {config_path}") - print(f"Batch size: {cfg.data.batch_size}") - print(f"Num workers: {cfg.data.num_workers}") + print(f"Batch size: {cfg.system.training.batch_size}") + print(f"Num workers: {cfg.system.training.num_workers}") print(f"Iter num per epoch: {cfg.data.iter_num_per_epoch}") print() @@ -77,7 +78,7 @@ def profile_dataloader(config_path: str, num_batches: int = 10): print(f"Min batch time: {min(batch_times):.3f}s") print(f"Max batch time: {max(batch_times):.3f}s") print(f"Throughput: {num_batches/total_time:.2f} batches/sec") - print(f"Samples/sec: {num_batches * cfg.data.batch_size / total_time:.2f}") + print(f"Samples/sec: {num_batches * cfg.system.training.batch_size / total_time:.2f}") print() # Recommendations @@ -90,14 +91,14 @@ def profile_dataloader(config_path: str, num_batches: int = 10): if avg_time > 1.0: print("⚠️ SLOW: Average batch time > 1s") print(" Recommendations:") - print(f" - Increase num_workers (current: {cfg.data.num_workers})") + print(f" - Increase num_workers (current: {cfg.system.training.num_workers})") print(" - Enable caching if possible") print(" - Check disk I/O performance") print(" - Simplify transform pipeline") elif avg_time > 0.5: print("⚠️ MODERATE: Average batch time 0.5-1.0s") print(" Could be improved with:") - print(f" - More workers (current: {cfg.data.num_workers})") + print(f" - More workers (current: {cfg.system.training.num_workers})") print(" - Caching strategy") else: print("✓ GOOD: Average batch time < 0.5s") @@ -105,7 +106,7 @@ def profile_dataloader(config_path: str, num_batches: int = 10): print() if __name__ == "__main__": - config_path = "tutorials/monai_lucchi.yaml" + config_path = "tutorials/mito_lucchi++.yaml" if len(sys.argv) > 1: config_path = sys.argv[1] diff --git a/scripts/validate_tutorial_configs.py b/scripts/validate_tutorial_configs.py new file mode 100644 index 00000000..90b7c0ad --- /dev/null +++ b/scripts/validate_tutorial_configs.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +"""Validate top-level tutorial YAML configs. + +Checks: +1. Config can be loaded by the Hydra/OmegaConf loader. +2. Legacy keys that should not appear in top-level tutorials are absent. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Any, Iterable, List, Tuple + +import yaml + +from connectomics.config import load_config + + +LEGACY_PATTERNS: List[Tuple[Tuple[str, ...], str]] = [ + (("inference", "data"), "Use `test.data` instead of `inference.data`."), + ( + ("data", "augmentation", "enabled"), + "Use augmentation `preset` + per-transform `enabled` flags.", + ), + ( + ("inference", "test_time_augmentation", "act"), + "Use `inference.test_time_augmentation.channel_activations`.", + ), +] + + +def _has_path(data: Any, path: Tuple[str, ...]) -> bool: + cur = data + for key in path: + if not isinstance(cur, dict) or key not in cur: + return False + cur = cur[key] + return True + + +def _load_yaml(path: Path) -> Any: + with path.open("r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} + + +def _iter_config_paths(glob_patterns: Iterable[str]) -> List[Path]: + paths: List[Path] = [] + for pattern in glob_patterns: + paths.extend(Path().glob(pattern)) + # Keep deterministic order and unique paths. + return sorted(set(p for p in paths if p.is_file())) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Validate tutorial YAML configs.") + parser.add_argument( + "--glob", + action="append", + default=["tutorials/*.yaml"], + help="Glob pattern to include (can be passed multiple times).", + ) + args = parser.parse_args() + + config_paths = _iter_config_paths(args.glob) + if not config_paths: + print("No matching config files found.") + return 1 + + errors: List[str] = [] + for config_path in config_paths: + raw = _load_yaml(config_path) + + for pattern, message in LEGACY_PATTERNS: + if _has_path(raw, pattern): + dotted = ".".join(pattern) + errors.append(f"{config_path}: legacy key `{dotted}` found. {message}") + + try: + load_config(config_path) + except Exception as exc: # pragma: no cover - exact exception type may vary. + errors.append(f"{config_path}: failed to load ({type(exc).__name__}: {exc})") + + if errors: + print("Tutorial config validation failed:") + for err in errors: + print(f" - {err}") + return 1 + + print(f"Validated {len(config_paths)} tutorial configs successfully.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/test_hydra_utils_inheritance.py b/tests/unit/test_hydra_utils_inheritance.py new file mode 100644 index 00000000..d3b30154 --- /dev/null +++ b/tests/unit/test_hydra_utils_inheritance.py @@ -0,0 +1,78 @@ +from pathlib import Path + +import pytest + +from connectomics.config import load_config + + +def test_load_config_with_base_relative_path(tmp_path: Path): + base = tmp_path / "base.yaml" + base.write_text( + """ +system: + seed: 123 +model: + architecture: monai_unet +""" + ) + + child = tmp_path / "child.yaml" + child.write_text( + """ +_base_: base.yaml +system: + seed: 42 +""" + ) + + cfg = load_config(child) + assert cfg.system.seed == 42 + assert cfg.model.architecture == "monai_unet" + + +def test_load_config_with_multiple_bases_order(tmp_path: Path): + base_a = tmp_path / "base_a.yaml" + base_a.write_text( + """ +system: + seed: 1 +model: + architecture: monai_unet +""" + ) + + base_b = tmp_path / "base_b.yaml" + base_b.write_text( + """ +system: + seed: 2 +model: + architecture: mednext +""" + ) + + child = tmp_path / "child.yaml" + child.write_text( + """ +_base_: + - base_a.yaml + - base_b.yaml +system: + seed: 3 +""" + ) + + cfg = load_config(child) + # Later bases override earlier ones, child overrides both. + assert cfg.system.seed == 3 + assert cfg.model.architecture == "mednext" + + +def test_load_config_with_cyclic_base_raises(tmp_path: Path): + a = tmp_path / "a.yaml" + b = tmp_path / "b.yaml" + a.write_text("_base_: b.yaml\n") + b.write_text("_base_: a.yaml\n") + + with pytest.raises(ValueError, match="cyclic _base_ config inheritance"): + load_config(a) diff --git a/tests/unit/test_main_cli_contract.py b/tests/unit/test_main_cli_contract.py index 720c9124..de485d39 100644 --- a/tests/unit/test_main_cli_contract.py +++ b/tests/unit/test_main_cli_contract.py @@ -37,13 +37,13 @@ def test_parse_args_preserves_overrides_passthrough(monkeypatch): monkeypatch, [ "--config", - "tutorials/lucchi++.yaml", + "tutorials/mito_lucchi++.yaml", "data.batch_size=8", "optimization.max_epochs=3", ], ) - assert args.config == "tutorials/lucchi++.yaml" + assert args.config == "tutorials/mito_lucchi++.yaml" assert args.overrides == ["data.batch_size=8", "optimization.max_epochs=3"] diff --git a/tutorials/README.md b/tutorials/README.md new file mode 100644 index 00000000..e2f37b92 --- /dev/null +++ b/tutorials/README.md @@ -0,0 +1,42 @@ +# Tutorial Configs + +Top-level tutorial configs are in this folder and are intended to be runnable with: + +```bash +python scripts/main.py --config tutorials/.yaml +``` + +## Active top-level configs + +- `tutorials/mito_lucchi++.yaml`: Lucchi++ mitochondria segmentation (MONAI UNet). +- `tutorials/mito_mitoEM.yaml`: MitoEM mitochondria instance segmentation (MedNeXt, SDT). +- `tutorials/mito_mitolab.yaml`: CEM-MitoLab 2D mitochondria segmentation (MedNeXt). +- `tutorials/mito_betaseg.yaml`: BetaSeg mitochondria instance segmentation (MedNeXt, affinity+SDT). +- `tutorials/neuron_snemi.yaml`: SNEMI3D neuron segmentation (RSUNet, affinities). +- `tutorials/nuc_nucmm-z.yaml`: NucMM zebrafish nuclei segmentation (MONAI UNet, multi-task). +- `tutorials/fiber_linghu26.yaml`: Fiber segmentation (MedNeXt, binary+boundary+distance). + +## Config composition (`_base_`) + +Top-level configs now use inheritance via `_base_`: + +- `tutorials/bases/common.yaml`: Shared defaults across top-level tutorials. +- `tutorials/bases/mednext.yaml`: MedNeXt-specific shared settings. +- `tutorials/bases/monai_unet.yaml`: MONAI UNet shared settings. +- `tutorials/bases/rsunet.yaml`: RSUNet shared settings. + +`_base_` supports: + +- A single file path (`_base_: bases/common.yaml`) +- A list of files (`_base_: [a.yaml, b.yaml]`) with left-to-right merge order +- Relative paths resolved from the current config file + +## Validation + +Validate top-level tutorial configs: + +```bash +python scripts/validate_tutorial_configs.py +``` + +This check fails if a config cannot load or if legacy keys reappear (`inference.data`, `data.augmentation.enabled`, or `inference.test_time_augmentation.act`). diff --git a/tutorials/bases/common.yaml b/tutorials/bases/common.yaml new file mode 100644 index 00000000..c78c1a6f --- /dev/null +++ b/tutorials/bases/common.yaml @@ -0,0 +1,33 @@ +# Shared tutorial defaults used by top-level tutorials/*.yaml. + +system: + inference: + num_gpus: 1 + +model: + in_channels: 1 + +data: + image_transform: + normalize: "0-1" + +optimization: + optimizer: + betas: [0.9, 0.999] + +monitor: + detect_anomaly: false + logging: + scalar: + benchmark: true + images: + enabled: true + selected_channels: null + checkpoint: + mode: min + save_last: true + use_timestamp: true + +inference: + test_time_augmentation: + ensemble_mode: mean diff --git a/tutorials/bases/mednext.yaml b/tutorials/bases/mednext.yaml new file mode 100644 index 00000000..6c47a5f0 --- /dev/null +++ b/tutorials/bases/mednext.yaml @@ -0,0 +1,22 @@ +_base_: common.yaml + +model: + architecture: mednext + +optimization: + gradient_clip_val: 1.0 + optimizer: + name: AdamW + +monitor: + early_stopping: + enabled: true + mode: min + check_finite: true + logging: + images: + channel_mode: all + +inference: + test_time_augmentation: + flip_axes: null diff --git a/tutorials/bases/monai_unet.yaml b/tutorials/bases/monai_unet.yaml new file mode 100644 index 00000000..3729c9c5 --- /dev/null +++ b/tutorials/bases/monai_unet.yaml @@ -0,0 +1,58 @@ +_base_: common.yaml + +system: + training: + num_gpus: 4 + num_cpus: 8 + num_workers: 8 + seed: 42 + +model: + architecture: monai_unet + filters: [32, 64, 128, 256] + +data: + use_preloaded_cache: true + iter_num_per_epoch: 1280 + image_transform: + clip_percentile_low: 0.0 + clip_percentile_high: 1.0 + +optimization: + max_epochs: 1000 + accumulate_grad_batches: 1 + precision: "bf16-mixed" + optimizer: + lr: 0.001 + eps: 1.0e-8 + scheduler: + min_lr: 1.0e-6 + +monitor: + early_stopping: + enabled: true + monitor: train_loss_total_epoch + mode: min + check_finite: true + logging: + scalar: + loss: [train_loss_total_epoch] + loss_every_n_steps: 10 + val_check_interval: 1.0 + images: + log_every_n_epochs: 1 + checkpoint: + save_top_k: 1 + save_every_n_epochs: 10 + +inference: + sliding_window: + blending: gaussian + sigma_scale: 0.25 + test_time_augmentation: + enabled: true + select_channel: null + save_prediction: + enabled: true + intensity_scale: 255 + intensity_dtype: uint8 diff --git a/tutorials/bases/rsunet.yaml b/tutorials/bases/rsunet.yaml new file mode 100644 index 00000000..f01ecb6d --- /dev/null +++ b/tutorials/bases/rsunet.yaml @@ -0,0 +1,4 @@ +_base_: common.yaml + +model: + architecture: rsunet diff --git a/tutorials/fiber_linghu26.yaml b/tutorials/fiber_linghu26.yaml index 9b9be8c9..51cc9117 100644 --- a/tutorials/fiber_linghu26.yaml +++ b/tutorials/fiber_linghu26.yaml @@ -1,274 +1,245 @@ -# Fiber segmentation with MedNeXt -# Multi-task learning: Binary + Contour + Signed Distance Transform (BCS) -# -# This config uses MedNeXt for fiber segmentation with multi-task learning to predict: -# - Channel 0: Binary fiber masks (sigmoid activation) -# - Channel 1: Fiber contour/boundary maps (sigmoid activation) -# - Channel 2: Signed distance transforms (tanh activation) -# -# Based on barcode-R-BCS.yaml (legacy v1.0 config): -# TARGET_OPT: ["0", "4-0-1", "a-0-40-16-16"] -# - Channel 0: Binary mask (target "0") -# - Channel 1: Contour (target "4-0-1" = instance boundary, thickness=1, edge_mode=0) -# - Channel 2: SDT (target "a-0-40-16-16" = affinity-based distance, resolution=[40,16,16]) -# -# LOSS_OPTION: [[WeightedBCE, Dice], [WeightedBCE, Dice], [WeightedMSE]] -# LOSS_WEIGHT: [[1.0, 0.5], [1.0, 0.5], [4.0]] -# OUTPUT_ACT: [["none", "sigmoid"], ["none", "sigmoid"], ["tanh"]] - +_base_: bases/mednext.yaml experiment_name: fiber_mednext_bcs description: Fiber segmentation with MedNeXt and multi-task learning (Binary + Contour + SDT) - -# System system: training: - num_gpus: 4 # 4 GPU - num_cpus: 8 # 8 CPUs for data loading - num_workers: 8 # Use main process (avoids /dev/shm issues) - batch_size: 4 # Batch size + num_gpus: 4 + num_cpus: 8 + num_workers: 8 + batch_size: 4 inference: - num_gpus: 1 num_cpus: 1 num_workers: 1 - batch_size: 8 # Inference batch size + batch_size: 8 seed: 42 - -# Model - MedNeXt for multi-task fiber segmentation model: - in_channels: 1 # Single-channel grayscale EM images - out_channels: 3 # 3 outputs: Binary + Contour + SDT (matching barcode-R-BCS.yaml) - - architecture: mednext - input_size: [32, 96, 96] # Training input size - output_size: [32, 96, 96] # Training output size - - # MedNeXt-specific settings - mednext_size: S # S (5.6M), B (10.5M), M (17.6M), L (61.8M) - mednext_kernel_size: 3 # 3, 5, or 7 - deep_supervision: false # Disable deep supervision - - # Loss configuration matching barcode-R-BCS.yaml - # LOSS_OPTION: [[WeightedBCEWithLogitsLoss, DiceLoss], [WeightedBCEWithLogitsLoss, DiceLoss], [WeightedMSELoss]] - # LOSS_WEIGHT: [[1.0, 0.5], [1.0, 0.5], [4.0]] - # Using 5 separate loss function instances to avoid gradient accumulation issues - loss_functions: [WeightedBCEWithLogitsLoss, DiceLoss, WeightedBCEWithLogitsLoss, DiceLoss, WeightedMSELoss] - loss_weights: [1.0, 0.5, 1.0, 0.5, 4.0] # Weights for each loss: [BCE1, Dice1, BCE2, Dice2, MSE] + out_channels: 3 + input_size: + - 32 + - 96 + - 96 + output_size: + - 32 + - 96 + - 96 + mednext_size: S + mednext_kernel_size: 3 + deep_supervision: false + loss_functions: + - WeightedBCEWithLogitsLoss + - DiceLoss + - WeightedBCEWithLogitsLoss + - DiceLoss + - WeightedMSELoss + loss_weights: + - 1.0 + - 0.5 + - 1.0 + - 0.5 + - 4.0 loss_kwargs: - - {reduction: mean} # WeightedBCEWithLogitsLoss for Binary (channel 0): average over batch - - {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} # DiceLoss with sigmoid (include_background ignored for single channel) - - {reduction: mean} # WeightedBCEWithLogitsLoss for Contour (channel 1): average over batch - - {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} # DiceLoss with sigmoid (include_background ignored for single channel) - - {tanh: true} # WeightedMSELoss for SDT (channel 2, with tanh activation) - - # Multi-task configuration matching barcode-R-BCS.yaml - # Format: [[start_ch, end_ch, target_name, [loss_indices]], ...] - # Target names must match label_transform target names + - reduction: mean + - sigmoid: true + smooth_nr: 1e-5 + smooth_dr: 1e-5 + - reduction: mean + - sigmoid: true + smooth_nr: 1e-5 + smooth_dr: 1e-5 + - tanh: true multi_task_config: - - [0, 1, "binary", [0, 1]] # Channel 0: Binary (loss #0 weight=1.0 + loss #1 weight=0.5) - - [1, 2, "instance_boundary", [2, 3]] # Channel 1: Contour (loss #2 weight=1.0 + loss #3 weight=0.5) - - [2, 3, "skeleton_aware_edt", [4]] # Channel 2: SDT (loss #4 weight=4.0) - -# Data - Fiber dataset configuration (based on barcode-R-Base.yaml) + - - 0 + - 1 + - binary + - - 0 + - 1 + - - 1 + - 2 + - instance_boundary + - - 2 + - 3 + - - 2 + - 3 + - skeleton_aware_edt + - - 4 data: - # Base paths (NEW: will be combined with train_image/train_label) - train_path: '/projects/weilab/dataset/barcode/train_r2/' - - # Volume configuration - TIFF files (barcode fiber structure) - # Matches barcode-R dataset structure: ["1-xri_deconvolved.tif", "2-xri_deconvolved.tif"] - # These paths will be combined with train_path above - - train_image: ["PT37/*_raw.tif", "CA1_LZ58/raw_p1-w2-CA1-8d-1-fiber-1.tif", "DG_LZ58/*-raw.tif"] - train_label: ["PT37/*-mask.tif", "CA1_LZ58/final_p1-w2-CA1-8d-1-segmentation-1.tif", "DG_LZ58/*-mask.tif"] - #train_image: ["DG_LZ58/0702-2-C4-DG-40X002_1-raw.tif"] - #train_label: ["DG_LZ58/0702-2-C4-DG-40X002_1-mask.tif"] - - train_resolution: [40, 16, 16] # Isotropic resolution (adjust based on actual data) - use_preloaded_cache: true # Pre-load raw volumes into RAM for fast random cropping - persistent_workers: false # Disable persistent workers (avoids /dev/shm space issues) - - # Patch configuration - patch_size: [32, 96, 96] # Training patch size (matching INPUT_SIZE/OUTPUT_SIZE) - - iter_num_per_epoch: 1000 # Iterations per epoch - - # Image normalization + train_path: /projects/weilab/dataset/barcode/train_r2/ + train_image: + - PT37/*_raw.tif + - CA1_LZ58/raw_p1-w2-CA1-8d-1-fiber-1.tif + - DG_LZ58/*-raw.tif + train_label: + - PT37/*-mask.tif + - CA1_LZ58/final_p1-w2-CA1-8d-1-segmentation-1.tif + - DG_LZ58/*-mask.tif + train_resolution: + - 40 + - 16 + - 16 + use_preloaded_cache: true + persistent_workers: false + patch_size: + - 32 + - 96 + - 96 + iter_num_per_epoch: 1000 image_transform: - normalize: "0-1" # Min-max normalization to [0, 1] - clip_percentile_low: 0.005 # No clipping + clip_percentile_low: 0.005 clip_percentile_high: 0.995 - pad_size: [8, 16, 16] # Reflection padding for context - pad_mode: reflect # Reflection padding at boundaries - - # Label transformation for multi-task learning (matching barcode-R-BCS.yaml) - # TARGET_OPT: ["0", "4-0-1", "a-0-40-16-16"] - # - "0": Binary mask (channel 0) - # - "4-0-1": Instance boundary (channel 1) - thickness=1, edge_mode=0 - # - "a-0-40-16-16": Affinity-based distance (channel 2) - resolution=[40,16,16] + pad_size: + - 8 + - 16 + - 16 + pad_mode: reflect label_transform: targets: - - name: binary # Channel 0: Binary mask ("0") - kwargs: {} - - name: instance_boundary # Channel 1: Instance boundary ("4-0-1") - kwargs: - thickness: 1 # Boundary thickness in pixels - edge_mode: all # edge_mode=0 → "all" edges - mode: "2d" - - name: skeleton_aware_edt # Channel 2: Signed distance transform ("a-0-40-16-16") - kwargs: - resolution: [40, 16, 16] # Physical voxel resolution (z, y, x) - alpha: 1 # Affinity-based (alpha=1 for skeleton-aware distance) - bg_value: -1.0 # Background value for distance map - relabel: true # Relabel connected components - - # Augmentation Configuration with Presets - # - # Choose a preset mode and set individual augmentations below: - # - # Preset modes: - # - "all": Start with ALL augmentations enabled by default - # (Manually set enabled: false to disable specific ones) - # WARNING: Requires use_preloaded_cache: true - # - # - "some": Start with NO augmentations, ONLY respect manually enabled ones - # (Manually set enabled: true to enable specific ones) - # RECOMMENDED: Safe and flexible - # - # - "none": Disable all augmentations completely - # (Individual settings ignored) - # + - name: binary + kwargs: {} + - name: instance_boundary + kwargs: + thickness: 1 + edge_mode: all + mode: 2d + - name: skeleton_aware_edt + kwargs: + resolution: + - 40 + - 16 + - 16 + alpha: 1 + bg_value: -1.0 + relabel: true augmentation: - preset: "some" # Choose: "all", "some", or "none" + preset: some flip: enabled: true rotate: enabled: true - spatial_axes: [1, 2] # Rotate only in Y-X plane (preserves Z-axis) + spatial_axes: + - 1 + - 2 affine: - enabled: true # Affine transform (rotation + scaling + shearing) - prob: 0.5 # Probability of applying affine (0-1) - rotate_range: [0.2, 0.2, 0.2] # Rotation range in radians (~11° per axis) - scale_range: [0.1, 0.1, 0.1] # Scaling range (±10% per axis) - shear_range: [0.1, 0.1, 0.1] # Shearing range (±10° per axis) + enabled: true + prob: 0.5 + rotate_range: + - 0.2 + - 0.2 + - 0.2 + scale_range: + - 0.1 + - 0.1 + - 0.1 + shear_range: + - 0.1 + - 0.1 + - 0.1 elastic: enabled: true - prob: 0.3 # Probability of applying elastic deformation (0-1) - sigma_range: [8.0, 10.0] # Gaussian filter sigma range for deformation field (in pixels) - magnitude_range: [20.0, 50.0] # Deformation magnitude range (in pixels) + prob: 0.3 + sigma_range: + - 8.0 + - 10.0 + magnitude_range: + - 20.0 + - 50.0 intensity: enabled: true - # Grayscale intensity augmentations (applied only to image, not labels) - gaussian_noise_prob: 0 # Probability of adding Gaussian noise (0-1) - shift_intensity_prob: 0.3 # Probability of shifting intensity values (0-1) - shift_intensity_offset: 0.1 # Intensity shift range as fraction of image range - contrast_prob: 0.3 # Probability of adjusting contrast (0-1) - contrast_range: [0.7, 1.4] # Contrast adjustment range (0.7 = darker, 1.4 = brighter) - - -# Optimizer - AdamW with cosine annealing (based on barcode-R-Base.yaml) + gaussian_noise_prob: 0 + shift_intensity_prob: 0.3 + shift_intensity_offset: 0.1 + contrast_prob: 0.3 + contrast_range: + - 0.7 + - 1.4 optimization: - max_epochs: 100 # 100k iterations / (1000 iters/epoch) = 100 epochs - gradient_clip_val: 1.0 + max_epochs: 100 accumulate_grad_batches: 1 - precision: "16-mixed" # Mixed precision training - + precision: 16-mixed optimizer: - name: AdamW - lr: 0.001 # Reduced from 0.02 to prevent NaN (will use warmup) + lr: 0.001 weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1.0e-8 - - # Scheduler - Cosine annealing with warmup (matches barcode LR_SCHEDULER_NAME: WarmupCosineLR) + eps: 1.0e-08 scheduler: name: CosineAnnealingLR warmup_epochs: 5 - warmup_start_lr: 1.0e-5 - min_lr: 1.0e-6 - t_max: 95 # max_epochs - warmup_epochs - + warmup_start_lr: 1.0e-05 + min_lr: 1.0e-06 + t_max: 95 monitor: - # Loss monitoring and validation frequency - detect_anomaly: false logging: - # scalar loss scalar: - loss: [train_loss_total_epoch] + loss: + - train_loss_total_epoch loss_every_n_steps: 10 val_check_interval: 1.0 - benchmark: true - - # visualization images: - enabled: true max_images: 2 num_slices: 4 - log_every_n_epochs: 1 # Log less frequently for fiber data - channel_mode: all # Show all 3 channels for multi-task - selected_channels: null - - # Checkpointing (matches barcode-R-Base.yaml ITERATION_SAVE: 5000) + log_every_n_epochs: 1 checkpoint: - mode: min save_top_k: 3 - save_last: true - save_every_n_epochs: 5 # Save every 5 epochs (5000 iterations) - use_timestamp: true - - # Early stopping + save_every_n_epochs: 5 early_stopping: - enabled: true monitor: train_loss_total_epoch - patience: 100 # Patience in epochs - mode: min - min_delta: 1.0e-5 - check_finite: true + patience: 100 + min_delta: 1.0e-05 threshold: 0.01 divergence_threshold: 100.0 - test: data: - # Test on all available volumes (matches barcode INFERENCE IMAGE_NAME) - test_image: ["/projects/weilab/dataset/barcode/train_r2/DG_LZ58/0702-2-C4-DG-40X002_1-raw.tif"] - test_label: ["/projects/weilab/dataset/barcode/train_r2/DG_LZ58/0702-2-C4-DG-40X002_1-mask.tif"] - test_resolution: [40, 16, 16] # Isotropic resolution - -# Inference - MONAI SlidingWindowInferer for fiber segmentation (based on barcode-R-Base.yaml) + test_image: + - /projects/weilab/dataset/barcode/train_r2/DG_LZ58/0702-2-C4-DG-40X002_1-raw.tif + test_label: + - /projects/weilab/dataset/barcode/train_r2/DG_LZ58/0702-2-C4-DG-40X002_1-mask.tif + test_resolution: + - 40 + - 16 + - 16 inference: - - # MONAI SlidingWindowInferer parameters sliding_window: - window_size: [32, 256, 256] # Inference window size - stride: [16, 128, 128] # Stride for sliding window - blending: gaussian # Gaussian weighting for smooth blending + window_size: + - 32 + - 256 + - 256 + stride: + - 16 + - 128 + - 128 + blending: gaussian sigma_scale: 0.25 - padding_mode: reflect # Reflection-padding at volume boundaries - pad_size: [16, 32, 32] # PAD_SIZE (matches barcode-R-Base.yaml) - - # Test-Time Augmentation (TTA) - matches barcode-R-BCS.yaml + padding_mode: reflect + pad_size: + - 16 + - 32 + - 32 test_time_augmentation: enabled: true - flip_axes: null # Use all flip augmentations (AUG_NUM: None) - # Per-channel activations (aligned with barcode-R-BCS.yaml) - # OUTPUT_ACT: ["sigmoid", "sigmoid", "tanh"] - # Format: [[start_ch, end_ch, activation], ...] channel_activations: - - [0, 1, sigmoid] # Channel 0: Binary mask (sigmoid) - - [1, 2, sigmoid] # Channel 1: Contour (sigmoid) - - [2, 3, tanh] # Channel 2: SDT (tanh) - ensemble_mode: mean # AUG_MODE: "mean" - - # Decoding configuration (instance segmentation postprocessing) + - - 0 + - 1 + - sigmoid + - - 1 + - 2 + - sigmoid + - - 2 + - 3 + - tanh decoding: - - name: decode_binary_contour_distance_watershed - kwargs: - binary_threshold: [0.9, 0.85] - contour_threshold: [0.8, 1.1] - distance_threshold: [0.5, -0.5] - min_instance_size: 100 # Larger fibers (adjust based on data) - min_seed_size: 20 - prediction_scale: 1 - - # Evaluation + - name: decode_binary_contour_distance_watershed + kwargs: + binary_threshold: + - 0.9 + - 0.85 + contour_threshold: + - 0.8 + - 1.1 + distance_threshold: + - 0.5 + - -0.5 + min_instance_size: 100 + min_seed_size: 20 + prediction_scale: 1 evaluation: enabled: true - metrics: [adapted_rand] + metrics: + - adapted_rand diff --git a/tutorials/mito_betaseg.yaml b/tutorials/mito_betaseg.yaml index 8c5fcb12..869e8dbe 100644 --- a/tutorials/mito_betaseg.yaml +++ b/tutorials/mito_betaseg.yaml @@ -1,335 +1,258 @@ -# BetaSeg Dataset - 3D Mitochondria Instance Segmentation with MedNeXt -# multi-channel-task learning: Signed Distance Transform (SDT) + Affinity -# -# This config uses MedNeXt for mitochondria instance segmentation with SDT-based approach: -# - Output: Single channel SDT (tanh activation) - WeightedMSE loss -# - SDT encodes both foreground/background AND instance separation in one channel -# - Positive values = inside instances (distance to boundary) -# - Negative values = outside instances (distance to nearest instance) -# -# MedNeXt Configuration: -# - Deep supervision: CRITICAL for MedNeXt performance (5 scales) -# - Kernel size: 3 for better context (recommended for instance segmentation) -# - Size: S (~20-30M params) - small model for efficient training -# -# Instance Segmentation Pipeline: -# SDT prediction → Watershed on SDT seeds → Instance IDs -# -# BetaSeg Dataset: -# - High-resolution EM (16x16x16 nm/voxel isotropic) -# - Dense mitochondria with complex shapes -# - Challenging instance separation requiring precise SDT - +_base_: bases/mednext.yaml experiment_name: betaseg_mednext_s_sdt_affinity -description: BetaSeg 3D mitochondria instance segmentation with MedNeXt using SDT+affinity (1+6 channels) - -# System +description: BetaSeg 3D mitochondria instance segmentation with MedNeXt using SDT+affinity (1+6 channels) system: training: num_gpus: 1 num_cpus: 8 - num_workers: 8 # Parallel data loading - batch_size: 2 # Larger batch for single-channel output (vs multi-task) + num_workers: 8 + batch_size: 2 inference: - num_gpus: 1 num_cpus: 8 num_workers: 8 batch_size: 1 - # num_cpus: 1 - # num_workers: 1 - # batch_size: 1 - seed: 0 # 42->0, using Peng's nnUnet parameters firstly. - -# Model - MedNeXt for SDT-based mitochondria instance segmentation + seed: 0 model: - architecture: mednext # MedNeXt (SOTA for instance segmentation) - - # Input/output configuration - input_size: [128, 128, 128] # Isotropic patches for isotropic data (16×16×16 nm) - output_size: [128, 128, 128] - in_channels: 1 # Grayscale EM - out_channels: 7 # 6 affinity + 1 - - # MedNeXt architecture (optimized for instance segmentation) - mednext_size: S # S (~20-30M params) - small model for efficient training - mednext_kernel_size: 3 # Using kernel size 3 (from nnUNet/BANIS baseline). Larger kernels (5 or 7) may improve context but increase memory usage. - mednext_dim: "3d" # 3D convolutions - deep_supervision: true # CRITICAL for MedNeXt (5-scale deep supervision) - - - # Multi-task loss configuration + input_size: + - 128 + - 128 + - 128 + output_size: + - 128 + - 128 + - 128 + out_channels: 7 + mednext_size: S + mednext_kernel_size: 3 + mednext_dim: 3d + deep_supervision: true loss_functions: - - WeightedBCEWithLogitsLoss # Loss index 0: for affinity channels - - WeightedMSELoss # Loss index 1: for SDT channel - loss_weights: [1.0, 1.0] # Equal weighting: loss = affinity_loss + sdt_loss + - WeightedBCEWithLogitsLoss + - WeightedMSELoss + loss_weights: + - 1.0 + - 1.0 loss_kwargs: - - {} # BCEWithLogitsLoss: default (no pos_weight) - - {tanh: true} # WeightedMSELoss: tanh activation for [-1, 1] range - - # Multi-task channel mapping: [start_ch, end_ch, task_name, loss_indices] - # This tells the training loop which channels use which losses + - {} + - tanh: true multi_task_config: - - [0, 6, "affinity", [0]] # Channels 0-5: affinity → BCEWithLogitsLoss - - [6, 7, "sdt", [1]] # Channel 6: SDT → WeightedMSELoss(tanh=True) - - - + - - 0 + - 6 + - affinity + - - 0 + - - 6 + - 7 + - sdt + - - 1 data: - # Dataset configuration - BetaSeg training data - train_path: "/projects/weilab/qiongwang/datasets/betaseg/tif" - - # Training: 3 volumes (high_c1 moved to validation) + train_path: /projects/weilab/qiongwang/datasets/betaseg/tif train_image: - - "high_c3_im.tiff" - - "low_c1_im.tiff" - - "low_c2_im.tiff" + - high_c3_im.tiff + - low_c1_im.tiff + - low_c2_im.tiff train_label: - - "high_c3_mito.tiff" - - "low_c1_mito.tiff" - - "low_c2_mito.tiff" - train_resolution: [16, 16, 16] # 16nm x 16nm x 16nm isotropic - - # Validation: 1 volume (high_c1) - for monitoring generalization - val_path: "/projects/weilab/qiongwang/datasets/betaseg/tif" + - high_c3_mito.tiff + - low_c1_mito.tiff + - low_c2_mito.tiff + train_resolution: + - 16 + - 16 + - 16 + val_path: /projects/weilab/qiongwang/datasets/betaseg/tif val_image: - - "high_c1_im.tiff" + - high_c1_im.tiff val_label: - - "high_c1_mito.tiff" - # val_resolution defaults to train_resolution if not specified - - # Data loading optimization - use_preloaded_cache: false # Disabled to enable validation support - use_cache: true # Use MONAI caching instead (still fast!) - cache_rate: 1.0 # Cache 100% of data - persistent_workers: true # Keep workers alive between epochs - - # Patch configuration (isotropic cubic patches for isotropic data) - patch_size: [128, 128, 128] # repeat Peng - pad_size: [16, 16, 16] # from resolution - pad_mode: reflect # Reflection padding at boundaries - iter_num_per_epoch: 790 # Training iterations per epoch - # val_iter_num: auto-calculated based on validation volume size and patch size - - - - # Image normalization, Input Normalization, If the training loss is unstable, It might be that there are too many outliers. Try to clip the outliers more: 0.005 → 0.01, 0.995 → 0.99 + - high_c1_mito.tiff + use_preloaded_cache: false + use_cache: true + cache_rate: 1.0 + persistent_workers: true + patch_size: + - 128 + - 128 + - 128 + pad_size: + - 16 + - 16 + - 16 + pad_mode: reflect + iter_num_per_epoch: 790 image_transform: - normalize: "0-1" # Min-max normalization to [0, 1] - clip_percentile_low: 0.005 # Clip bottom 0.5% outliers (reduces noise impact) - clip_percentile_high: 0.995 # Clip top 0.5% outliers (reduces saturation artifacts) - - - # Multi-task label transformation - # This generates 7 channels from instance segmentation labels: - # - 6 channels: affinity maps (short + long range) - # - 1 channel: instance SDT - # Label transformation - Affinity maps + Signed Distance Transform for instance segmentation + clip_percentile_low: 0.005 + clip_percentile_high: 0.995 label_transform: targets: - # Target 1: Affinity maps (6 channels: 3 short-range + 3 long-range) - - name: affinity - kwargs: - offsets: - # Short-range affinities (offset = 1 voxel) - - "0-0-1" # x-direction, distance 1 - - "0-1-0" # y-direction, distance 1 - - "1-0-0" # z-direction, distance 1 - # Long-range affinities (offset = 10 voxels, matching BANIS --long_range 10) - - "0-0-10" # x-direction, distance 10 - - "0-10-0" # y-direction, distance 10 - - "10-0-0" # z-direction, distance 10 - # Total: 6 affinity channels (3 short + 3 long) - - # Target 2: Signed Distance Transform (1 channel) - - name: skeleton_aware_edt # [1][skeleton_aware_edt]New Version of SDT; [2][signed_distance] TRUE Signed Distance Transform (solves class imbalance!); [3][instance_edt] for edt only - kwargs: - resolution: [16, 16, 16] # Physical voxel resolution (z, y, x) - alpha: 0.8 # Affinity-based (alpha=1 for skeleton-aware distance) - bg_value: -1.0 # Background value for distance map - relabel: true - - - - # Augmentation - crucial for generalization + - name: affinity + kwargs: + offsets: + - 0-0-1 + - 0-1-0 + - 1-0-0 + - 0-0-10 + - 0-10-0 + - 10-0-0 + - name: skeleton_aware_edt + kwargs: + resolution: + - 16 + - 16 + - 16 + alpha: 0.8 + bg_value: -1.0 + relabel: true augmentation: - preset: "some" # some, none, all,new version of augmentation - + preset: some affine: enabled: true - prob: 0.5 # repeat Peng: affine: 0.5 - rotate_range: [0.2, 0.2, 0.2] - scale_range: [0.2, 0.2, 0.2] - shear_range: [0.5, 0.5, 0.5] - + prob: 0.5 + rotate_range: + - 0.2 + - 0.2 + - 0.2 + scale_range: + - 0.2 + - 0.2 + - 0.2 + shear_range: + - 0.5 + - 0.5 + - 0.5 intensity: - enabled: true # repeat Peng: intensity_aug: true + enabled: true gaussian_noise_prob: 0.3 - gaussian_noise_std: 0.5 # repeat Peng: noise_scale: 0.5 + gaussian_noise_std: 0.5 shift_intensity_prob: 0.3 shift_intensity_offset: 0.1 contrast_prob: 0.3 - contrast_range: [0.7, 1.4] - - + contrast_range: + - 0.7 + - 1.4 missing_section: enabled: true - prob: 0.05 # repeat Peng: drop_slice_prob: 0.05 + prob: 0.05 num_sections: 2 - misalignment: enabled: true - prob: 0.05 # repeat Peng: shift_slice_prob: 0.05 + prob: 0.05 displacement: 10 rotate_ratio: 0.0 - flip: enabled: true prob: 0.5 - rotate: enabled: true prob: 0.5 - elastic: enabled: true prob: 0.3 - - - -# Optimizer - MedNeXt recommended settings optimization: - max_steps: 1000000 - gradient_clip_val: 1.0 - accumulate_grad_batches: 1 - precision: "16-mixed" # FP16 mixed precision (better GPU compatibility than bf16),"16-mixed" - + max_steps: 1000000 + accumulate_grad_batches: 1 + precision: 16-mixed optimizer: - name: AdamW - lr: 1e-3 # MedNeXt recommended: 1e-3 (constant LR) - weight_decay: 1e-2 # 1e-4 -> 1e-2, repeat Peng. 3e-5 in the paper. AdamW typically uses 1e-3 as the standard learning rate for vision transformers and ConvNeXt-style architectures, better for generalization. - betas: [0.9, 0.999] - eps: 1.0e-4 # 1.0e-8 -> 1.0e-4. 1e-8 cause nans in fp16 - - # Scheduler - Constant LR (MedNeXt recommendation) + lr: 1e-3 + weight_decay: 1e-2 + eps: 0.0001 scheduler: - # name: constant # Constant LR works best for MedNeXt (per paper) - name: CosineAnnealingLR #repeat Peng, schedular: true, --> CosineAnnealingLR(T_max=1_000_000) + name: CosineAnnealingLR t_max: 1000000 - min_lr: 0.0 - interval: step # Step every training step - frequency: 1 # Step every 1 training step - + min_lr: 0.0 + interval: step + frequency: 1 monitor: - # Loss monitoring and validation frequency - detect_anomaly: false logging: - # Scalar loss monitoring scalar: - loss: [train_loss_total_epoch, val_loss_total_epoch, train_loss_affinity_total, train_loss_sdt_total] - loss_every_n_steps: 100 # Log every 50 steps - val_check_interval: 1.0 # Validate every epoch - benchmark: true - - # Visualization - SDT predictions (train + validation) + loss: + - train_loss_total_epoch + - val_loss_total_epoch + - train_loss_affinity_total + - train_loss_sdt_total + loss_every_n_steps: 100 + val_check_interval: 1.0 images: - enabled: true - max_images: 10 # Show more samples for quality check - num_slices: 10 # More slices for 3D visualization - log_every_n_epochs: 5 # Visualize every 5 epochs - channel_mode: all # Show SDT channel - selected_channels: null - - # Checkpointing - Save best models based on validation loss + max_images: 10 + num_slices: 10 + log_every_n_epochs: 5 checkpoint: - monitor: val_loss_total # Monitor validation loss (prevents overfitting) - mode: min # Minimize validation loss - save_top_k: 10 # Keep top 5 checkpoints - save_last: true - save_every_n_epochs: 5 # Save checkpoint every 5 epochs + monitor: val_loss_total + save_top_k: 10 + save_every_n_epochs: 5 dirpath: outputs/betaseg_mednext_affinity_sdt/checkpoints/ - use_timestamp: true - - # Early stopping - Stop training when validation loss plateaus early_stopping: - enabled: true - monitor: val_loss_total # Monitor validation loss (prevents overfitting) - patience: 150 # More patience for SDT convergence - mode: min - min_delta: 1e-6 # Small delta for precise SDT - check_finite: true - threshold: 0.001 # 0.005 -> 0.001, 0.01 -> 0.005 + monitor: val_loss_total + patience: 150 + min_delta: 1e-6 + threshold: 0.001 divergence_threshold: 100.0 - -# Inference - MONAI SlidingWindowInferer for BetaSeg test: data: - test_image: - - "/projects/weilab/qiongwang/datasets/betaseg/tif/high_c2_im.tiff" - - "/projects/weilab/qiongwang/datasets/betaseg/tif/high_c4_im.tiff" - - "/projects/weilab/qiongwang/datasets/betaseg/tif/low_c3_im.tiff" - test_label: - - "/projects/weilab/qiongwang/datasets/betaseg/tif/high_c2_mito.tiff" - - "/projects/weilab/qiongwang/datasets/betaseg/tif/high_c4_mito.tiff" - - "/projects/weilab/qiongwang/datasets/betaseg/tif/low_c3_mito.tiff" - test_resolution: [16, 16, 16] + test_image: + - /projects/weilab/qiongwang/datasets/betaseg/tif/high_c2_im.tiff + - /projects/weilab/qiongwang/datasets/betaseg/tif/high_c4_im.tiff + - /projects/weilab/qiongwang/datasets/betaseg/tif/low_c3_im.tiff + test_label: + - /projects/weilab/qiongwang/datasets/betaseg/tif/high_c2_mito.tiff + - /projects/weilab/qiongwang/datasets/betaseg/tif/high_c4_mito.tiff + - /projects/weilab/qiongwang/datasets/betaseg/tif/low_c3_mito.tiff + test_resolution: + - 16 + - 16 + - 16 output_path: outputs/betaseg_mednext_affinity_sdt/results/ - - # Inference normalization (must match training normalization!) image_transform: - normalize: "0-1" # Min-max normalization [0,1] (same as training) - clip_percentile_low: 0.005 # Clip bottom 0.5% outliers (same as training) - clip_percentile_high: 0.995 # Clip top 0.5% outliers (same as training) - - # Decoding configuration (SDT → instances via watershed) + normalize: 0-1 + clip_percentile_low: 0.005 + clip_percentile_high: 0.995 decoding: - - name: decode_distance_watershed # ← Changed from decode_instance_binary_contour_distance - kwargs: - distance_channels: [6] # Use SDT channel only for now - distance_threshold: [0.5, 0] # Seeds: SDT>0.5, Foreground: SDT>0 - min_seed_size: 50 - min_instance_size: 100 - use_fast_edt: true # Enable fast EDT (10-50x speedup) - edt_parallel: 8 # Use 8 CPU cores - edt_anisotropy: [1.0, 1.0, 1.0] # BetaSeg: resolution: [16, 16, 16] - edt_downsample_factor: 1 # Full resolution (use 2 for large volumes) - - # Evaluation + - name: decode_distance_watershed + kwargs: + distance_channels: + - 6 + distance_threshold: + - 0.5 + - 0 + min_seed_size: 50 + min_instance_size: 100 + use_fast_edt: true + edt_parallel: 8 + edt_anisotropy: + - 1.0 + - 1.0 + - 1.0 + edt_downsample_factor: 1 evaluation: enabled: true - metrics: [adapted_rand, voi, instance_accuracy, instance_accuracy_detail] # Adapted Rand Score + VOI + Accuracy for instance segmentation - - + metrics: + - adapted_rand + - voi + - instance_accuracy + - instance_accuracy_detail inference: - # MONAI SlidingWindowInferer parameters sliding_window: - window_size: [128, 128, 128] # Match training patch size - # sw_batch_size: 4 # Process multiple patches per batch - sw_batch_size: 4 # Process multiple patches per batch - overlap: 0.5 # 50% overlap for smooth blending - blending: gaussian # Gaussian weighting for smooth blending - sigma_scale: 0.25 # Gaussian sigma scale - padding_mode: replicate # Replicate padding at volume boundaries - - # Test-Time Augmentation (TTA) for SDT + window_size: + - 128 + - 128 + - 128 + sw_batch_size: 4 + overlap: 0.5 + blending: gaussian + sigma_scale: 0.25 + padding_mode: replicate test_time_augmentation: - flip_axes: null # Use all 8 flip augmentations (2^3 for xyz) - rotation90_axes: null # No 90-degree rotations + rotation90_axes: null channel_activations: - - [0, 6, sigmoid] # Affinity channels: sigmoid activation - - [6, 7, tanh] + - - 0 + - 6 + - sigmoid + - - 6 + - 7 + - tanh select_channel: all - ensemble_mode: mean # Average predictions across augmentations - apply_mask: false # No mask for BetaSeg - - # Save intermediate predictions + apply_mask: false save_prediction: - enabled: true # Save SDT predictions before decoding - intensity_scale: -1 # Keep original scale (no rescaling) - intensity_dtype: float32 # Keep as float32 for SDT - output_formats: [h5, tiff] # Save in multiple formats: HDF5, TIFF, and NIfTI: [h5, tiff, nii.gz] - # Supported formats: h5, tiff, nii.gz, png - # Example configurations: - # [h5] - HDF5 only (smallest file size) - # [tiff] - TIFF only (compatible with ImageJ/Fiji) - # [h5, tiff] - Both HDF5 and TIFF - # [h5, tiff, nii.gz] - All three formats (current setting) + enabled: true + intensity_scale: -1 + intensity_dtype: float32 + output_formats: + - h5 + - tiff diff --git a/tutorials/mito_lucchi++.yaml b/tutorials/mito_lucchi++.yaml index 3500d80e..1253b9cf 100644 --- a/tutorials/mito_lucchi++.yaml +++ b/tutorials/mito_lucchi++.yaml @@ -1,264 +1,157 @@ -# Lucchi++ Mitochondria Segmentation -# Electron microscopy (EM) dataset with multiple architecture options -# -# ============================================================================ -# ARCHITECTURE SELECTION - Change 'architecture' to switch models: -# ============================================================================ -# -# monai_unet (recommended for MONAI) -# - MONAI's UNet with residual units -# - Supports any number of levels (uses filters directly) -# - No deep supervision -# - Recommended for: MONAI baseline, flexible architecture -# -# monai_basic_unet3d -# - MONAI's BasicUNet (always 6 levels, pads filters if < 6) -# - Simple, fast -# - Recommended for: Quick experiments -# -# rsunet -# - Residual Symmetric U-Net (EM-optimized) -# - No checkerboard artifacts (uses upsample+conv instead of transposed conv) -# - Anisotropic convolutions for EM data -# - Recommended for: Production EM segmentation, anisotropic data -# -# mednext -# - MedNeXt (MICCAI 2023, ConvNeXt-based) -# - State-of-the-art performance -# - Deep supervision for better training -# - Sizes: S (5.6M), B (10.5M), M (17.6M), L (61.8M) params -# - Recommended for: Best accuracy, sufficient GPU memory -# -# ============================================================================ - +_base_: bases/monai_unet.yaml experiment_name: lucchi++ description: Mitochondria segmentation on Lucchi++ EM dataset - -# System system: training: - num_gpus: 4 - num_cpus: 8 - num_workers: 8 # Set to 0 to avoid /dev/shm space issues (use in-process loading) batch_size: 16 inference: - num_gpus: 1 num_cpus: 1 - num_workers: 1 # Set to 0 to avoid /dev/shm space issues - batch_size: 1 # Reduced from 16 to avoid OOM (sw_batch_size will use this) - seed: 42 - -# Model Configuration + num_workers: 1 + batch_size: 1 model: - # ========== CHANGE THIS LINE TO SWITCH ARCHITECTURES ========== - architecture: monai_unet # Options: monai_unet, monai_basic_unet3d, rsunet, mednext - # ============================================================== - - # Common settings (used by all architectures) - input_size: [112, 112, 112] - output_size: [112, 112, 112] - in_channels: 1 - out_channels: 1 # Single channel with BCE loss (standard for EM) - filters: [32, 64, 128, 256] # 4-level encoder + input_size: + - 112 + - 112 + - 112 + output_size: + - 112 + - 112 + - 112 + out_channels: 1 dropout: 0.1 - - # MONAI BasicUNet-specific settings (ignored by other architectures) - upsample: deconv # Use trilinear upsampling (no transposed conv) - - # RSUNet-specific settings (ignored by other architectures) - rsunet_norm: batch # Batch normalization for RSUNet - - # MedNeXt-specific settings (ignored by other architectures) - mednext_size: S # S (5.6M), B (10.5M), M (17.6M), L (61.8M) - mednext_kernel_size: 3 # 3, 5, or 7 - deep_supervision: false # Enable deep supervision for MedNeXt - - # Loss configuration - WeightedBCEWithLogitsLoss + Dice for mitochondria segmentation - loss_functions: [WeightedBCEWithLogitsLoss, DiceLoss] - loss_weights: [1.0, 1.0] # Equal weighting for BCE and Dice + upsample: deconv + rsunet_norm: batch + mednext_size: S + mednext_kernel_size: 3 + deep_supervision: false + loss_functions: + - WeightedBCEWithLogitsLoss + - DiceLoss + loss_weights: + - 1.0 + - 1.0 loss_kwargs: - - {reduction: mean} # WeightedBCEWithLogitsLoss: average over batch - - {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} # DiceLoss with sigmoid (include_background ignored for single channel) - -# Data - Using automatic 80/20 train/val split (DeepEM-style) + - reduction: mean + - sigmoid: true + smooth_nr: 1e-5 + smooth_dr: 1e-5 data: - # Volume configuration train_image: datasets/lucchi++/train_im.h5 train_label: datasets/lucchi++/train_mito.h5 - train_resolution: [5, 5, 5] # Lucchi EM: 5nm isotropic resolution - use_preloaded_cache: true # Load volumes into memory for fast training - - # Patch configuration - patch_size: [112, 112, 112] # Isotropic patches for training - pad_size: [0, 0, 0] # No padding during training (not needed) - iter_num_per_epoch: 1280 # 1280 random crops per epoch - - # Image normalization - image_transform: - normalize: "0-1" # Min-max normalization to [0, 1] - clip_percentile_low: 0.0 # No clipping - clip_percentile_high: 1.0 - - # Augmentation - moderate set for 3D mitochondria segmentation - # Recommended for Lucchi++: geometric transforms + EM-specific augmentations + train_resolution: + - 5 + - 5 + - 5 + patch_size: + - 112 + - 112 + - 112 + pad_size: + - 0 + - 0 + - 0 augmentation: - preset: "some" # Enable only augmentations explicitly set to enabled=True - - # Standard geometric augmentations (safe for 3D EM) + preset: some flip: enabled: true prob: 0.5 - spatial_axis: [0, 1, 2] # Flip x/y/z - + spatial_axis: + - 0 + - 1 + - 2 rotate: enabled: true prob: 0.5 - spatial_axes: [0, 1, 2] # Rotate x/y/z - + spatial_axes: + - 0 + - 1 + - 2 affine: enabled: true - prob: 0.3 # Lower prob to avoid too aggressive transforms - rotate_range: [0.1, 0.1, 0.1] # Small rotations (~6°) - careful with Z-axis - scale_range: [0.05, 0.05, 0.05] # Small scaling (±5%) - shear_range: [0.05, 0.05, 0.05] # Small shearing - - # Intensity augmentations (important for EM data variability) + prob: 0.3 + rotate_range: + - 0.1 + - 0.1 + - 0.1 + scale_range: + - 0.05 + - 0.05 + - 0.05 + shear_range: + - 0.05 + - 0.05 + - 0.05 intensity: enabled: true - gaussian_noise_prob: 0.2 # Moderate noise + gaussian_noise_prob: 0.2 gaussian_noise_std: 0.03 shift_intensity_prob: 0.4 shift_intensity_offset: 0.1 contrast_prob: 0.4 - contrast_range: [0.8, 1.2] # Moderate contrast variation - - # EM-specific augmentations (highly recommended for EM data) + contrast_range: + - 0.8 + - 1.2 misalignment: enabled: true prob: 0.4 - displacement: 8 # Moderate displacement for small patches - rotate_ratio: 0.3 # Mix of translation and rotation - + displacement: 8 + rotate_ratio: 0.3 missing_section: enabled: true prob: 0.3 - num_sections: 2 # 1-2 missing sections (common in EM) - + num_sections: 2 motion_blur: enabled: true prob: 0.3 sections: 2 - kernel_size: 9 # Moderate blur - - # Avoid these for mitochondria segmentation: - # - elastic: Too aggressive for small structures - # - cut_blur/cut_noise: May be too aggressive - # - copy_paste/mixup: Not needed for binary segmentation - - -# Optimizer - Adam with conservative hyperparameters (proven for EM segmentation) + kernel_size: 9 optimization: - max_epochs: 1000 # Standard epochs for EM segmentation - gradient_clip_val: 0.5 # Conservative gradient clipping - accumulate_grad_batches: 1 - precision: "bf16-mixed" # BFloat16 mixed precision - + gradient_clip_val: 0.5 optimizer: - name: Adam # Standard Adam (not AdamW for EM tasks) - lr: 0.001 # Learning rate (1e-3 works well for all architectures) - weight_decay: 0.0 # No weight decay (not beneficial for EM) - betas: [0.9, 0.999] # Standard Adam betas - eps: 1.0e-8 # Numerical stability - - # Scheduler - ReduceLROnPlateau for adaptive learning + name: Adam + weight_decay: 0.0 scheduler: - name: ReduceLROnPlateau # Reduce LR when training loss plateaus - mode: min # Monitor minimum loss - factor: 0.5 # Reduce LR by 50% - patience: 50 # Wait 50 epochs before reducing - threshold: 1.0e-4 # Minimum change to qualify as improvement - min_lr: 1.0e-6 # Don't go below 1e-6 - monitor: train_loss_total_epoch # No validation; monitor training loss - + name: ReduceLROnPlateau + mode: min + factor: 0.5 + patience: 50 + threshold: 0.0001 + monitor: train_loss_total_epoch monitor: - # Loss monitoring and validation frequency - detect_anomaly: false logging: - # scalar loss - scalar: - loss: [train_loss_total_epoch] - loss_every_n_steps: 10 - val_check_interval: 1.0 - benchmark: true - - # visualization images: - enabled: true max_images: 8 num_slices: 2 - log_every_n_epochs: 1 # Log every N epochs (default: 1) - channel_mode: argmax # 'argmax', 'all', or 'selected' - selected_channels: null # Only used when channel_mode='selected' - - # Checkpointing - checkpoint: - mode: min - save_top_k: 1 - save_last: true - save_every_n_epochs: 10 - # checkpoint_filename: auto-generated from monitor metric (epoch={epoch:03d}-{monitor}={value:.4f}) - use_timestamp: true # Enable timestamped subdirectories (YYYYMMDD_HHMMSS) - - # Early stopping - Patient for convergence + channel_mode: argmax early_stopping: - enabled: true - monitor: train_loss_total_epoch - patience: 100 # Patient waiting for improvement - mode: min - min_delta: 1.0e-4 # Minimum delta for improvement - check_finite: true # Stop if monitored metric becomes NaN/inf - threshold: 0.02 # Stop if loss gets this low (excellent convergence for EM) - divergence_threshold: 2.0 # Stop if loss exceeds this (training collapse) - -# Inference - MONAI SlidingWindowInferer + patience: 100 + min_delta: 0.0001 + threshold: 0.02 + divergence_threshold: 2.0 inference: - # MONAI SlidingWindowInferer parameters sliding_window: - window_size: [112, 112, 112] # Patch size (matches training patches) - sw_batch_size: 1 # Process 1 patch at a time (memory optimization) - overlap: 0.25 # 25% overlap (reduced from 0.5 to save memory) - blending: gaussian # Gaussian weighting for smooth blending - sigma_scale: 0.25 # Larger sigma = smoother blending at boundaries - padding_mode: replicate # Replicate edge values (better than reflect for z=0) - - # Test-Time Augmentation (TTA) + window_size: + - 112 + - 112 + - 112 + sw_batch_size: 1 + overlap: 0.25 + padding_mode: replicate test_time_augmentation: - enabled: true # Enable TTA for improved predictions - flip_axes: all # No flip augmentation (set to null to disable flips) - # XY flips are safe for isotropic XY resolution, Z-flip avoided due to anisotropy - channel_activations: [[0, 1, 'sigmoid']] # Sigmoid activation for single-channel output - select_channel: null # Keep all channels (single channel output) - ensemble_mode: mean # Mean ensemble (smooth predictions) - # NOTE: Reduced TTA compared to original (4x instead of 8x) for faster inference - - # Save intermediate predictions (before decoding/postprocessing) - save_prediction: - enabled: true # Save intermediate predictions - intensity_scale: 255 # Scale predictions to [0, 255] for saving - intensity_dtype: uint8 # Save as uint8 - + flip_axes: all + channel_activations: + - - 0 + - 1 + - sigmoid test: data: test_image: datasets/lucchi++/test_im.h5 test_label: datasets/lucchi++/test_mito.h5 - test_resolution: [5, 5, 5] - - # Evaluation + test_resolution: + - 5 + - 5 + - 5 evaluation: - enabled: true # Use eval mode for BatchNorm - metrics: [jaccard] # Metrics to compute - - # NOTE: batch_size=1 for inference - # During training: batch_size controls how many random patches to load - # During inference: batch_size=1 means process one full volume at a time - # sw_batch_size (above) controls how many patches are processed per GPU forward pass + enabled: true + metrics: + - jaccard diff --git a/tutorials/mito_mitoEM.yaml b/tutorials/mito_mitoEM.yaml index 5af84fe0..5aeb2294 100644 --- a/tutorials/mito_mitoEM.yaml +++ b/tutorials/mito_mitoEM.yaml @@ -1,205 +1,147 @@ -# MitoEM Dataset - 3D Mitochondria Instance Segmentation with MedNeXt -# Single-task learning: Signed Distance Transform (SDT) only -# -# This config uses MedNeXt for mitochondria instance segmentation with SDT-based approach: -# - Output: Single channel SDT (tanh activation) - WeightedMSE loss -# - SDT encodes both foreground/background AND instance separation in one channel -# - Positive values = inside instances (distance to boundary) -# - Negative values = outside instances (distance to nearest instance) -# -# MedNeXt Configuration: -# - Deep supervision: CRITICAL for MedNeXt performance (5 scales) -# - Kernel size: 7x7x7 for better context (recommended for instance segmentation) -# - Size: M (17.6M params) - balanced capacity for 3D SDT learning -# -# Instance Segmentation Pipeline: -# SDT prediction → Watershed on SDT seeds → Instance IDs -# -# MitoEM Dataset: -# - High-resolution EM (30x8x8 nm/voxel anisotropic) -# - Dense mitochondria with complex shapes -# - Challenging instance separation requiring precise SDT - +_base_: bases/mednext.yaml experiment_name: mitoem_mednext_sdt description: MitoEM 3D mitochondria instance segmentation with MedNeXt using SDT only - -# System system: training: num_gpus: 4 num_cpus: 8 - num_workers: 8 # Parallel data loading - batch_size: 16 # Larger batch for single-channel output (vs multi-task) + num_workers: 8 + batch_size: 16 inference: - num_gpus: 1 num_cpus: 1 num_workers: 1 batch_size: 1 seed: 42 - -# Model - MedNeXt for SDT-based mitochondria instance segmentation model: - architecture: mednext # MedNeXt (SOTA for instance segmentation) - - # Input/output configuration - input_size: [16, 256, 256] # Anisotropic patches matching data resolution - output_size: [16, 256, 256] - in_channels: 1 # Grayscale EM - out_channels: 1 # Single channel: SDT only - - # MedNeXt architecture (optimized for instance segmentation) - mednext_size: M # M (17.6M params) - balanced capacity for 3D SDT - mednext_kernel_size: 7 # 7x7x7 kernels for better context (RECOMMENDED) - mednext_dim: "3d" # 3D convolutions - deep_supervision: true # CRITICAL for MedNeXt (5-scale deep supervision) - - # Single-task loss configuration (SDT only) - loss_functions: [WeightedMSELoss] - loss_weights: [1.0] # Single loss for SDT prediction + input_size: + - 16 + - 256 + - 256 + output_size: + - 16 + - 256 + - 256 + out_channels: 1 + mednext_size: M + mednext_kernel_size: 7 + mednext_dim: 3d + deep_supervision: true + loss_functions: + - WeightedMSELoss + loss_weights: + - 1.0 loss_kwargs: - - {tanh: true} # WeightedMSELoss with tanh activation for SDT [-1, 1] range - + - tanh: true data: - # Dataset configuration - MitoEM training data - train_path: "/projects/weilab/dataset/mito/mitoEM/" - train_image: ["EM30-H/im_train_val.h5"] # MitoEM training volume - train_label: ["EM30-H/mito_train_val.h5"] # Instance labels - train_resolution: [30, 8, 8] # Anisotropic: 30nm (z) x 8nm (xy) - - # Data loading optimization - use_preloaded_cache: true # Pre-load entire volume into RAM (faster training) - cache_rate: 1.0 # Cache 100% of data - persistent_workers: true # Keep workers alive between epochs - - # Patch configuration (anisotropic to match data resolution) - patch_size: [16, 256, 256] # 16z x 256xy - matches anisotropic resolution - pad_size: [4, 16, 16] # Reflection padding for context - pad_mode: reflect # Reflection padding at boundaries - iter_num_per_epoch: 2000 # Training iterations per epoch - - # Image normalization + train_path: /projects/weilab/dataset/mito/mitoEM/ + train_image: + - EM30-H/im_train_val.h5 + train_label: + - EM30-H/mito_train_val.h5 + train_resolution: + - 30 + - 8 + - 8 + use_preloaded_cache: true + cache_rate: 1.0 + persistent_workers: true + patch_size: + - 16 + - 256 + - 256 + pad_size: + - 4 + - 16 + - 16 + pad_mode: reflect + iter_num_per_epoch: 2000 image_transform: - normalize: "0-1" # Min-max normalization to [0, 1] - clip_percentile_low: 0.005 # Clip outliers (0.5%) - clip_percentile_high: 0.995 # Clip outliers (0.5%) - - # Label transformation - SDT only for instance segmentation + clip_percentile_low: 0.005 + clip_percentile_high: 0.995 label_transform: targets: - - name: instance_edt # Signed distance transform (SDT) - kwargs: - mode: "3d" # 3D EDT computation - quantize: false # Continuous values (not quantized) - normalize: true # Normalize to [-1, 1] range - - # Augmentation - crucial for generalization + - name: instance_edt + kwargs: + mode: 3d + quantize: false + normalize: true augmentation: - preset: "all" # Legacy `enabled: true` -> modern preset mode - -# Optimizer - MedNeXt recommended settings + preset: all optimization: - max_epochs: 1000 # Extended training for quality instance segmentation - gradient_clip_val: 1.0 - accumulate_grad_batches: 4 # Effective batch size = 16 * 4 = 64 - precision: "16-mixed" # FP16 mixed precision (better GPU compatibility than bf16) - + max_epochs: 1000 + accumulate_grad_batches: 4 + precision: 16-mixed optimizer: - name: AdamW - lr: 0.001 # MedNeXt recommended: 1e-3 (constant LR) - weight_decay: 1e-4 # Regularization for better generalization - betas: [0.9, 0.999] - eps: 1.0e-8 - - # Scheduler - Constant LR (MedNeXt recommendation) + lr: 0.001 + weight_decay: 1e-4 + eps: 1.0e-08 scheduler: - name: constant # Constant LR works best for MedNeXt (per paper) - + name: constant monitor: - # Loss monitoring and validation frequency - detect_anomaly: false logging: - # Scalar loss monitoring scalar: - loss: [train_loss_total_epoch] - loss_every_n_steps: 50 # Log every 50 steps - val_check_interval: 1.0 # Validate every epoch - benchmark: true - - # Visualization - SDT predictions + loss: + - train_loss_total_epoch + loss_every_n_steps: 50 + val_check_interval: 1.0 images: - enabled: true - max_images: 4 # Show more samples for quality check - num_slices: 8 # More slices for 3D visualization - log_every_n_epochs: 5 # Visualize every 5 epochs - channel_mode: all # Show SDT channel - selected_channels: null - - # Checkpointing + max_images: 4 + num_slices: 8 + log_every_n_epochs: 5 checkpoint: - mode: min - save_top_k: 5 # Keep top 5 checkpoints - save_last: true - save_every_n_epochs: 25 # Save checkpoint every 25 epochs + save_top_k: 5 + save_every_n_epochs: 25 dirpath: outputs/mitoem_mednext_sdt/checkpoints/ - use_timestamp: true - - # Early stopping (patient for SDT learning) early_stopping: - enabled: true monitor: train_loss_total_epoch - patience: 150 # More patience for SDT convergence - mode: min - min_delta: 1e-6 # Small delta for precise SDT - check_finite: true + patience: 150 + min_delta: 1e-6 threshold: 0.01 divergence_threshold: 100.0 - -# Test data paths (schema-compliant location) test: data: test_image: /projects/weilab/dataset/mito/mitoEM/EM30-H/im_test.h5 test_label: /projects/weilab/dataset/mito/mitoEM/EM30-H/mito_test.h5 - test_resolution: [30, 8, 8] + test_resolution: + - 30 + - 8 + - 8 output_path: outputs/mitoem_mednext_sdt/results/ - -# Inference - MONAI SlidingWindowInferer for MitoEM inference: - # MONAI SlidingWindowInferer parameters sliding_window: - window_size: [16, 256, 256] # Match training patch size - sw_batch_size: 4 # Process multiple patches per batch - overlap: 0.5 # 50% overlap for smooth blending - blending: gaussian # Gaussian weighting for smooth blending - sigma_scale: 0.25 # Gaussian sigma scale - padding_mode: replicate # Replicate padding at volume boundaries - - # Test-Time Augmentation (TTA) for SDT + window_size: + - 16 + - 256 + - 256 + sw_batch_size: 4 + overlap: 0.5 + blending: gaussian + sigma_scale: 0.25 + padding_mode: replicate test_time_augmentation: enabled: true - flip_axes: null # Use all 8 flip augmentations (2^3 for xyz) - rotation90_axes: null # No 90-degree rotations (not isotropic) + rotation90_axes: null channel_activations: - - [0, 1, tanh] # SDT channel: tanh activation [-1, 1] + - - 0 + - 1 + - tanh select_channel: all - ensemble_mode: mean # Average predictions across augmentations - apply_mask: false # No mask for MitoEM - - # Save intermediate predictions + apply_mask: false save_prediction: - enabled: true # Save SDT predictions before decoding - intensity_scale: -1 # Keep original scale (no rescaling) - intensity_dtype: float32 # Keep as float32 for SDT - - # Decoding configuration (SDT → instances via watershed) + enabled: true + intensity_scale: -1 + intensity_dtype: float32 decoding: - - name: decode_distance_watershed # Watershed on SDT only - kwargs: - distance_threshold: [0.5, 0] # Threshold range for seeds (positive SDT values) - min_instance_size: 100 # Minimum mitochondria size (voxels) - min_seed_size: 50 # Minimum seed size (voxels) - prediction_scale: 1 # No scaling - - # Evaluation + - name: decode_distance_watershed + kwargs: + distance_threshold: + - 0.5 + - 0 + min_instance_size: 100 + min_seed_size: 50 + prediction_scale: 1 evaluation: enabled: true - metrics: [adapted_rand, voi] # Adapted Rand Score + VOI for instance segmentation + metrics: + - adapted_rand + - voi diff --git a/tutorials/mito_mitolab.yaml b/tutorials/mito_mitolab.yaml index 65c773b2..ed18ce59 100644 --- a/tutorials/mito_mitolab.yaml +++ b/tutorials/mito_mitolab.yaml @@ -1,207 +1,145 @@ -# CEM-MitoLab mitochondria segmentation with 2D MedNeXt -# Large-scale EM dataset with pre-tiled images -# -# This config uses MedNeXt 2D (ConvNeXt-based) for the MitoLab dataset, -# which contains 21,871 pre-tiled image/mask pairs from various EM sources. -# Uses filename-based loading instead of volume cropping. - +_base_: bases/mednext.yaml experiment_name: cem-mitolab_mednext2d description: Mitochondria segmentation on CEM-MitoLab dataset using MedNeXt 2D - -# System system: training: num_gpus: 1 num_cpus: 1 - num_workers: 0 # DEBUG: Disable parallel data loading to debug - batch_size: 16 # Larger batch for 2D (21,871 samples total) + num_workers: 0 + batch_size: 16 inference: - num_gpus: 1 num_cpus: 4 num_workers: 4 batch_size: 32 seed: 42 - -# Model - MedNeXt 2D (ConvNeXt-based architecture) model: - architecture: mednext - in_channels: 1 - out_channels: 3 - - # MedNeXt 2D architecture configuration - mednext_size: S # S (Small): ~5.6M params - good for 2D - mednext_kernel_size: 3 # Start with 3x3 kernels (recommended) - mednext_dim: "2d" # 2D convolutions for 2D images - deep_supervision: true # STRONGLY RECOMMENDED for MedNeXt - - # Loss configuration for multi-task learning (3 output channels) - # Task 1 (channel 0): Binary mask - use DiceLoss + BCE - # Task 2 (channel 1): Boundary mask - use BCE - # Task 3 (channel 2): EDT - use MSELoss - loss_functions: [DiceLoss, BCEWithLogitsLoss, WeightedMSELoss] - loss_weights: [1.0, 0.5, 1.0] + out_channels: 3 + mednext_size: S + mednext_kernel_size: 3 + mednext_dim: 2d + deep_supervision: true + loss_functions: + - DiceLoss + - BCEWithLogitsLoss + - WeightedMSELoss + loss_weights: + - 1.0 + - 0.5 + - 1.0 loss_kwargs: - - {include_background: false, sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} # DiceLoss for binary - - {} # BCEWithLogitsLoss for binary & boundary - - {tanh: true} # MSELoss for EDT - - # Multi-task configuration: [start_ch, end_ch, "task_name", [loss_indices]] - # Each task specifies which output channels and which losses to use + - include_background: false + sigmoid: true + smooth_nr: 1e-5 + smooth_dr: 1e-5 + - {} + - tanh: true multi_task_config: - - [0, 1, "binary", [0, 1]] # Channel 0: binary mask -> DiceLoss + BCE - - [1, 2, "boundary", [1]] # Channel 1: boundary mask -> BCE only - - [2, 3, "sdt", [2]] # Distance channel: MSE (with tanh activation) -# Data - Using JSON filename-based dataset + - - 0 + - 1 + - binary + - - 0 + - 1 + - - 1 + - 2 + - boundary + - - 1 + - - 2 + - 3 + - sdt + - - 2 data: - # Use filename-based dataset with JSON file lists - dataset_type: filename # Use MonaiFilenameDataset - - # JSON file with image/mask file lists + dataset_type: filename train_json: datasets/cem-mitolab/files.json - train_image_key: images # Key for image filenames in JSON - train_label_key: masks # Key for mask filenames in JSON - train_val_split: 0.9 # 90% train, 10% validation - - # Voxel resolution (physical dimensions in nm: y, x for 2D) - train_resolution: [0, 5, 5] # EM data: typically 5nm isotropic - - # Patch configuration (2D images are already tiles, no cropping needed) - # MitoLab images vary in size but are typically 224x224 or similar - # For 2D filename datasets: patch_size should be [H, W] not [Z, H, W] - patch_size: [224, 224] # 2D patch size (H, W) - pad_size: [0, 0] # No padding needed (images are pre-tiled) - iter_num_per_epoch: -1 # Use all samples per epoch - use_preloaded_cache: false # Images loaded on-demand - - # Image normalization + train_image_key: images + train_label_key: masks + train_val_split: 0.9 + train_resolution: + - 0 + - 5 + - 5 + patch_size: + - 224 + - 224 + pad_size: + - 0 + - 0 + iter_num_per_epoch: -1 + use_preloaded_cache: false image_transform: - normalize: "0-1" # Min-max normalization to [0, 1] clip_percentile_low: 0.0 clip_percentile_high: 1.0 label_transform: targets: - - name: binary - - name: instance_boundary - kwargs: - thickness: 1 - edge_mode: "seg-all" - mode: "2d" - - name: instance_edt - kwargs: - mode: "2d" - quantize: false - - # Augmentation + - name: binary + - name: instance_boundary + kwargs: + thickness: 1 + edge_mode: seg-all + mode: 2d + - name: instance_edt + kwargs: + mode: 2d + quantize: false augmentation: - preset: "all" # Legacy `enabled: true` -> modern preset mode - -# Optimizer - AdamW (MedNeXt paper recommendation) + preset: all optimization: - max_epochs: 100 # Fewer epochs needed with 21k samples - gradient_clip_val: 1.0 + max_epochs: 100 accumulate_grad_batches: 1 - precision: "bf16-mixed" # BFloat16 mixed precision - + precision: bf16-mixed optimizer: - name: AdamW - lr: 0.001 # MedNeXt paper: lr=1e-3 - weight_decay: 0.01 # L2 regularization - betas: [0.9, 0.999] - eps: 1.0e-8 - - # Scheduler - MedNeXt uses CONSTANT LR (no scheduler) - # For best performance with MedNeXt, use constant learning rate - # Using StepLR with gamma=1.0 keeps LR constant + lr: 0.001 + weight_decay: 0.01 + eps: 1.0e-08 scheduler: - name: StepLR # StepLR with no decay - step_size: 100000 # Very large step (never triggered) - gamma: 1.0 # No LR decay (constant LR) - + name: StepLR + step_size: 100000 + gamma: 1.0 monitor: - # Loss monitoring and validation frequency - detect_anomaly: false logging: - # scalar loss scalar: - loss: [train_loss_total_epoch] + loss: + - train_loss_total_epoch loss_every_n_steps: 50 - benchmark: true - - # visualization images: - enabled: true - max_images: 16 # More samples for diverse dataset - num_slices: 1 # 2D images (single slice) + max_images: 16 + num_slices: 1 log_every_n_epochs: 1 - channel_mode: all # Show all 3 channels for multi-task - selected_channels: null - - # Checkpointing checkpoint: monitor: val/loss - mode: min save_top_k: 3 - save_last: true save_every_n_epochs: 1 dirpath: outputs/cem-mitolab_mednext2d/checkpoints/ checkpoint_filename: epoch={epoch:03d}-val_loss={val/loss:.4f} - use_timestamp: true - - # Early stopping - early_stopping: - enabled: true + early_stopping: monitor: in_loss_total_epoch - patience: 20 # Less patience for large dataset - mode: min - min_delta: 1.0e-4 - check_finite: true + patience: 20 + min_delta: 0.0001 threshold: 0.05 divergence_threshold: 2.0 - -# Test data paths (schema-compliant location) test: data: - test_image: datasets/cem-mitolab/test.json # JSON file for filename-based dataset - test_label: null # Labels loaded from JSON - test_resolution: [0, 5, 5] + test_image: datasets/cem-mitolab/test.json + test_label: null + test_resolution: + - 0 + - 5 + - 5 output_path: outputs/cem-mitolab_mednext2d/results/ - -# Inference - For testing on held-out images inference: - # No sliding window needed - images are already tiles (direct inference) - - # Test-Time Augmentation (TTA) test_time_augmentation: - enabled: false # No TTA for pre-tiled images - flip_axes: null + enabled: false channel_activations: - - [0, 2, sigmoid] # Binary + boundary channels - - [2, 3, tanh] # Distance channel + - - 0 + - 2 + - sigmoid + - - 2 + - 3 + - tanh select_channel: null - ensemble_mode: mean - - # Save intermediate predictions save_prediction: enabled: true - intensity_scale: 255 # Scale to [0, 255] + intensity_scale: 255 intensity_dtype: uint8 - - # Evaluation evaluation: - enabled: false # Disabled for multi-task learning (metrics need channel selection) - metrics: [] # Would need: [dice, jaccard, precision, recall] - -# Notes: -# - MitoLab dataset contains diverse EM data from multiple sources -# - Images are pre-tiled at various sizes (typically 224x224 or similar) -# - Dataset is loaded via JSON file lists using MonaiFilenameDataset -# - No volume cropping needed - images are already segmented tiles -# - Train on 90% of 21,871 images (~19,684 train, ~2,187 val) -# - 2D model is faster and works well for pre-tiled data -# -# MedNeXt-specific notes: -# - MedNeXt 2D uses ConvNeXt architecture adapted for medical imaging -# - Deep supervision is STRONGLY RECOMMENDED (enabled in this config) -# - Uses constant LR (no scheduler) as per MedNeXt paper -# - Model size S (~5.6M params) is efficient for 2D images -# - Can use UpKern to upgrade from 3x3 to 5x5 kernels after training + enabled: false + metrics: [] diff --git a/tutorials/neuron_snemi.yaml b/tutorials/neuron_snemi.yaml index df277853..e4a4dc7e 100644 --- a/tutorials/neuron_snemi.yaml +++ b/tutorials/neuron_snemi.yaml @@ -1,139 +1,105 @@ -# RSUNet configuration for SNEMI3D (aug3-long variant) -# -# Faithfully reproduces Kisuk Lee's award-winning SNEMI3D model: -# - Architecture: Residual Symmetric 3D U-Net -# - Parameters: ~1.5M (original Caffe model) -# - Training: 500K-700K iterations (~5 days on 1x Titan X Pascal) -# - Key features: Anisotropic convolutions, 2D/3D hybrid, long-range affinities -# -# Reference: Lee et al. (2017) "Superhuman Accuracy on SNEMI3D" -# Original framework: Caffe -# This config: PyTorch Lightning + RSUNet - +_base_: bases/rsunet.yaml experiment_name: rsunet_snemi_aug3_long description: SNEMI3D neuron segmentation with RSUNet (long-range affinities, EM augmentations) - -# System system: training: num_gpus: 4 num_cpus: 8 num_workers: 8 - batch_size: 8 # Original: 1 patch per batch + batch_size: 8 inference: - num_gpus: 1 num_cpus: 1 num_workers: 4 batch_size: 1 seed: 42 - -# Model - RSUNet with SNEMI3D specifications model: - architecture: rsunet - in_channels: 1 - out_channels: 12 # 12 affinity maps (8 long-range + 4 z-direction) - - # Feature widths: 36 → 48 → 64 → 80 per scale - filters: [36, 48, 64, 80] - - # RSUNet-specific parameters - rsunet_norm: batch # Batch normalization (as in original) - rsunet_activation: elu # ELU activation (as in original) + out_channels: 12 + filters: + - 36 + - 48 + - 64 + - 80 + rsunet_norm: batch + rsunet_activation: elu rsunet_num_groups: 8 - - # Anisotropic downsampling - NO downsampling in Z - # Original: max-pool (2×2×1) and transposed conv (2×2×1) rsunet_down_factors: - - [1, 2, 2] # Level 1: preserve Z resolution - - [1, 2, 2] # Level 2 - - [1, 2, 2] # Level 3 - - # 2D/3D hybrid convolutions - # Original: 3×3×1 at finest scale, then 3×3×3 at coarser scales - rsunet_depth_2d: 1 # First layer uses 2D convolutions - rsunet_kernel_2d: [1, 3, 3] # 2D kernel: 3×3×1 - - # Deep supervision disabled (original doesn't use it) + - - 1 + - 2 + - 2 + - - 1 + - 2 + - 2 + - - 1 + - 2 + - 2 + rsunet_depth_2d: 1 + rsunet_kernel_2d: + - 1 + - 3 + - 3 deep_supervision: false - - # Loss configuration - # Original: Binomial cross-entropy with class rebalancing - loss_functions: [WeightedBCEWithLogitsLoss] - loss_weights: [1.0] + loss_functions: + - WeightedBCEWithLogitsLoss + loss_weights: + - 1.0 loss_kwargs: - - pos_weight: 10.0 # Class rebalancing (adjust based on dataset statistics) - - # Note: Long-range affinity maps (8 of 12 outputs) used only during training - # Inference uses only nearest-neighbor affinities (4 outputs) - -# Data + - pos_weight: 10.0 data: - # SNEMI3D dataset paths train_image: datasets/SNEMI/train-input.tif train_label: datasets/SNEMI/train-labels.tif - - # Training patch size: 16 × 160 × 160 voxels (Z, Y, X) - anisotropic - # Note: Changed Z from 18 to 16 (cleanly divisible by 2^n) to avoid rounding issues - patch_size: [16, 160, 160] - pad_size: [0, 0, 0] # No padding (same convolutions) - iter_num_per_epoch: 1280 # 700K iterations (upper bound) + patch_size: + - 16 + - 160 + - 160 + pad_size: + - 0 + - 0 + - 0 + iter_num_per_epoch: 1280 use_preloaded_cache: true - - # Image normalization image_transform: - normalize: "0-1" # Min-max normalization clip_percentile_low: 0.0 clip_percentile_high: 1.0 - - # Multi-channel label transformation - # Transforms instance segmentation labels into multiple output channels label_transform: - erosion: 1 # Border erosion (seg_widen_border with 3×3×1 kernel) - # Marks boundary voxels as background (Kisuk Lee's preprocessing) - - # Affinity map generation (for long-range connectivity) - # Offsets in (z, y, x) format - SNEMI3D specification + erosion: 1 targets: - - name: affinity - kwargs: - offsets: - # Short-range affinities - - "0-0-1" # x-direction, distance 1 - - "0-1-0" # y-direction, distance 1 - - "1-0-0" # z-direction, distance 1 - - "2-0-0" # z-direction, distance 2 - - "3-0-0" # z-direction, distance 3 - - "4-0-0" # z-direction, distance 4 - # Long-range affinities (exponential spacing: 3, 9, 27) - - "0-0-3" # x-direction, distance 3 - - "0-0-9" # x-direction, distance 9 - - "0-0-27" # x-direction, distance 27 - - "0-3-0" # y-direction, distance 3 - - "0-9-0" # y-direction, distance 9 - - "0-27-0" # y-direction, distance 27 - # Total: 12 affinity channels (6 short + 6 long) - - # Augmentation + - name: affinity + kwargs: + offsets: + - 0-0-1 + - 0-1-0 + - 1-0-0 + - 2-0-0 + - 3-0-0 + - 4-0-0 + - 0-0-3 + - 0-0-9 + - 0-0-27 + - 0-3-0 + - 0-9-0 + - 0-27-0 augmentation: - preset: "some" - - # Standard augmentations + preset: some flip: enabled: true prob: 0.5 - spatial_axis: [0, 1, 2] # Flip x/y/z - + spatial_axis: + - 0 + - 1 + - 2 rotate: enabled: true prob: 0.5 - max_angle: 90.0 # 90° rotations - + max_angle: 90.0 elastic: enabled: true prob: 0.7 - sigma_range: [4.0, 8.0] - magnitude_range: [8.0, 16.0] - + sigma_range: + - 4.0 + - 8.0 + magnitude_range: + - 8.0 + - 16.0 intensity: enabled: true gaussian_noise_prob: 0.3 @@ -141,157 +107,104 @@ data: shift_intensity_prob: 0.7 shift_intensity_offset: 0.2 contrast_prob: 0.7 - contrast_range: [0.5, 1.5] - - # EM-specific augmentations + contrast_range: + - 0.5 + - 1.5 misalignment: enabled: true prob: 0.5 - displacement: 17 # 0-17 pixels (original spec) - rotate_ratio: 0.5 # Mix of slip-type and translation-type - + displacement: 17 + rotate_ratio: 0.5 missing_section: enabled: true prob: 0.3 - num_sections: 5 # Max 5 slices - + num_sections: 5 motion_blur: enabled: true prob: 0.3 - sections: 2 # Number of sections to blur - kernel_size: 11 # Kernel size - + sections: 2 + kernel_size: 11 cut_noise: enabled: false - cut_blur: enabled: false - missing_parts: enabled: false - -# Optimization -# Original: Adam with α=0.01, β₁=0.9, β₂=0.999, ε=0.01 optimization: - max_epochs: 500 # ~700K iterations = many epochs - gradient_clip_val: 0.0 # No gradient clipping in original + max_epochs: 500 + gradient_clip_val: 0.0 accumulate_grad_batches: 1 - precision: "bf16-mixed" # BFloat16 mixed precision - - # Performance + precision: bf16-mixed deterministic: false benchmark: true - optimizer: - name: Adam # Adam (not AdamW) - lr: 0.01 # α = 0.01 (original) - betas: [0.9, 0.999] # β₁=0.9, β₂=0.999 - eps: 0.01 # ε = 0.01 (original, unusually high) - weight_decay: 0.0 # No weight decay in original - - # Learning rate scheduler - # Original: halve α when validation loss plateaus, ≤4 times + name: Adam + lr: 0.01 + eps: 0.01 + weight_decay: 0.0 scheduler: name: ReduceLROnPlateau mode: min - factor: 0.5 # Halve learning rate - patience: 10000 # Check every N iterations - min_lr: 0.0000625 # 0.01 / (2^4) = min LR after 4 halvings + factor: 0.5 + patience: 10000 + min_lr: 6.25e-05 threshold: 0.0001 cooldown: 0 - monitor: train_loss_total_epoch # Monitor validation loss - - -# Monitoring + monitor: train_loss_total_epoch monitor: - detect_anomaly: false - logging: scalar: - loss: [train_loss_total_epoch] + loss: + - train_loss_total_epoch loss_every_n_steps: 10 val_check_interval: 1.0 - benchmark: true - images: - enabled: true - max_images: 8 # Increased for multi-channel visualization + max_images: 8 num_slices: 8 log_every_n_epochs: 1 - channel_mode: all # 'argmax', 'all', or 'selected' - selected_channels: null # Only used when channel_mode='selected' - + channel_mode: all checkpoint: monitor: val/loss - mode: min save_top_k: 3 - save_last: true - save_every_n_epochs: 10 # Save frequently (long training) + save_every_n_epochs: 10 dirpath: outputs/rsunet_snemi_aug3_long/checkpoints/ checkpoint_filename: epoch={epoch:03d}-step={step:07d}-val_loss={val/loss:.4f} - use_timestamp: true - early_stopping: - enabled: false # Train for full duration - -# Test data paths (schema-compliant location) + enabled: false test: data: test_image: datasets/SNEMI/train-input.tif test_label: datasets/SNEMI/train-labels.tif - test_resolution: [30, 6, 6] + test_resolution: + - 30 + - 6 + - 6 output_path: outputs/rsunet_snemi_aug3_long/results/ - -# Inference inference: - # MONAI SlidingWindowInferer parameters sliding_window: - window_size: [16, 180, 180] # Patch size extracted from volume - sw_batch_size: 4 # Number of patches processed simultaneously - overlap: 0.5 # 50% overlap between patches - blending: gaussian # Gaussian weighting for smooth blending - sigma_scale: 0.25 # Gaussian sigma scale - padding_mode: reflect # Reflection-padding at volume boundaries - - # Test-Time Augmentation (TTA) + window_size: + - 16 + - 180 + - 180 + sw_batch_size: 4 + overlap: 0.5 + blending: gaussian + sigma_scale: 0.25 + padding_mode: reflect test_time_augmentation: enabled: true - flip_axes: null # Flip strategy: "all" (8 flips), null (no aug), or custom list - channel_activations: [[0, 12, softmax]] - select_channel: [1] # Channel selection: [1] (foreground only), null (all) - ensemble_mode: mean # Ensemble strategy: 'mean', 'min', 'max' - - # Save intermediate predictions + flip_axes: null + channel_activations: + - - 0 + - 12 + - softmax + select_channel: + - 1 save_prediction: enabled: true intensity_scale: 255 intensity_dtype: uint8 - - # Evaluation evaluation: - enabled: false # Use eval mode for BatchNorm - metrics: [adapted_rand] # Metrics to compute - - -# Notes on reproducing the original model: -# -# 1. Long-range affinities: -# - Output 12 affinity maps: -# * x/y directions: distances 1, 3, 9, 27 voxels (4 maps each = 8 total) -# * z direction: distances 1, 2, 3, 4 voxels (4 maps) -# - Long-range maps (8 of 12) used ONLY during training -# - Inference uses only nearest-neighbor (4 maps) -# -# 2. Class rebalancing: -# - Original uses weighted BCE to handle class imbalance -# - PyTC: use pos_weight parameter in BCEWithLogitsLoss -# - Calculate from training data: pos_weight = neg_samples / pos_samples -# -# 3. Initialization: -# - Original: He et al. (2015) method -# - PyTC RSUNet: already uses Kaiming (He) initialization -# -# 4. Training duration: -# - Original: ~5 days on 1x Titan X Pascal -# - Modern GPUs (V100/A100): ~2-3 days expected -# - 500K-700K iterations with batch_size=1 + enabled: false + metrics: + - adapted_rand diff --git a/tutorials/nuc_nucmm-z.yaml b/tutorials/nuc_nucmm-z.yaml index 387deebb..2aedacac 100644 --- a/tutorials/nuc_nucmm-z.yaml +++ b/tutorials/nuc_nucmm-z.yaml @@ -1,201 +1,141 @@ -# NucMM Zebrafish nucleus segmentation with MONAI Residual UNet -# Multi-task learning: Binary + Contour + Distance -# -# This config uses MONAI's UNet with residual units for nucleus segmentation -# with multi-task learning to predict: -# - Channel 0: Binary masks (sigmoid activation) -# - Channel 1: Contour maps (sigmoid activation) -# - Channel 2: Distance transforms (tanh activation) -# -# Based on NucMM-Zebrafish-UNet-BCD.yaml configuration from pytorch_connectomics_v1. -# The multi-task setup uses different loss functions for each channel: -# - Binary & Contour: DiceLoss + BCEWithLogitsLoss -# - Distance: WeightedMSE - +_base_: bases/monai_unet.yaml experiment_name: nucmm_zebrafish_monai_unet description: Zebrafish nucleus segmentation on NucMM dataset using MONAI Residual UNet with multi-task learning - -# System system: training: - num_gpus: 4 - num_cpus: 8 - num_workers: 8 # More workers for parallel data loading (6x speedup) - batch_size: 32 # Larger batch for smaller patches + batch_size: 32 inference: - num_gpus: 1 num_cpus: 4 num_workers: 4 batch_size: 4 - seed: 42 - -# Model - MONAI UNet with residual units for multi-task learning model: - architecture: monai_unet - input_size: [64, 64, 64] # 64x64x64 input patches - output_size: [64, 64, 64] # 64x64x64 output patches - in_channels: 1 - out_channels: 3 # 3 channels: binary, contour, distance - - # UNet architecture configuration (optimized for 64^3) - filters: [32, 64, 128, 256] # 4 levels for 64^3 (64->32->16->8->4) - strides: [2, 2, 2] # 3 downsampling levels - num_res_units: 2 # Residual units per block - kernel_size: 3 # Convolution kernel size + input_size: + - 64 + - 64 + - 64 + output_size: + - 64 + - 64 + - 64 + out_channels: 3 + strides: + - 2 + - 2 + - 2 + num_res_units: 2 + kernel_size: 3 norm: batch - dropout: 0.0 # No dropout for nucleus segmentation - - # Multi-task loss configuration - loss_functions: [DiceLoss, BCEWithLogitsLoss, WeightedMSELoss] - loss_weights: [1.0, 0.5, 2.0] # Binary: Dice+BCE, Contour: Dice+BCE, Distance: MSE + dropout: 0.0 + loss_functions: + - DiceLoss + - BCEWithLogitsLoss + - WeightedMSELoss + loss_weights: + - 1.0 + - 0.5 + - 2.0 loss_kwargs: - - {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} # DiceLoss for binary - - {} # BCEWithLogitsLoss for binary - - {tanh: true} # WeightedMSELoss for distance (with tanh activation) - - # Multi-task configuration - # Format: [[start_ch, end_ch, target_name, loss_indices], ...] + - sigmoid: true + smooth_nr: 1e-5 + smooth_dr: 1e-5 + - {} + - tanh: true multi_task_config: - - [0, 1, "label", [0, 1]] # Original labels: Dice + BCE - - [1, 2, "boundary", [0, 1]] # Boundary channel: Dice + BCE - - [2, 3, "edt", [2]] # Distance channel: MSE - -# Data - NucMM Zebrafish dataset (multiple volumes) + - - 0 + - 1 + - label + - - 0 + - 1 + - - 1 + - 2 + - boundary + - - 0 + - 1 + - - 2 + - 3 + - edt + - - 2 data: - # Volume configuration - supports multiple files via glob pattern or list - train_image: "datasets/NucMM-Z/Image/train/img_*.h5" # 27 training volumes - train_label: "datasets/NucMM-Z/Label/train/seg_*.h5" # 27 training labels - train_resolution: [1.0, 1.0, 1.0] # NucMM: 1.0 isotropic resolution - use_preloaded_cache: true # Pre-load all volumes into memory (MAJOR speedup) - cache_rate: 1.0 # Cache all volumes in RAM - persistent_workers: true # Keep workers alive between epochs (avoid restart overhead) - - # Patch configuration - patch_size: [64, 64, 64] # 64x64x64 training patches - pad_size: [16, 16, 16] # Reflection padding for context - pad_mode: reflect # Reflection padding at boundaries - iter_num_per_epoch: 1280 # More iterations for smaller patches - - # Image normalization - image_transform: - normalize: "0-1" # Min-max normalization to [0, 1] - clip_percentile_low: 0.0 # No clipping - clip_percentile_high: 1.0 - - # Label transformation for multi-task learning + train_image: datasets/NucMM-Z/Image/train/img_*.h5 + train_label: datasets/NucMM-Z/Label/train/seg_*.h5 + train_resolution: + - 1.0 + - 1.0 + - 1.0 + cache_rate: 1.0 + persistent_workers: true + patch_size: + - 64 + - 64 + - 64 + pad_size: + - 16 + - 16 + - 16 + pad_mode: reflect label_transform: targets: - - name: binary # Channel 0: foreground mask - - name: instance_boundary # Channel 1: contour map - kwargs: - thickness: 5 - edge_mode: "seg-no-bg" - mode: "3d" - - name: instance_edt # Channel 2: distance transform (bbox-optimized) - kwargs: - mode: "3d" # 2D EDT with per-instance bounding box optimization - quantize: false - - # Augmentation + - name: binary + - name: instance_boundary + kwargs: + thickness: 5 + edge_mode: seg-no-bg + mode: 3d + - name: instance_edt + kwargs: + mode: 3d + quantize: false augmentation: - preset: "all" # Legacy `enabled: true` -> modern preset mode - -# Optimizer - AdamW with NucMM-specific hyperparameters + preset: all optimization: - max_epochs: 1000 gradient_clip_val: 1.0 - accumulate_grad_batches: 1 - precision: "bf16-mixed" # BFloat16 mixed precision - optimizer: name: AdamW - lr: 0.001 # Lower LR for smaller patches (64^3) weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1.0e-8 - - # Scheduler - Cosine annealing with warmup scheduler: name: CosineAnnealingLR - warmup_epochs: 5 # Shorter warmup for 64^3 - warmup_start_lr: 1.0e-5 - min_lr: 1.0e-6 + warmup_epochs: 5 + warmup_start_lr: 1.0e-05 t_max: 995 - monitor: - # Loss monitoring and validation frequency - detect_anomaly: false logging: - # scalar loss - scalar: - loss: [train_loss_total_epoch] - loss_every_n_steps: 10 - val_check_interval: 1.0 - benchmark: true - - # visualization images: - enabled: true max_images: 2 num_slices: 4 - log_every_n_epochs: 1 - channel_mode: all # Show all 3 channels for multi-task - selected_channels: null - - # Checkpointing - checkpoint: - mode: min - save_top_k: 1 - save_last: true - save_every_n_epochs: 10 - use_timestamp: true - - # Early stopping - early_stopping: - enabled: true - monitor: train_loss_total_epoch + channel_mode: all + early_stopping: patience: 300 - mode: min - min_delta: 1.0e-5 - check_finite: true + min_delta: 1.0e-05 threshold: 0.01 divergence_threshold: 100.0 - -# Test data paths (schema-compliant location) test: data: test_image: datasets/NucMM-Z/Image/val/img_0000_0640_0896.h5 test_label: datasets/NucMM-Z/Label/val/seg_0000_0640_0896.h5 - test_resolution: [1.0, 1.0, 1.0] - -# Inference - MONAI SlidingWindowInferer for NucMM + test_resolution: + - 1.0 + - 1.0 + - 1.0 inference: - # MONAI SlidingWindowInferer parameters sliding_window: - window_size: [64, 64, 64] # Match training patch size - sw_batch_size: 8 # More patches for smaller window - overlap: 0.5 # 50% overlap between patches - blending: gaussian # Gaussian weighting for smooth blending - sigma_scale: 0.25 - padding_mode: reflect # Reflection-padding at volume boundaries - - # Test-Time Augmentation (TTA) + window_size: + - 64 + - 64 + - 64 + sw_batch_size: 8 + overlap: 0.5 + padding_mode: reflect test_time_augmentation: - enabled: true - flip_axes: null # No augmentation for NucMM + flip_axes: null channel_activations: - - [0, 2, sigmoid] # Binary + boundary channels - - [2, 3, tanh] # Distance channel - select_channel: null # Use all channels - ensemble_mode: mean - - # Save intermediate predictions - save_prediction: - enabled: true - intensity_scale: 255 - intensity_dtype: uint8 - - # Evaluation + - - 0 + - 2 + - sigmoid + - - 2 + - 3 + - tanh evaluation: enabled: true - metrics: [jaccard, dice] # Multiple metrics for nucleus segmentation + metrics: + - jaccard + - dice