Skip to content

Commit 6b5d55c

Browse files
author
Donglai Wei
committed
Add BANIS affinity reproduction and chunked inference
1 parent 30bdfc3 commit 6b5d55c

87 files changed

Lines changed: 6278 additions & 604 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.claude/benchmark/SNEMI.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,14 @@ No per-tutorial loss config needed — the pipeline profile handles it.
158158

159159
| | DeepEM | PyTC (before fix) | PyTC (after fix) |
160160
|---|---|---|---|
161-
| **Border handling** | Per-channel `get_pair` crop + mask | Uniform `deepem_crop` (max-offset spatial crop) | Per-channel valid mask ✅ |
161+
| **Border handling** | Per-channel `get_pair` crop + mask | Old uniform max-offset spatial crop | `affinity_mode=deepem` per-channel valid mask ✅ |
162162
| **Augment padding on labels** | Mask propagated through augmentation | `RandAffined` reflection padding on labels → false affinities | Per-channel mask excludes border artifacts ✅ |
163163

164164
**Problem:** Two interacting issues caused border artifacts in affinity targets:
165165

166166
1. **Reflection padding on labels during augmentation**: `RandAffined` and `RandElasticd` use `padding_mode="reflection"` for all keys including labels. When spatial transforms rotate/scale/shear a patch, border pixels are filled with reflected label values. Computing affinity from these reflected labels creates false affinities — especially visible for long-range channels like ch11 (offset `0-27-0`) where the reflected region spans 27 voxels.
167167

168-
2. **Uniform spatial crop vs per-channel masking**: The old `deepem_crop` computed the **union** of all offsets' invalid borders and uniformly cropped all channels to this smallest valid region. For the SNEMI 12-channel offsets, this meant cropping (4, 27, 27) from all channels — even short-range channels that only need 1 voxel cropped. This wasted ~35% of training data for short-range channels.
168+
2. **Uniform spatial crop vs per-channel masking**: The old crop path computed the **union** of all offsets' invalid borders and uniformly cropped all channels to this smallest valid region. For the SNEMI 12-channel offsets, this meant cropping (4, 27, 27) from all channels — even short-range channels that only need 1 voxel cropped. This wasted ~35% of training data for short-range channels.
169169

170170
**DeepEM's approach**: In DeepEM, `get_pair(arr, edge)` extracts two aligned crops per channel, computing affinity only in the overlap region. A separate mask (propagated through augmentation) excludes padded regions from the loss. Each channel has its own valid region.
171171

.claude/refactor/affinity.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Affinity Modes
2+
3+
Affinity targets must declare one explicit convention:
4+
5+
```yaml
6+
data:
7+
label_transform:
8+
targets:
9+
- name: affinity
10+
kwargs:
11+
offsets: ["0-0-1", "0-1-0", "1-0-0"]
12+
affinity_mode: deepem # or banis
13+
```
14+
15+
There is no legacy crop flag. The mode controls target voxel placement,
16+
valid-border masking, visualization crop, and test-time prediction crop.
17+
18+
## Modes
19+
20+
| Mode | Edge storage | Valid side for positive offsets | Intended use |
21+
| --- | --- | --- | --- |
22+
| `deepem` | Destination voxel `v + offset` | Leading border invalid | DeepEM/SNEMI-style targets, zwatershed/ABISS destination-index affinity |
23+
| `banis` | Source voxel `v` | Trailing border invalid | BANIS-compatible targets and source-index connected-component decoding |
24+
25+
For a positive x offset `0-0-1`, `deepem` produces valid affinities at
26+
`x >= 1`; `banis` produces valid affinities at `x < W - 1`.
27+
28+
## Training
29+
30+
Training does not crop every affinity channel to the largest common valid
31+
interior. It keeps prediction and target shapes unchanged and applies a
32+
per-channel valid mask before loss evaluation. This preserves short-range edge
33+
supervision while excluding convention-dependent padded borders.
34+
35+
Mixed affinity modes in one stacked label tensor are rejected. If multiple
36+
affinity target groups are ever needed, they must share the same
37+
`affinity_mode`.
38+
39+
## Inference And Decoding
40+
41+
Test-time affinity crops use the same mode as training:
42+
43+
```python
44+
compute_affinity_crop_pad(offsets, affinity_mode="deepem")
45+
compute_affinity_crop_pad(offsets, affinity_mode="banis")
46+
```
47+
48+
The crop is resolved after `inference.select_channel` and output-head target
49+
slices. If decoding keeps only short-range channels from a larger affinity
50+
target stack, the automatic affinity crop must use only those selected offsets.
51+
52+
`decode_affinity_cc` has an independent `edge_offset` knob for the numba
53+
backend:
54+
55+
| Target mode | `decode_affinity_cc.kwargs.edge_offset` |
56+
| --- | --- |
57+
| `deepem` | `1` |
58+
| `banis` | `0` |
59+
60+
The `cc3d` backend ignores directed edge placement and only thresholds
61+
foreground connectivity, so this matters mainly for `backend: numba`.
62+
63+
## Config Policy
64+
65+
Use `affinity_mode: deepem` for DeepEM/SNEMI/LiConn-style configs.
66+
67+
Use `affinity_mode: banis` for BANIS/NISB reproduction configs and any config
68+
whose target should match `lib/banis/data.py::comp_affinities`.

