Skip to content

Commit b7ba58c

Browse files
author
Donglai Wei
committed
add liconn mit yaml
1 parent 674de83 commit b7ba58c

1 file changed

Lines changed: 394 additions & 0 deletions

File tree

tutorials/neuron_liconn_mit.yaml

Lines changed: 394 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
experiment_name: liconn_affinity_instance_seg
2+
description: >
3+
LiConn neuron instance segmentation using long-range affinity model (distances 1, 3, 27)
4+
with ABISS-based decoding for large-volume EM data.
5+
6+
Data: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled.zarr/
7+
img: shape (795, 4870, 3825), dtype uint8, chunks (64,64,64), ZYX order
8+
label: shape (795, 4870, 3825), dtype uint64, chunks (64,64,64), ZYX order
9+
resolution: 25 nm (z) × 18 nm (y) × 18 nm (x)
10+
11+
Split strategy (spatial Z-axis, 80/20):
12+
train: Z slices 0–635 (636 slices, ~15.9 µm z-depth)
13+
val: Z slices 636–794 (159 slices, ~4.0 µm z-depth)
14+
test: full volume (run inference on the whole zarr after training)
15+
16+
Label erosion (erosion: 1):
17+
Applies Kisuk Lee's instance erosion — any voxel in a 3×3 XY window
18+
touching >1 segment ID is set to background before affinity computation.
19+
Essential for tightly-packed axons/neurons to create clean boundary gaps.
20+
21+
Imperfect label robustness strategy:
22+
1. erosion: 1 — removes boundary ambiguity from imperfect proofreading
23+
2. FocalLoss (gamma=2) — down-weights easy background voxels; handles class
24+
imbalance without needing pos_weight tuning
25+
3. TverskyLoss (beta=0.7) — penalises false negatives (missed boundaries) more
26+
than false positives; reduces false merges in packed axons
27+
4. aug_em_neuron — heavy elastic + misalignment + missing sections to
28+
prevent overfitting to label artifacts
29+
5. EMA (decay=0.999) — smooths weight updates, reduces sensitivity to noisy labels
30+
6. foreground_threshold: 0.05 — skips near-empty patches (background-heavy
31+
patches amplify label noise)
32+
7. accumulate_grad_batches: 4 — larger effective batch reduces gradient noise
33+
from individual mislabeled patches
34+
35+
Loss design (FocalLoss + TverskyLoss, per channel group):
36+
Short-range (ch 0-2, offset 1): Focal weight=1.0 + Tversky weight=1.0
37+
— primary connectivity signal; full weight on both losses
38+
Mid-range (ch 3-5, offset 3): Focal weight=0.8 + Tversky weight=0.8
39+
— slightly downweighted; bridges gaps but noisier than short-range
40+
Long-range (ch 6-8, offset 9): Focal weight=0.5 + Tversky weight=0.5
41+
— sparser and noisier; half weight to avoid dominating gradients
42+
43+
FocalLoss kwargs: gamma=2.0 (standard), alpha=0.25 (foreground weight)
44+
TverskyLoss kwargs: alpha=0.3 (FP penalty), beta=0.7 (FN penalty), sigmoid=true
45+
— beta > alpha means we penalise missed boundaries (false merges) more than
46+
spurious boundaries (false splits); appropriate for tightly-packed axons
47+
48+
Affinity channels (9 total):
49+
- Ch 0-2: short-range (offset 1) for z, y, x
50+
- Ch 3-5: mid-range (offset 3) for z, y, x
51+
- Ch 6-8: long-range (offset 9) for z, y, x
52+
53+
Decoding strategy:
54+
- Primary: ABISS chunked watershed+agglomeration pipeline (run_abiss_single.py)
55+
- Fallback: decode_affinity_cc (connected components on short-range affinities)
56+
- For the full 795×4870×3825 volume, run ABISS externally on saved predictions.
57+
58+
See .claude/repos_other/ABISS_USAGE_SUMMARY.md for ABISS build/run instructions.
59+
60+
# ── Inherit all shared profile libraries ──────────────────────────────────────
61+
_base_:
62+
- bases/all_profiles.yaml
63+
64+
# ── Profile selectors ─────────────────────────────────────────────────────────
65+
default:
66+
# Use 'binary' pipeline profile (not 'affinity-12') because:
67+
# - affinity-12 injects model.loss.overrides which is not in LossConfig schema
68+
# - We define our own losses list inline (FocalLoss + TverskyLoss per channel group)
69+
# - We override out_channels, label_transform, and augmentation ourselves below
70+
pipeline_profile: binary
71+
system:
72+
profile: all-gpu-cpu
73+
model:
74+
arch:
75+
profile: mednext_m # MedNeXt-M (17.6M params); better capacity for 9-ch affinity
76+
# Override out_channels: 9 affinities (3 distances × 3 axes)
77+
out_channels: 9
78+
79+
loss:
80+
profile: null
81+
# ── Custom loss: FocalLoss + TverskyLoss per channel group ────────────────
82+
# We bypass the loss_binary profile and define losses inline.
83+
# Each entry maps 1:1 to a loss function (by index in the losses list).
84+
# pred_slice / target_slice select which affinity channels each term applies to.
85+
#
86+
# FocalLoss (MONAI):
87+
# gamma=2.0 — standard focusing parameter; down-weights easy negatives
88+
# alpha=0.25 — foreground class weight (boundary voxels are the minority)
89+
# sigmoid=true — applies sigmoid before computing loss (raw logits input)
90+
#
91+
# TverskyLoss (MONAI):
92+
# alpha=0.3 — false positive penalty (spurious boundaries)
93+
# beta=0.7 — false negative penalty (missed boundaries / false merges)
94+
# sigmoid=true — applies sigmoid before computing loss
95+
# smooth_nr/smooth_dr=1e-5 — numerical stability
96+
#
97+
# NOTE: FocalLoss does NOT support pos_weight (no spatial_weight_arg).
98+
# Class imbalance is handled by alpha + gamma instead.
99+
losses:
100+
# ── Short-range channels (ch 0-2, offset 1) ──────────────────────────
101+
# FocalLoss: no 'sigmoid' param — applies sigmoid by default (use_softmax=False)
102+
# TverskyLoss: 'sigmoid=true' is valid
103+
- function: FocalLoss
104+
weight: 1.0
105+
pred_slice: [0, 3]
106+
target_slice: [0, 3]
107+
kwargs:
108+
gamma: 2.0
109+
alpha: 0.25
110+
- function: TverskyLoss
111+
weight: 1.0
112+
pred_slice: [0, 3]
113+
target_slice: [0, 3]
114+
kwargs:
115+
alpha: 0.3
116+
beta: 0.7
117+
sigmoid: true
118+
smooth_nr: 1.0e-5
119+
smooth_dr: 1.0e-5
120+
# ── Mid-range channels (ch 3-5, offset 3) ────────────────────────────
121+
- function: FocalLoss
122+
weight: 0.8
123+
pred_slice: [3, 6]
124+
target_slice: [3, 6]
125+
kwargs:
126+
gamma: 2.0
127+
alpha: 0.25
128+
- function: TverskyLoss
129+
weight: 0.8
130+
pred_slice: [3, 6]
131+
target_slice: [3, 6]
132+
kwargs:
133+
alpha: 0.3
134+
beta: 0.7
135+
sigmoid: true
136+
smooth_nr: 1.0e-5
137+
smooth_dr: 1.0e-5
138+
# ── Long-range channels (ch 6-8, offset 27) ──────────────────────────
139+
- function: FocalLoss
140+
weight: 0.5
141+
pred_slice: [6, 9]
142+
target_slice: [6, 9]
143+
kwargs:
144+
gamma: 2.0
145+
alpha: 0.25
146+
- function: TverskyLoss
147+
weight: 0.5
148+
pred_slice: [6, 9]
149+
target_slice: [6, 9]
150+
kwargs:
151+
alpha: 0.3
152+
beta: 0.7
153+
sigmoid: true
154+
smooth_nr: 1.0e-5
155+
smooth_dr: 1.0e-5
156+
mednext:
157+
size: M # MedNeXt-M (17.6M params); was B (10.5M) — more capacity for 9-ch affinity
158+
kernel_size: 3
159+
# Preset `mednext` already uses GroupNorm internally in MedNeXt blocks.
160+
# `model.mednext.norm` is only configurable when using `arch.type: mednext_custom`.
161+
dim: 3d
162+
checkpoint_style: outside_block # gradient checkpointing — required for large patches (≥192³)
163+
164+
data:
165+
# ── Image normalization ───────────────────────────────────────────────────
166+
# Normalize uint8 [0,255] → float [0,1] before augmentation.
167+
# clip_percentile_low/high=0.0/1.0 means no clipping (use full range).
168+
# Learned from neuron_nisb_9nm_liconn.yaml — missing this causes training instability.
169+
image_transform:
170+
normalize: "0-1"
171+
clip_percentile_low: 0.0
172+
clip_percentile_high: 1.0
173+
data_transform:
174+
# Keep symmetric full-volume context padding on the inference input.
175+
pad_size: [32, 128, 128]
176+
dataloader:
177+
profile: cached
178+
use_cache: false # disable in-memory cache — volume is 795×4870×3825
179+
use_lazy_zarr: true # stream patches directly from Zarr (no RAM preload)
180+
use_preloaded_cache_train: false
181+
use_preloaded_cache_val: false
182+
# Skip patches with <10% foreground — avoids amplifying label noise in
183+
# background-heavy crops (common with tightly-packed axons at edges).
184+
# Raised from 0.05 → 0.10 after cropping out unlabeled black regions,
185+
# since remaining volume should have denser label coverage.
186+
cached_sampling_foreground_threshold: 0.10
187+
augmentation:
188+
# aug_em_neuron: heavy elastic deformation, wide contrast range, aggressive
189+
# EM artifacts (misalignment, missing sections, motion blur, missing parts).
190+
# This is the most important robustness tool for imperfect labels — the model
191+
# learns to be invariant to the kinds of artifacts that cause proofreading errors.
192+
profile: aug_em_neuron
193+
194+
inference:
195+
system:
196+
num_gpus: 1
197+
num_workers: 4
198+
batch_size: 4
199+
sliding_window:
200+
# Large overlap for smooth predictions on big volumes
201+
overlap: 0.5
202+
blending: gaussian
203+
keep_input_on_cpu: true # keep raw volume on CPU to save GPU memory
204+
sw_device: cuda
205+
output_device: cpu
206+
test_time_augmentation:
207+
enabled: false
208+
# Only use short-range affinities for TTA ensemble (channels 0-2)
209+
select_channel: [0, 1, 2]
210+
ensemble_mode: min # min-pooling is conservative for affinities
211+
# activation_profile is not in TestTimeAugmentationConfig — removed
212+
postprocessing:
213+
enabled: true
214+
crop_pad: [32, 32, 128, 128, 128, 128]
215+
216+
217+
# ── Training stage ─────────────────────────────────────────────────────────────
218+
train:
219+
data:
220+
label_transform:
221+
# ── Label erosion ────────────────────────────────────────────────────────
222+
# erosion: N applies SegErosionInstanced (Kisuk Lee's SNEMI3D preprocessing):
223+
# Any voxel in a (2N+1)×(2N+1) XY window that touches >1 segment ID is
224+
# set to background BEFORE affinity computation.
225+
#
226+
# For tightly-packed axons/neurons at 18 nm, erosion=1 (3×3 window) is
227+
# the standard choice — it creates a 1-voxel gap at every instance boundary,
228+
# making the affinity signal cleaner and reducing false merges.
229+
# Increase to erosion=2 if boundaries are still ambiguous after training.
230+
erosion: 1
231+
232+
# 9-channel affinity: distances 1, 3, 27 along each axis (z, y, x)
233+
targets:
234+
- name: affinity
235+
kwargs:
236+
offsets:
237+
# Distance 1 (short-range) — primary connectivity signal
238+
- "1-0-0" # ch 0: z
239+
- "0-1-0" # ch 1: y
240+
- "0-0-1" # ch 2: x
241+
# Distance 3 (mid-range) — bridges small gaps
242+
- "3-0-0" # ch 3: z
243+
- "0-3-0" # ch 4: y
244+
- "0-0-3" # ch 5: x
245+
# Distance 9 (long-range) — long-range context for agglomeration
246+
- "9-0-0" # ch 6: z
247+
- "0-9-0" # ch 7: y
248+
- "0-0-9" # ch 8: x
249+
250+
# ── Data paths ───────────────────────────────────────────────────────────
251+
# NOTE: split_enabled only works with HDF5/TIFF, not Zarr (raises ValueError).
252+
# LazyZarrVolumeDataset samples random patches from the full volume extent.
253+
# Both train and val use the same Zarr — val patches are drawn with a different
254+
# random seed (set_epoch is called per epoch). For a strict spatial split,
255+
# create separate Zarr sub-stores (e.g. zarr.open(...)[0:636]) offline.
256+
train:
257+
image: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/img
258+
label: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/label
259+
resolution: [25, 18, 18] # [z_nm, y_nm, x_nm] — anisotropic: 25 nm z, 18 nm xy
260+
261+
val:
262+
image: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/img
263+
label: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/label
264+
resolution: [25, 18, 18]
265+
266+
# Data is already ZYX — no transpose needed
267+
# data_transform:
268+
# train_transpose: [2, 1, 0] # only needed if Zarr is stored XYZ
269+
270+
dataloader:
271+
# Patch must be ≥ 27 in all axes (largest affinity offset).
272+
# 64×256×256 = 4,194,304 voxels — good axon context (10–50 axons per crop at 18 nm).
273+
# Reducing crop size hurts long-range affinity training more than reducing batch size.
274+
patch_size: [64, 256, 256]
275+
# batch_size=1 halves GPU memory vs batch_size=2; compensate with accumulate_grad_batches=4
276+
# to keep effective batch = 4 (same as batch_size=2 × accumulate=2).
277+
batch_size: 1
278+
279+
model:
280+
input_size: [64, 256, 256]
281+
output_size: [64, 256, 256]
282+
283+
optimization:
284+
profile: warmup_cosine_lr
285+
max_epochs: 500
286+
n_steps_per_epoch: 200
287+
# Gradient accumulation: effective batch = batch_size × accumulate_grad_batches.
288+
# batch_size=1, accumulate=4, num_gpus=1 → effective batch = 4.
289+
# Larger effective batch reduces gradient noise from individual mislabeled patches.
290+
accumulate_grad_batches: 4
291+
# EMA smooths weight updates over time — reduces sensitivity to noisy/imperfect
292+
# labels by averaging model weights across recent steps (decay=0.999 from profile).
293+
# Already enabled in warmup_cosine_lr profile (ema.enabled: true, decay: 0.999).
294+
# Validation uses EMA weights (validate_with_ema: true) for best generalization.
295+
296+
system:
297+
num_gpus: 2
298+
num_workers: 16 # num_cpus is not in SystemConfig; use num_workers for dataloader workers
299+
seed: 42
300+
301+
monitor:
302+
logging:
303+
scalar:
304+
loss: [train_loss_total_epoch, val_loss_total]
305+
loss_every_n_steps: 100
306+
val_check_interval: 10
307+
images:
308+
# dtype cast fixed: _log_multi_channel_viz and _log_single_channel_viz now
309+
# explicitly cast all tensors to float32 before make_grid (uint8 label channels
310+
# and bfloat16 model outputs both caused RuntimeError previously).
311+
enabled: true
312+
max_images: 8
313+
num_slices: 4
314+
log_every_n_epochs: 5
315+
channel_mode: all
316+
checkpoint:
317+
dirpath: /orcd/scratch/bcs/002/mansour/train_liconn/outputs/checkpoints/
318+
monitor: val_loss_total
319+
mode: min
320+
save_top_k: 3
321+
322+
# ── Test / inference stage ─────────────────────────────────────────────────────
323+
test:
324+
data:
325+
# Test volume: BA_5AA_proteintest_1 FFN-sharpened, uint8.
326+
# Shape: (59, 1024, 1024) ZYX — Z=59 is smaller than training patch (64).
327+
# Use window_size: [32, 256, 256] so the sliding window fits within Z=59.
328+
# No label available — inference only (evaluation.enabled: false).
329+
test:
330+
# image: /orcd/data/edboyden/002/mansour8/deb_data/BA_5AA_proteintest_1_2026-02-18_17.25.37_channel1_ffn_sharp.zarr
331+
image: /orcd/data/edboyden/002/dleible/DL288B/DL288B_251222S_cond5_40x_12tiles_round1_fused_488_crop512x1024x1024.tif
332+
# label: omitted — no ground truth available for this volume
333+
resolution: [25, 18, 18] # same voxel size as training data
334+
335+
inference:
336+
sliding_window:
337+
# window_size Z=32 (< volume Z=59) — model was trained on Z=64 but handles
338+
# smaller windows; use Z=32 to avoid padding artifacts at volume boundaries.
339+
# Y/X=256 matches training receptive field for best quality.
340+
window_size: [32, 256, 256]
341+
overlap: 0.5
342+
blending: gaussian
343+
keep_input_on_cpu: true
344+
sw_device: cuda
345+
output_device: cpu
346+
347+
save_prediction:
348+
enabled: true
349+
output_formats: [h5]
350+
output_path: /orcd/scratch/bcs/002/mansour/train_liconn/outputs/results_axons/
351+
# float32 for full precision affinities; use float16 to halve storage
352+
intensity_scale: -1.0
353+
intensity_dtype: float32
354+
compression: gzip
355+
356+
# ── Decoding ──────────────────────────────────────────────────────────────
357+
# Option A: ABISS external pipeline — best quality, requires ABISS built.
358+
# Falls back to connected-components automatically if ABISS fails.
359+
#
360+
# ABISS binary: abiss/build/ws (built from abiss/ source via abiss/build_ws.sh)
361+
# Intel TBB runtime: /orcd/software/community/001/rocky8/intel/2024.2.1/tbb/2021.13/lib
362+
# must be on LD_LIBRARY_PATH at runtime (ws binary links against libtbb.so.12).
363+
#
364+
# Thresholds for axon segmentation at 18 nm:
365+
# ws_high_threshold=0.9 — only seed watershed at very high affinity (confident boundaries)
366+
# ws_low_threshold=0.3 — extend seeds down to moderate affinity
367+
# ws_size_threshold=100 — discard fragments smaller than 100 voxels
368+
decoding:
369+
- name: decode_abiss
370+
kwargs:
371+
command: >
372+
env LD_LIBRARY_PATH=/orcd/software/community/001/rocky8/intel/2024.2.1/tbb/2021.13/lib:$LD_LIBRARY_PATH
373+
{python_exe} scripts/run_abiss_single.py
374+
--input {input_h5}
375+
--output {output_h5}
376+
--abiss-home /home/mansour8/pytc_dev/pytorch_connectomics/abiss
377+
--ws-high-threshold 0.9
378+
--ws-low-threshold 0.3
379+
--ws-size-threshold 100
380+
input_dataset: main
381+
output_dataset: main
382+
timeout_sec: 3600
383+
384+
# Option B: Lightweight connected-components (no ABISS required).
385+
# Only uses short-range affinities (ch 0-2); good for quick inspection.
386+
# Uncomment and comment out Option A above to use.
387+
# decoding:
388+
# - name: decode_affinity_cc
389+
# kwargs:
390+
# threshold: 0.5
391+
# backend: auto
392+
393+
evaluation:
394+
enabled: false # no ground truth — cannot compute metrics

0 commit comments

Comments
 (0)