Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ jobs:
pip install pytest pytest-cov
pip install -e .

- name: Validate tutorial configs
run: |
python scripts/validate_tutorial_configs.py

- name: Run tests
run: |
pytest tests/ -v --cov=connectomics --cov-report=xml
Expand Down
70 changes: 70 additions & 0 deletions AGENT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# AGENT.md

This file provides instructions for Codex and similar coding agents working in this repository.

## Project

PyTorch Connectomics (PyTC) is a Hydra/OmegaConf + PyTorch Lightning + MONAI codebase for EM segmentation.
Primary entry point: `scripts/main.py`.

## Mandatory Guardrails

- Preserve user-facing behavior unless explicitly requested.
- Preserve `scripts/main.py` CLI arguments and mode semantics.
- Preserve Hydra/OmegaConf key structure and override behavior.
- Prefer small, reviewable refactors over large rewrites.
- Do not add new runtime dependencies unless explicitly approved.

## Environment

- Use conda env `pytc` for all validation.
- Do not install into `base`.
- Prefer `conda run -n pytc <command>` for deterministic execution.

## Required Verification

Run these after meaningful changes (or explain why not possible):

- `conda run -n pytc python scripts/main.py --demo`
- `conda run -n pytc pytest -q`

For targeted changes, run focused tests plus the demo smoke test.

## Lint/Type Checks

Use changed-file scope unless specifically fixing global style/type debt:

- `black --check <changed_py_files>`
- `isort --check-only <changed_py_files>`
- `flake8 --max-line-length=100 <changed_py_files>`
- `mypy --config-file .github/mypy_changed.ini <changed_py_files>`

Note: repository-wide `black --check connectomics/` is not currently clean.

## Code Change Style

- Keep diffs minimal and localized.
- Avoid changing config keys and defaults without a clear migration plan.
- Keep module boundaries explicit (config, data, training, decoding).
- Add or update tests when behavior or contracts are touched.

## Git and PR Expectations

- Commit logically (one milestone/concern per commit).
- In PR descriptions, include:
- what changed
- why
- exact validation commands run
- key results

## Repository Hotspots

- Config system: `connectomics/config/`
- Data pipeline: `connectomics/data/`
- Lightning runtime: `connectomics/training/lit/`
- Decoding/postprocess: `connectomics/decoding/`

## Safety

- Do not use destructive git commands (`reset --hard`, checkout discards) unless explicitly requested.
- If unexpected working-tree changes appear, pause and confirm intent before touching them.
10 changes: 6 additions & 4 deletions connectomics/config/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from connectomics.config.hydra_utils import load_config

# Load configuration from YAML
cfg = load_config("tutorials/monai_lucchi++.yaml")
cfg = load_config("tutorials/mito_lucchi++.yaml")