.claude/refactor/training.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ All 12 issues across 4 priority levels have been resolved. 190/191 unit tests pa
6363
#### P3.2: Affinity decoupling in orchestrator.py
6464
- **File**: `training/loss/orchestrator.py`
6565
- **Issue**: Direct import of `data.process.affinity` created tight cross-package coupling
66-
- **Fix**: Dependency injection via constructor parameters (`affinity_crop_enabled_fn`, `crop_spatial_fn`, `resolve_affinity_offsets_fn`) with lazy-import bridge functions as defaults. No behavioral change
66+
- **Fix**: Dependency injection via constructor parameters (`resolve_affinity_mode_fn`, `resolve_affinity_offsets_fn`) with lazy-import bridge functions as defaults. Affinity target handling now routes through explicit `affinity_mode`.
6767

6868
#### P3.3: Logging migration (print -> logging)
6969
- **Files**: All files in `training/` except `debugging.py`

.claude/reference/snemi_old.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# snemi_old: Segmentation Post-Processing Reference
2+
3+
Summary of useful functions from `lib/snemi_old/*.py` for improving segmentation.
4+
5+
## Key Post-Processing Strategies
6+
7+
### 1. Segment Classification (T_pytc_v2.py, T_snemi220416.py)
8+
- **Border-touching**: `bb[:,1::2] == 0` or `bb[:,2] == D-1` etc. Segments touching volume boundary are unreliable for merge analysis.
9+
- **Interior segments**: `num_border == 0` — candidates for orphan merge.
10+
- **Singletons**: `bb[:,1] == bb[:,2]` — single-slice segments. SNEMI test had 1059/1651 singletons. High error rate.
11+
- **Disconnected components**: cc3d.connected_components per segment — remove non-largest component.
12+
13+
### 2. Orphan Detection & Merge (T_snemi220416.py opt=='0.32', T_pytc_v2.py opt=='2.22')
14+
Criteria:
15+
1. Segment touches ≤1 boundary
16+
2. Not connected across z-slices
17+
3. Has single dominant neighbor in z±1
18+
4. Size-based IoU > 0.6 with neighbor
19+
20+
### 3. Oracle Merge Analysis (T_pytc_v2.py opt=='2.211')
21+
- Map each predicted segment to best GT match via max IoU
22+
- Group predicted segments by GT label
23+
- Segments mapped to same GT = should be merged
24+
- Typical result: 190 oracle merges, ARE 0.048 → 0.025
25+
26+
### 4. Morphological Refinement (T_pytc.py)
27+
- `seg_postprocess()`: 2D constrained watershed (mahotas.cwatershed) per slice
28+
- Optional Sobel edge guidance from raw image
29+
- ~0.008-0.015 error reduction
30+
31+
### 5. Multi-Stage Waterz (T_snemi220416.py, T_waterz.py)
32+
Best parameters found:
33+
- `merge_function: aff85_his256`
34+
- `aff_threshold: [0.1, 0.9]`
35+
- `threshold: 0.4-0.7`
36+
- `dust_merge_size: 800 * rr²` (resolution-dependent)
37+
- `dust_merge_affinity: 0.3-0.5`
38+
39+
### 6. Consistency Checking (T_consistency.py)
40+
- Track segment IDs across z-slices
41+
- Count max consecutive occurrences
42+
- Segments with ≤2 consecutive slices = likely noise
43+
- Abrupt size changes = potential errors
44+
45+
### 7. Skeleton Analysis (T_yulun_skel.py, T_skel.py)
46+
- kimimaro TEASAR: `scale=4, const=500, anisotropy=(30,6,6)`
47+
- Cable length filtering: long axons (≥5000µm) vs short fragments (<1000µm)
48+
- ERL (skeleton-based) metric as alternative to pixel-based ARE
49+
- Oracle skeletonization bridges false splits
50+
51+
## Practical Improvement Hierarchy
52+
53+
1. **Remove single-slice dust** — low risk, removes noise
54+
2. **cc3d disconnect removal** — keep largest component per segment
55+
3. **Orphan merge** — segments with bbox fully inside another
56+
4. **IoU-based cross-slice merge** — cautious, only at bbox endpoints
57+
5. **Morphological refinement** — cwatershed per slice
58+
6. **Skeleton-guided merge** — use cable length to validate merges
59+
60+
## Key Files
61+
- `T_pytc_v2.py` — comprehensive pipeline (merge, split, oracle)
62+
- `T_snemi220416.py` — waterz params, orphan detection
63+
- `T_consistency.py` — cross-slice tracking
64+
- `T_yulun_iou.py` — IoU computation, adapted_rand
65+
- `T_yulun_skel.py` — skeleton analysis

