Skip to content

Commit afdc0a7

Browse files
Donglai Weiclaude
andcommitted
Fix lazy normalization; align BANIS inference with lib/banis
- smart_normalize now parses 'divide-K' mode strings. The lazy inference path calls smart_normalize directly with the raw mode and was silently no-op'ing on 'divide-255', sending uint8 [0,255] inputs to a model trained on [0,1]. Symptom: all-near-1 sigmoid predictions, BCE ~3.4 on training-distribution crops despite reported train loss ~0.14. - _resolve_affinity_inference_crop short-circuits for affinity_mode=banis so predictions keep the full input shape, matching lib/banis (cc3d source-stored decoder bounds-checks; trailing-edge values are harmless). - base_banis.yaml: window_size=144 to give boundary affinities +16 voxels of real surrounding-volume context via MONAI's gaussian blending (cleaner than per-window target_context oversample + central crop). Keep snap_to_edge for the lazy path. - Add .claude/banis/inference.md comparing pytc vs lib/banis inference; reorganize tutorial yamls (drop 40nm/v0 stubs, add base_banis_v1/v2, erosion2 variants, and base_banis_chunk). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b20a08d commit afdc0a7

77 files changed

Lines changed: 2197 additions & 1308 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
File renamed without changes.

.claude/banis/inference.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Inference: `base_banis.yaml` vs `lib/banis`
2+
3+
Comparison of pytc whole-volume affinity inference (`tutorials/neuron_nisb/base_banis.yaml`) against the reference `lib/banis/inference.py` + `lib/banis/BANIS.py`.
4+
5+
## Match
6+
7+
| Aspect | Both |
8+
| --- | --- |
9+
| Window size | 128³ |
10+
| Overlap | 50% (BANIS: `small_size // 2 = 64` shift) |
11+
| Activation | `scale_sigmoid(x) = sigmoid(0.2·x)` on output channels |
12+
| Precision | fp16 autocast |
13+
| Input normalization | divide-255, XYZ layout, no transpose |
14+
| Decoded channels | first 3 (short-range) via `select_channel: [0,1,2]` |
15+
| Decoder | connected components, 6-connectivity, source-stored edges (`edge_offset: 0`) |
16+
| TTA | disabled (BANIS has no flip/rotate TTA) |
17+
| Whole-volume strategy | both load full image then patch |
18+
19+
## Differences
20+
21+
1. **Blending weight**
22+
- pytc: `blending: gaussian, sigma_scale=0.25` (Gaussian importance map).
23+
- BANIS: `distance_transform_cdt` of zero-padded ones cube → L1 chamfer distance from surface (zero on faces, max in center). `lib/banis/inference.py:209-210`.
24+
25+
2. **Boundary handling**
26+
- pytc: `padding_mode: replicate` — MONAI pads the *whole* volume up front so windows align.
27+
- BANIS: no padding. `get_offsets` always sets the final offset to `big_size - small_size`, so every window fits fully inside the volume. `lib/banis/inference.py:189-191`.
28+
29+
3. **Threshold**
30+
- pytc: fixed `threshold: 0.5`.
31+
- BANIS: sweeps over `eval_ranges = sigmoid(0.2 · range(-1, 12))``[0.45, 0.55, 0.65, 0.73, 0.80, 0.85, 0.89, 0.91, ...]`, picks val-best by NERL, reuses on test. `lib/banis/BANIS.py:439`, `lib/banis/BANIS.py:209-211`.
32+
33+
4. **Patch grid**
34+
- pytc: regular MONAI grid at 50% stride.
35+
- BANIS: base grid + 7 shifted sets (all combinations of `+small_size//2` per axis) unioned and de-duped. `lib/banis/inference.py:154-174`. Slightly more centers near boundaries.
36+
37+
5. **Stored prediction channels**
38+
- pytc: short-range only (`select_channel: [0,1,2]`).
39+
- BANIS: all 6 channels written to `pred_aff_*.zarr`; decoding still reads `[:3]`. `lib/banis/BANIS.py:199-200`, `lib/banis/BANIS.py:217`.
40+
41+
## To match BANIS exactly
42+
43+
- Replace `blending: gaussian` with custom L1-distance window (or accept gaussian as a near-equivalent at 50% overlap).
44+
- Drop `padding_mode: replicate` and use BANIS-style snap-to-edge offsets (last offset = `image_size - roi_size`).
45+
- Run a decoding threshold sweep over BANIS' `eval_ranges` and pick best by NERL on val before testing.
46+
47+
Items 1–2 are cosmetic at 50% overlap; #3 is the main accuracy lever.
48+
49+
## Boundary handling in pytc
50+
51+
Two paths matter:
52+
53+
- **Lazy sliding-window path** (`connectomics/inference/lazy.py`, used when `inference.sliding_window.lazy_load=true`). Honors `snap_to_edge: true` (last window at `image_size - roi_size`, no whole-volume padding) and per-window `target_context` (read `roi + 2·context`, predict, central-crop). `base_banis.yaml` uses this path.
54+
- **Eager MONAI path** (`connectomics/inference/sliding.py`, MONAI's `SlidingWindowInferer`). Vanilla MONAI; ignores `snap_to_edge` / `target_context`. For BANIS-flavored boundary context here, just bump `window_size` larger than the training patch — see below.
55+
56+
## The `window_size = roi + extra` hack
57+
58+
Instead of per-window `target_context` oversample + central crop (extra code, extra forwards), set `window_size` larger than the training patch and rely on default gaussian blending to de-emphasize the outer band:
59+
60+
```yaml
61+
sliding_window:
62+
window_size: [144, 144, 144] # 128 (training) + 16 context per axis
63+
blending: gaussian
64+
sigma_scale: 0.25
65+
overlap: 0.5
66+
```
67+
68+
- Interior windows naturally pick up real surrounding-volume voxels in the +16 band.
69+
- Default gaussian (`sigma_scale=0.125–0.25`) gives the outer band ~5× less weight than the central edge — soft taper, no hard mask, no boundary coverage hole.
70+
- Must be a multiple of the model's downsample stride. MedNeXt-S has 4 stages → 144 ✓ (128 + 16); 138 (BANIS training oversample) ✗.
71+
- `~2×` per-patch GPU memory at 144 vs 128. Verify it fits with fp16.
72+
73+
This works for both the lazy and eager paths and replaces the need for an inference-time `target_context` config.
74+
75+
## What's still BANIS-specific
76+
77+
`snap_to_edge: true` (in the yaml) only affects the lazy path and is the BANIS-faithful behavior — model never sees padded volume data. The eager path uses MONAI's whole-volume padding, which is functionally close at 50% overlap.

.claude/refactor/config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Profiles are named YAML snippets in `tutorials/bases/*.yaml`, resolved pre-conve
7777
| `optimizer_profiles` | `{stage}.optimization.profile` | `{stage}.optimization` |
7878
| `loss_profiles` | `{stage}.model.loss.profile` | `{stage}.model.loss.losses` |
7979
| `label_profiles` | `{stage}.data.label_transform.profile` | `{stage}.data.label_transform` |
80-
| `decoding_profiles` | `{stage}.inference.decoding_profile` | `{stage}.inference.decoding` |
80+
| `decoding_templates` | list refs under `{stage}.decoding` | `{stage}.decoding` |
8181
| `activation_profiles` | `{stage}.inference.test_time_augmentation.activation_profile` | `{stage}.inference.test_time_augmentation.channel_activations` |
8282

8383
Selectors are only accepted at canonical paths; non-canonical paths raise `ValueError`.
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Inference / Decoding Split Refactor
2+
3+
## Problem
4+
5+
The current test path mixes three responsibilities:
6+
7+
- deep learning inference,
8+
- prediction artifact storage,
9+
- decoding/postprocessing/evaluation.
10+
11+
This is most visible in chunked inference: `run_chunked_affinity_cc_inference`
12+
predicts a chunk, immediately decodes it, stitches labels, and writes the final
13+
segmentation. That is memory efficient, but it cannot reproduce whole-volume
14+
decoding exactly because connected components are solved per chunk and then
15+
stitched heuristically.
16+
17+
## Target Design
18+
19+
Treat model inference and decoding as separate stages.
20+
21+
1. Model inference writes a raw prediction artifact.
22+
The artifact is file-backed, chunked, and has a stable layout:
23+
`(C, Z, Y, X)` for one volume after inference-time crop/channel selection.
24+
25+
2. Decoding consumes a raw prediction artifact and writes a segmentation
26+
artifact.
27+
It should not require model construction, checkpoint loading, or GPU setup.
28+
29+
3. The combined test path remains a convenience wrapper.
30+
It can run inference, then optionally decode the just-written artifact.
31+
32+
4. Evaluation should become its own top-level stage.
33+
It should not live under `decoding`, because metrics consume decoded
34+
artifacts and labels regardless of which decoder or cache produced them.
35+
The config tree now has a dedicated top-level/default/test/tune
36+
`evaluation` section.
37+
38+
## Config Contract
39+
40+
`inference.decode_after_inference`
41+
42+
- `true`: current convenience behavior; decode after prediction.
43+
- `false`: stop after writing raw predictions.
44+
45+
`inference.chunking.output_mode`
46+
47+
- `decoded`: current streaming chunk decode/stitch behavior.
48+
- `raw_prediction`: stream chunked model predictions into one raw prediction
49+
HDF5, then optionally run the normal whole-volume decoding path.
50+
51+
Existing decode-only mode remains:
52+
53+
```yaml
54+
inference:
55+
saved_prediction_path: /path/to/raw_prediction.h5
56+
decoding:
57+
- name: decode_affinity_cc
58+
kwargs:
59+
threshold: 0.7
60+
backend: numba
61+
edge_offset: 0
62+
```
63+
64+
## Implementation Plan
65+
66+
1. Add schema fields for `decode_after_inference` and chunked `output_mode`.
67+
2. Split chunked inference code into two entry points:
68+
`run_chunked_prediction_inference` for raw prediction writing, and
69+
`run_chunked_affinity_cc_inference` for existing streamed decode/stitch.
70+
3. Route `test_pipeline` chunked mode based on `chunking.output_mode`.
71+
4. For `raw_prediction`, write the raw file first. If
72+
`decode_after_inference=true`, load that file and reuse the standard
73+
decode/postprocess/save/evaluate path.
74+
5. Keep decode-only via `inference.saved_prediction_path` as the standalone
75+
decoding entry for now. A future CLI can expose it as `--mode decode`.
76+
77+
## Implemented
78+
79+
- `decoding` is a top-level/default/test/tune stage section.
80+
- `evaluation` is a top-level/default/test/tune stage section.
81+
- Tutorial YAMLs use `default.decoding`/`test.decoding` and
82+
`default.evaluation`/`test.evaluation` instead of nested inference sections.
83+
84+
## Follow-Ups
85+
86+
- Store prediction artifact metadata such as channel order, crop, activation,
87+
checkpoint, and value scale in a small sidecar or HDF5 attrs.
88+
- Add lazy/blockwise decode readers for decoders that can operate without
89+
materializing the full prediction volume.

CLAUDE.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ configs/ # Canonical shared YAML registries
289289
│ ├── label_profiles.yaml # Label-transform presets
290290
│ └── ... # system, dataloader, augmentation, pipeline, tune
291291
└── templates/ # Explicit list-item templates
292-
└── decoding_templates.yaml # `inference.decoding` templates (`template: ...`)
292+
└── decoding_templates.yaml # top-level `decoding` templates (`template: ...`)
293293
294294
tutorials/ # Example configurations
295295
├── misc/ # Miscellaneous experiments
@@ -316,15 +316,15 @@ The project uses **Hydra/OmegaConf** with dataclass-based configs for type safet
316316
Canonical YAML layout:
317317

318318
- `connectomics/config/profiles/*.yaml`: section-level registries selected by `*.profile`
319-
- `connectomics/config/templates/*.yaml`: explicit list-item templates, currently for `inference.decoding`
319+
- `connectomics/config/templates/*.yaml`: explicit list-item templates, currently for top-level `decoding`
320320
- `tutorials/*.yaml`: runnable experiments that `_base_` the shared registries
321321

322322
Canonical merge semantics:
323323

324324
- Profile payloads are merged into the target section as the base config.
325325
- Explicit keys override profile keys.
326326
- Explicit lists replace profile lists; list overrides are not additive.
327-
- Canonical decoding expansion is explicit `template:` inside `inference.decoding`.
327+
- Canonical decoding expansion is explicit `template:` inside top-level `decoding`.
328328
- Do not introduce `decoding_profile` or `- profile: decoding_*` usages.
329329

330330
**Config File Example** (`tutorials/lucchi.yaml`):
@@ -757,7 +757,7 @@ All previously identified technical debt items have been addressed. Below is the
757757
31. ~~Pass-through `create_volume_data_dicts()`~~ ✅ (removed)
758758
32. ~~Python 2 `__future__` imports~~ ✅ (removed)
759759
33. ~~`cfg.inference.*` references~~ ✅ (valid InferenceConfig in TestConfig, not legacy)
760-
34. ~~Legacy `test.decoding` fallback~~ ✅ (uses `inference.decoding` directly)
760+
34. ~~Legacy `test.decoding` fallback~~ ✅ (uses top-level `decoding` directly)
761761
35. ~~Unnecessary try-except for RSUNet import~~ ✅ (removed)
762762
36. ~~Hardcoded architecture list~~ ✅ (queries registry dynamically)
763763
37. ~~Duplicate `_to_plain_dict`/`_as_dict`~~ ✅ (consolidated to `config/dict_utils.py`)

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,14 @@ just test lucchi++ outputs/lucchi++/$EXPERIMENT_DATE/checkpoints/best.ckpt
186186

187187
- `tutorials/*.yaml`: runnable experiment configs
188188
- `connectomics/config/profiles/*.yaml`: section-level registries selected by `*.profile`
189-
- `connectomics/config/templates/*.yaml`: explicit list-item templates, currently used for `inference.decoding`
189+
- `connectomics/config/templates/*.yaml`: explicit list-item templates, currently used for top-level `decoding`
190190

191191
Merge rule:
192192

193193
- Profile payloads are merged into the target section as the base config.
194194
- Explicit keys in the tutorial/config override profile keys.
195195
- Explicit lists replace profile lists; they are not additive.
196-
- Canonical decoding syntax is explicit list expansion, for example `inference.decoding: [{template: decoding_waterz}]`.
196+
- Canonical decoding syntax is explicit list expansion, for example `decoding: [{template: decoding_waterz}]`.
197197

198198
---
199199

connectomics/config/pipeline/config_io.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ def validate_config(cfg: Config) -> None:
393393
axes = str(getattr(chunking_cfg, "axes", "all")).lower()
394394
if axes not in {"all", "z"}:
395395
raise ValueError("inference.chunking.axes must be 'all' or 'z'")
396+
output_mode = str(getattr(chunking_cfg, "output_mode", "decoded")).lower()
397+
if output_mode not in {"decoded", "raw_prediction"}:
398+
raise ValueError("inference.chunking.output_mode must be 'decoded' or 'raw_prediction'")
396399
chunk_size = getattr(chunking_cfg, "chunk_size", None)
397400
if not chunk_size or len(chunk_size) != 3:
398401
raise ValueError("inference.chunking.chunk_size must be a length-3 ZYX list")
@@ -615,7 +618,7 @@ def _validate_label_channel_capacity(selector_value: Any, *, path: str) -> None:
615618
)
616619