# Access configuration sections
print(f"Model architecture: {cfg.model.architecture}")
Expand All @@ -37,9 +37,10 @@
"""

from __future__ import annotations
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, List, Optional, Tuple, Any, Union

import inspect
from dataclasses import dataclass, field, is_dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

# Note: MISSING can be imported from omegaconf if needed for required fields

Expand Down Expand Up @@ -946,7 +947,8 @@ class SavePredictionConfig:
enabled: Enable saving intermediate predictions (default: True)
intensity_scale: Scale factor for predictions (e.g., 255 for uint8 visualization)
intensity_dtype: Data type for saved predictions (e.g., 'uint8', 'float32')
output_formats: List of output formats to save predictions in (e.g., ['h5', 'tiff', 'nii.gz'])
output_formats: List of output formats to save predictions
(e.g., ['h5', 'tiff', 'nii.gz'])
Supported formats: 'h5', 'tiff', 'nii', 'nii.gz', 'png'
Default: ['h5', 'nii.gz'] for backward compatibility
"""
Expand Down
129 changes: 93 additions & 36 deletions connectomics/config/hydra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,74 @@
"""

from __future__ import annotations

from pathlib import Path
from typing import Optional, Union, Dict, Any, List
from omegaconf import OmegaConf, DictConfig
from typing import Any, Dict, List, Optional, Tuple, Union

from omegaconf import DictConfig, ListConfig, OmegaConf

from .hydra_config import Config


def _normalize_base_paths(base_field: Any, config_path: Path) -> List[Path]:
"""Normalize `_base_` field to an ordered list of absolute paths."""
if base_field is None:
return []

if isinstance(base_field, (str, Path)):
base_entries = [str(base_field)]
elif isinstance(base_field, (list, tuple, ListConfig)):
base_entries = [str(item) for item in base_field]
else:
raise TypeError(
f"Invalid _base_ value in {config_path}: expected string or list, "
f"got {type(base_field)}"
)

resolved_paths: List[Path] = []
for base_entry in base_entries:
base_path = Path(base_entry)
if not base_path.is_absolute():
base_path = (config_path.parent / base_path).resolve()
if not base_path.exists():
raise FileNotFoundError(
f"Base config not found: {base_entry} (resolved to {base_path}) in {config_path}"
)
resolved_paths.append(base_path)

return resolved_paths


def _load_config_with_bases(config_path: Path, loading_stack: Tuple[Path, ...] = ()) -> DictConfig:
"""Load YAML config recursively with `_base_` inheritance."""
config_path = config_path.resolve()
if config_path in loading_stack:
cycle = " -> ".join(str(p) for p in (*loading_stack, config_path))
raise ValueError(f"Detected cyclic _base_ config inheritance: {cycle}")

if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")

yaml_conf = OmegaConf.load(config_path)
if yaml_conf is None:
yaml_conf = OmegaConf.create({})
if not isinstance(yaml_conf, DictConfig):
raise TypeError(
f"Config root must be a mapping in {config_path}, got {type(yaml_conf)} instead"
)

base_field = yaml_conf.get("_base_", None)
if "_base_" in yaml_conf:
del yaml_conf["_base_"]

merged_base = OmegaConf.create({})
for base_path in _normalize_base_paths(base_field, config_path):
base_conf = _load_config_with_bases(base_path, (*loading_stack, config_path))
merged_base = OmegaConf.merge(merged_base, base_conf)

return OmegaConf.merge(merged_base, yaml_conf)


def load_config(config_path: Union[str, Path]) -> Config:
"""
Load configuration from YAML file.
Expand All @@ -22,12 +83,8 @@ def load_config(config_path: Union[str, Path]) -> Config:
Returns:
Config object with defaults merged
"""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")

# Load YAML
yaml_conf = OmegaConf.load(config_path)
config_path = Path(config_path).resolve()
yaml_conf = _load_config_with_bases(config_path)

# Merge with structured config defaults
default_conf = OmegaConf.structured(Config)
Expand Down Expand Up @@ -190,7 +247,7 @@ def validate_config(cfg: Config) -> None:
if cfg.optimization.max_epochs <= 0:
raise ValueError("optimization.max_epochs must be positive when max_steps is not set")
# If max_steps is set, max_epochs can be anything (will be overridden to -1 in trainer)

if cfg.optimization.gradient_clip_val < 0:
raise ValueError("optimization.gradient_clip_val must be non-negative")
if cfg.optimization.accumulate_grad_batches <= 0:
Expand Down Expand Up @@ -299,9 +356,11 @@ def _combine_path(

# Handle list of paths
if isinstance(file_path, list):
result = []
result: List[str] = []
for p in file_path:
resolved = _combine_path(base_path, p)
if resolved is None:
continue
# If resolved is a list (from glob expansion), extend
if isinstance(resolved, list):
result.extend(resolved)
Expand Down Expand Up @@ -375,42 +434,40 @@ def _combine_path(
cfg.data.train_image = _combine_path(train_base, cfg.data.train_image)
cfg.data.train_label = _combine_path(train_base, cfg.data.train_label)
cfg.data.train_mask = _combine_path(train_base, cfg.data.train_mask)
cfg.data.train_json = _combine_path(train_base, cfg.data.train_json)
train_json_resolved = _combine_path(train_base, cfg.data.train_json)
if isinstance(train_json_resolved, list):
cfg.data.train_json = train_json_resolved[0] if train_json_resolved else None
else:
cfg.data.train_json = train_json_resolved

# Resolve validation paths (always expand globs, use val_path as base if available)
val_base = cfg.data.val_path if cfg.data.val_path else ""
cfg.data.val_image = _combine_path(val_base, cfg.data.val_image)
cfg.data.val_label = _combine_path(val_base, cfg.data.val_label)
cfg.data.val_mask = _combine_path(val_base, cfg.data.val_mask)
cfg.data.val_json = _combine_path(val_base, cfg.data.val_json)
val_json_resolved = _combine_path(val_base, cfg.data.val_json)
if isinstance(val_json_resolved, list):
cfg.data.val_json = val_json_resolved[0] if val_json_resolved else None
else:
cfg.data.val_json = val_json_resolved

# Resolve test data paths (cfg.test.data.test_*)
if hasattr(cfg, "test") and hasattr(cfg.test, "data"):
test_base = (
cfg.test.data.test_path
if hasattr(cfg.test.data, "test_path") and cfg.test.data.test_path
else ""
)
if hasattr(cfg.test.data, "test_image"):
cfg.test.data.test_image = _combine_path(test_base, cfg.test.data.test_image)
if hasattr(cfg.test.data, "test_label"):
cfg.test.data.test_label = _combine_path(test_base, cfg.test.data.test_label)
if hasattr(cfg.test.data, "test_mask"):
cfg.test.data.test_mask = _combine_path(test_base, cfg.test.data.test_mask)
if cfg.test is not None:
test_data = cfg.test.data
test_path_value = getattr(test_data, "test_path", "")
test_base = test_path_value if isinstance(test_path_value, str) else ""
test_data.test_image = _combine_path(test_base, test_data.test_image)
test_data.test_label = _combine_path(test_base, test_data.test_label)
test_data.test_mask = _combine_path(test_base, test_data.test_mask)

# Resolve tuning data paths (cfg.tune.data.tune_*)
if hasattr(cfg, "tune") and hasattr(cfg.tune, "data"):
tune_base = (
cfg.tune.data.test_path
if hasattr(cfg.tune.data, "test_path") and cfg.tune.data.test_path
else ""
)
if hasattr(cfg.tune.data, "tune_image"):
cfg.tune.data.tune_image = _combine_path(tune_base, cfg.tune.data.tune_image)
if hasattr(cfg.tune.data, "tune_label"):
cfg.tune.data.tune_label = _combine_path(tune_base, cfg.tune.data.tune_label)
if hasattr(cfg.tune.data, "tune_mask"):
cfg.tune.data.tune_mask = _combine_path(tune_base, cfg.tune.data.tune_mask)
if cfg.tune is not None:
tune_data = cfg.tune.data
tune_path_value = getattr(tune_data, "test_path", "")
tune_base = tune_path_value if isinstance(tune_path_value, str) else ""
tune_data.tune_image = _combine_path(tune_base, tune_data.tune_image)
tune_data.tune_label = _combine_path(tune_base, tune_data.tune_label)
tune_data.tune_mask = _combine_path(tune_base, tune_data.tune_mask)

return cfg

Expand Down
Loading