connectomics/config/pipeline/config_io.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ...data.processing.build import count_stacked_label_transform_channels
2121
from ...models.architectures.registry import get_architecture_info
2222
from ...utils.channel_slices import infer_min_required_channels
23-
from ...utils.model_outputs import resolve_configured_output_head
23+
from ...utils.model_outputs import resolve_configured_output_head, resolve_output_heads
2424
from ..schema import Config
2525
from ..schema.root import MergeContext
2626
from .profile_engine import _YAML_PROFILE_ENGINE
@@ -329,11 +329,19 @@ def validate_config(cfg: Config) -> None:
329329
f"model.primary_head='{primary_head}' is not present in model.heads "
330330
f"({sorted(model_heads.keys())})."
331331
)
332-
if inference_head is not None and inference_head not in model_heads:
333-
raise ValueError(
334-
f"inference.head='{inference_head}' is not present in model.heads "
335-
f"({sorted(model_heads.keys())})."
332+
if inference_head is not None:
333+
# Accept comma-separated lists (merged-heads inference); each name must exist.
334+
inference_head_names = (
335+
[h.strip() for h in inference_head.split(",") if h.strip()]
336+
if isinstance(inference_head, str) and "," in inference_head
337+
else [inference_head]
336338
)
339+
missing = [h for h in inference_head_names if h not in model_heads]
340+
if missing:
341+
raise ValueError(
342+
f"inference.head={inference_head_names} references unknown heads {missing}; "
343+
f"available: {sorted(model_heads.keys())}."
344+
)
337345
if (
338346
visualization_head is not None
339347
and visualization_head != "all"
@@ -365,6 +373,33 @@ def validate_config(cfg: Config) -> None:
365373
if cfg.data.dataloader.batch_size <= 0:
366374
raise ValueError("data.dataloader.batch_size must be positive")
367375

376+
strategy = str(getattr(cfg.inference, "strategy", "whole_volume")).lower()
377+
if strategy not in {"whole_volume", "chunked"}:
378+
raise ValueError("inference.strategy must be 'whole_volume' or 'chunked'")
379+
chunking_cfg = getattr(cfg.inference, "chunking", None)
380+
chunking_enabled = bool(getattr(chunking_cfg, "enabled", False)) or strategy == "chunked"
381+
if chunking_enabled:
382+
if len(cfg.data.dataloader.patch_size) != 3:
383+
raise ValueError("inference.chunking requires 3D data.dataloader.patch_size")
384+
axes = str(getattr(chunking_cfg, "axes", "all")).lower()
385+
if axes not in {"all", "z"}:
386+
raise ValueError("inference.chunking.axes must be 'all' or 'z'")
387+
chunk_size = getattr(chunking_cfg, "chunk_size", None)
388+
if not chunk_size or len(chunk_size) != 3:
389+
raise ValueError("inference.chunking.chunk_size must be a length-3 ZYX list")
390+
if any(int(v) <= 0 for v in chunk_size):
391+
raise ValueError("inference.chunking.chunk_size values must be positive")
392+
halo = getattr(chunking_cfg, "halo", None)
393+
if halo is None or len(halo) != 3:
394+
raise ValueError("inference.chunking.halo must be a length-3 ZYX list")
395+
if any(int(v) < 0 for v in halo):
396+
raise ValueError("inference.chunking.halo values must be non-negative")
397+
stitching = getattr(chunking_cfg, "stitching", None)
398+
if stitching is not None:
399+
min_contact = int(getattr(stitching, "min_contact", 1))
400+
if min_contact <= 0:
401+
raise ValueError("inference.chunking.stitching.min_contact must be positive")
402+
368403
# Optimizer validation
369404
if cfg.optimization.optimizer.lr <= 0:
370405
raise ValueError("optimization.optimizer.lr must be positive")
@@ -583,22 +618,26 @@ def _validate_label_channel_capacity(selector_value: Any, *, path: str) -> None:
583618
break
584619

585620
if model_heads and decode_has_channel_selection:
586-
decode_output_head = resolve_configured_output_head(
587-
cfg,
588-
purpose="decode channel selection",
589-
allow_none=True,
590-
)
591-
if len(model_heads) > 1 and decode_output_head is None:
621+
decode_heads = resolve_output_heads(cfg, purpose="decode channel selection")
622+
if len(model_heads) > 1 and not decode_heads:
592623
raise ValueError(
593624
"Cross-section validation failed: decode channel selectors require "
594625
"inference.head or model.primary_head when model.heads has multiple "
595626
f"entries ({sorted(model_heads.keys())})."
596627
)
597-
if decode_output_head in model_heads:
598-
decode_available_channels = int(
599-
getattr(model_heads[decode_output_head], "out_channels", out_channels)
628+
if len(decode_heads) > 1:
629+
decode_available_channels = sum(
630+
int(getattr(model_heads[h], "out_channels", 0)) for h in decode_heads
600631
)
601-
decode_channel_scope = f"head '{decode_output_head}'"
632+
decode_channel_scope = f"merged heads {decode_heads}"
633+
decode_output_head = decode_heads[0]
634+
elif decode_heads:
635+
decode_output_head = decode_heads[0]
636+
if decode_output_head in model_heads:
637+
decode_available_channels = int(
638+
getattr(model_heads[decode_output_head], "out_channels", out_channels)
639+
)
640+
decode_channel_scope = f"head '{decode_output_head}'"
602641

603642
for i, decode_step in enumerate(decoding_cfg):
604643
kwargs = getattr(decode_step, "kwargs", None)
@@ -624,39 +663,35 @@ def _validate_label_channel_capacity(selector_value: Any, *, path: str) -> None:
624663
continue
625664
required_output_channels.append((path, min_channels))
626665

627-
# 2e) TTA channel selectors
666+
# 2e) Inference channel selectors
628667
tta_cfg = getattr(cfg.inference, "test_time_augmentation", None)
629668
channel_activations = getattr(tta_cfg, "channel_activations", None) if tta_cfg else None
630-
select_channel = getattr(tta_cfg, "select_channel", None) if tta_cfg else None
631-
tta_has_channel_selection = bool(channel_activations) or select_channel is not None
632-
tta_output_head = (
633-
resolve_configured_output_head(
634-
cfg,
635-
purpose="TTA channel selection",
636-
allow_none=True,
637-
)
638-
if model_heads
639-
else None
669+
select_channel = getattr(cfg.inference, "select_channel", None)
670+
inference_has_channel_selection = bool(channel_activations) or select_channel is not None
671+
tta_heads = (
672+
resolve_output_heads(cfg, purpose="inference channel selection") if model_heads else []
640673
)
641-
if (
642-
model_heads
643-
and len(model_heads) > 1
644-
and tta_has_channel_selection
645-
and tta_output_head is None
646-
):
674+
tta_output_head = tta_heads[0] if tta_heads else None
675+
if model_heads and len(model_heads) > 1 and inference_has_channel_selection and not tta_heads:
647676
raise ValueError(
648-
"Cross-section validation failed: TTA channel selectors require inference.head "
677+
"Cross-section validation failed: inference channel selectors require inference.head "
649678
"or model.primary_head when model.heads has multiple entries "
650679
f"({sorted(model_heads.keys())})."
651680
)
652-
tta_available_channels = (
653-
int(getattr(model_heads[tta_output_head], "out_channels", out_channels))
654-
if tta_output_head in model_heads
655-
else out_channels
656-
)
657-
tta_channel_scope = (
658-
f"head '{tta_output_head}'" if tta_output_head in model_heads else "model output"
659-
)
681+
if len(tta_heads) > 1:
682+
tta_available_channels = sum(
683+
int(getattr(model_heads[h], "out_channels", 0)) for h in tta_heads
684+
)
685+
tta_channel_scope = f"merged heads {tta_heads}"
686+
else:
687+
tta_available_channels = (
688+
int(getattr(model_heads[tta_output_head], "out_channels", out_channels))
689+
if tta_output_head in model_heads
690+
else out_channels
691+
)
692+
tta_channel_scope = (
693+
f"head '{tta_output_head}'" if tta_output_head in model_heads else "model output"
694+
)
660695

661696
def _validate_tta_channel_capacity(selector_value: Any, *, path: str) -> None:
662697
min_selector_channels = infer_min_required_channels(
@@ -692,7 +727,7 @@ def _validate_tta_channel_capacity(selector_value: Any, *, path: str) -> None:
692727
)
693728
_validate_tta_channel_capacity(
694729
select_channel,
695-
path="inference.test_time_augmentation.select_channel",
730+
path="inference.select_channel",
696731
)
697732

698733
if required_output_channels:

0 commit comments

Comments
 (0)