617620
# 2d) Decoding kwargs channel selectors (*_channels)
618-
decoding_cfg = getattr(cfg.inference, "decoding", None)
621+
decoding_cfg = getattr(cfg, "decoding", None)
619622
decode_has_channel_selection = False
620623
decode_output_head = None
621624
decode_available_channels = out_channels
@@ -660,10 +663,10 @@ def _validate_label_channel_capacity(selector_value: Any, *, path: str) -> None:
660663
continue
661664
min_channels = infer_min_required_channels(
662665
value,
663-
context=f"inference.decoding[{i}].kwargs.{key}",
666+
context=f"decoding[{i}].kwargs.{key}",
664667
)
665668
if min_channels is not None:
666-
path = f"inference.decoding[{i}].kwargs.{key}"
669+
path = f"decoding[{i}].kwargs.{key}"
667670
if model_heads and decode_has_channel_selection:
668671
if min_channels > decode_available_channels:
669672
raise ValueError(

connectomics/config/pipeline/profile_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ def apply(self, yaml_conf: DictConfig) -> DictConfig:
467467
_STAGE_TRAIN = "train"
468468
_STAGE_TEST = "test"
469469
_STAGE_TUNE = "tune"
470+
_STAGE_ROOT = ""
470471

471472

472473
def _stage_path(stage: str, rel_path: str) -> str:
@@ -606,8 +607,8 @@ def _build_reference_profile_specs() -> List[Tuple[str, List[str]]]:
606607
_LIST_REFERENCE_FAMILIES: List[Tuple[str, Tuple[str, ...], str, str]] = [
607608
(
608609
"decoding_templates",
609-
(_STAGE_DEFAULT, _STAGE_TUNE, _STAGE_TEST),
610-
"inference.decoding",
610+
(_STAGE_ROOT, _STAGE_DEFAULT, _STAGE_TUNE, _STAGE_TEST),
611+
"decoding",
611612
"decoding",
612613
),
613614
]

0 commit comments

Comments
 (0)