|
| 1 | +experiment_name: liconn_affinity_instance_seg |
| 2 | +description: > |
| 3 | + LiConn neuron instance segmentation using long-range affinity model (distances 1, 3, 27) |
| 4 | + with ABISS-based decoding for large-volume EM data. |
| 5 | +
|
| 6 | + Data: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled.zarr/ |
| 7 | + img: shape (795, 4870, 3825), dtype uint8, chunks (64,64,64), ZYX order |
| 8 | + label: shape (795, 4870, 3825), dtype uint64, chunks (64,64,64), ZYX order |
| 9 | + resolution: 25 nm (z) × 18 nm (y) × 18 nm (x) |
| 10 | +
|
| 11 | + Split strategy (spatial Z-axis, 80/20): |
| 12 | + train: Z slices 0–635 (636 slices, ~15.9 µm z-depth) |
| 13 | + val: Z slices 636–794 (159 slices, ~4.0 µm z-depth) |
| 14 | + test: full volume (run inference on the whole zarr after training) |
| 15 | +
|
| 16 | + Label erosion (erosion: 1): |
| 17 | + Applies Kisuk Lee's instance erosion — any voxel in a 3×3 XY window |
| 18 | + touching >1 segment ID is set to background before affinity computation. |
| 19 | + Essential for tightly-packed axons/neurons to create clean boundary gaps. |
| 20 | +
|
| 21 | + Imperfect label robustness strategy: |
| 22 | + 1. erosion: 1 — removes boundary ambiguity from imperfect proofreading |
| 23 | + 2. FocalLoss (gamma=2) — down-weights easy background voxels; handles class |
| 24 | + imbalance without needing pos_weight tuning |
| 25 | + 3. TverskyLoss (beta=0.7) — penalises false negatives (missed boundaries) more |
| 26 | + than false positives; reduces false merges in packed axons |
| 27 | + 4. aug_em_neuron — heavy elastic + misalignment + missing sections to |
| 28 | + prevent overfitting to label artifacts |
| 29 | + 5. EMA (decay=0.999) — smooths weight updates, reduces sensitivity to noisy labels |
| 30 | + 6. foreground_threshold: 0.05 — skips near-empty patches (background-heavy |
| 31 | + patches amplify label noise) |
| 32 | + 7. accumulate_grad_batches: 4 — larger effective batch reduces gradient noise |
| 33 | + from individual mislabeled patches |
| 34 | +
|
| 35 | + Loss design (FocalLoss + TverskyLoss, per channel group): |
| 36 | + Short-range (ch 0-2, offset 1): Focal weight=1.0 + Tversky weight=1.0 |
| 37 | + — primary connectivity signal; full weight on both losses |
| 38 | + Mid-range (ch 3-5, offset 3): Focal weight=0.8 + Tversky weight=0.8 |
| 39 | + — slightly downweighted; bridges gaps but noisier than short-range |
| 40 | + Long-range (ch 6-8, offset 9): Focal weight=0.5 + Tversky weight=0.5 |
| 41 | + — sparser and noisier; half weight to avoid dominating gradients |
| 42 | +
|
| 43 | + FocalLoss kwargs: gamma=2.0 (standard), alpha=0.25 (foreground weight) |
| 44 | + TverskyLoss kwargs: alpha=0.3 (FP penalty), beta=0.7 (FN penalty), sigmoid=true |
| 45 | + — beta > alpha means we penalise missed boundaries (false merges) more than |
| 46 | + spurious boundaries (false splits); appropriate for tightly-packed axons |
| 47 | +
|
| 48 | + Affinity channels (9 total): |
| 49 | + - Ch 0-2: short-range (offset 1) for z, y, x |
| 50 | + - Ch 3-5: mid-range (offset 3) for z, y, x |
| 51 | + - Ch 6-8: long-range (offset 9) for z, y, x |
| 52 | +
|
| 53 | + Decoding strategy: |
| 54 | + - Primary: ABISS chunked watershed+agglomeration pipeline (run_abiss_single.py) |
| 55 | + - Fallback: decode_affinity_cc (connected components on short-range affinities) |
| 56 | + - For the full 795×4870×3825 volume, run ABISS externally on saved predictions. |
| 57 | +
|
| 58 | + See .claude/repos_other/ABISS_USAGE_SUMMARY.md for ABISS build/run instructions. |
| 59 | +
|
| 60 | +# ── Inherit all shared profile libraries ────────────────────────────────────── |
| 61 | +_base_: |
| 62 | + - bases/all_profiles.yaml |
| 63 | + |
| 64 | +# ── Profile selectors ───────────────────────────────────────────────────────── |
| 65 | +default: |
| 66 | + # Use 'binary' pipeline profile (not 'affinity-12') because: |
| 67 | + # - affinity-12 injects model.loss.overrides which is not in LossConfig schema |
| 68 | + # - We define our own losses list inline (FocalLoss + TverskyLoss per channel group) |
| 69 | + # - We override out_channels, label_transform, and augmentation ourselves below |
| 70 | + pipeline_profile: binary |
| 71 | + system: |
| 72 | + profile: all-gpu-cpu |
| 73 | + model: |
| 74 | + arch: |
| 75 | + profile: mednext_m # MedNeXt-M (17.6M params); better capacity for 9-ch affinity |
| 76 | + # Override out_channels: 9 affinities (3 distances × 3 axes) |
| 77 | + out_channels: 9 |
| 78 | + |
| 79 | + loss: |
| 80 | + profile: null |
| 81 | + # ── Custom loss: FocalLoss + TverskyLoss per channel group ──────────────── |
| 82 | + # We bypass the loss_binary profile and define losses inline. |
| 83 | + # Each entry maps 1:1 to a loss function (by index in the losses list). |
| 84 | + # pred_slice / target_slice select which affinity channels each term applies to. |
| 85 | + # |
| 86 | + # FocalLoss (MONAI): |
| 87 | + # gamma=2.0 — standard focusing parameter; down-weights easy negatives |
| 88 | + # alpha=0.25 — foreground class weight (boundary voxels are the minority) |
| 89 | + # sigmoid=true — applies sigmoid before computing loss (raw logits input) |
| 90 | + # |
| 91 | + # TverskyLoss (MONAI): |
| 92 | + # alpha=0.3 — false positive penalty (spurious boundaries) |
| 93 | + # beta=0.7 — false negative penalty (missed boundaries / false merges) |
| 94 | + # sigmoid=true — applies sigmoid before computing loss |
| 95 | + # smooth_nr/smooth_dr=1e-5 — numerical stability |
| 96 | + # |
| 97 | + # NOTE: FocalLoss does NOT support pos_weight (no spatial_weight_arg). |
| 98 | + # Class imbalance is handled by alpha + gamma instead. |
| 99 | + losses: |
| 100 | + # ── Short-range channels (ch 0-2, offset 1) ────────────────────────── |
| 101 | + # FocalLoss: no 'sigmoid' param — applies sigmoid by default (use_softmax=False) |
| 102 | + # TverskyLoss: 'sigmoid=true' is valid |
| 103 | + - function: FocalLoss |
| 104 | + weight: 1.0 |
| 105 | + pred_slice: [0, 3] |
| 106 | + target_slice: [0, 3] |
| 107 | + kwargs: |
| 108 | + gamma: 2.0 |
| 109 | + alpha: 0.25 |
| 110 | + - function: TverskyLoss |
| 111 | + weight: 1.0 |
| 112 | + pred_slice: [0, 3] |
| 113 | + target_slice: [0, 3] |
| 114 | + kwargs: |
| 115 | + alpha: 0.3 |
| 116 | + beta: 0.7 |
| 117 | + sigmoid: true |
| 118 | + smooth_nr: 1.0e-5 |
| 119 | + smooth_dr: 1.0e-5 |
| 120 | + # ── Mid-range channels (ch 3-5, offset 3) ──────────────────────────── |
| 121 | + - function: FocalLoss |
| 122 | + weight: 0.8 |
| 123 | + pred_slice: [3, 6] |
| 124 | + target_slice: [3, 6] |
| 125 | + kwargs: |
| 126 | + gamma: 2.0 |
| 127 | + alpha: 0.25 |
| 128 | + - function: TverskyLoss |
| 129 | + weight: 0.8 |
| 130 | + pred_slice: [3, 6] |
| 131 | + target_slice: [3, 6] |
| 132 | + kwargs: |
| 133 | + alpha: 0.3 |
| 134 | + beta: 0.7 |
| 135 | + sigmoid: true |
| 136 | + smooth_nr: 1.0e-5 |
| 137 | + smooth_dr: 1.0e-5 |
| 138 | + # ── Long-range channels (ch 6-8, offset 27) ────────────────────────── |
| 139 | + - function: FocalLoss |
| 140 | + weight: 0.5 |
| 141 | + pred_slice: [6, 9] |
| 142 | + target_slice: [6, 9] |
| 143 | + kwargs: |
| 144 | + gamma: 2.0 |
| 145 | + alpha: 0.25 |
| 146 | + - function: TverskyLoss |
| 147 | + weight: 0.5 |
| 148 | + pred_slice: [6, 9] |
| 149 | + target_slice: [6, 9] |
| 150 | + kwargs: |
| 151 | + alpha: 0.3 |
| 152 | + beta: 0.7 |
| 153 | + sigmoid: true |
| 154 | + smooth_nr: 1.0e-5 |
| 155 | + smooth_dr: 1.0e-5 |
| 156 | + mednext: |
| 157 | + size: M # MedNeXt-M (17.6M params); was B (10.5M) — more capacity for 9-ch affinity |
| 158 | + kernel_size: 3 |
| 159 | + # Preset `mednext` already uses GroupNorm internally in MedNeXt blocks. |
| 160 | + # `model.mednext.norm` is only configurable when using `arch.type: mednext_custom`. |
| 161 | + dim: 3d |
| 162 | + checkpoint_style: outside_block # gradient checkpointing — required for large patches (≥192³) |
| 163 | + |
| 164 | + data: |
| 165 | + # ── Image normalization ─────────────────────────────────────────────────── |
| 166 | + # Normalize uint8 [0,255] → float [0,1] before augmentation. |
| 167 | + # clip_percentile_low/high=0.0/1.0 means no clipping (use full range). |
| 168 | + # Learned from neuron_nisb_9nm_liconn.yaml — missing this causes training instability. |
| 169 | + image_transform: |
| 170 | + normalize: "0-1" |
| 171 | + clip_percentile_low: 0.0 |
| 172 | + clip_percentile_high: 1.0 |
| 173 | + data_transform: |
| 174 | + # Keep symmetric full-volume context padding on the inference input. |
| 175 | + pad_size: [32, 128, 128] |
| 176 | + dataloader: |
| 177 | + profile: cached |
| 178 | + use_cache: false # disable in-memory cache — volume is 795×4870×3825 |
| 179 | + use_lazy_zarr: true # stream patches directly from Zarr (no RAM preload) |
| 180 | + use_preloaded_cache_train: false |
| 181 | + use_preloaded_cache_val: false |
| 182 | + # Skip patches with <10% foreground — avoids amplifying label noise in |
| 183 | + # background-heavy crops (common with tightly-packed axons at edges). |
| 184 | + # Raised from 0.05 → 0.10 after cropping out unlabeled black regions, |
| 185 | + # since remaining volume should have denser label coverage. |
| 186 | + cached_sampling_foreground_threshold: 0.10 |
| 187 | + augmentation: |
| 188 | + # aug_em_neuron: heavy elastic deformation, wide contrast range, aggressive |
| 189 | + # EM artifacts (misalignment, missing sections, motion blur, missing parts). |
| 190 | + # This is the most important robustness tool for imperfect labels — the model |
| 191 | + # learns to be invariant to the kinds of artifacts that cause proofreading errors. |
| 192 | + profile: aug_em_neuron |
| 193 | + |
| 194 | + inference: |
| 195 | + system: |
| 196 | + num_gpus: 1 |
| 197 | + num_workers: 4 |
| 198 | + batch_size: 4 |
| 199 | + sliding_window: |
| 200 | + # Large overlap for smooth predictions on big volumes |
| 201 | + overlap: 0.5 |
| 202 | + blending: gaussian |
| 203 | + keep_input_on_cpu: true # keep raw volume on CPU to save GPU memory |
| 204 | + sw_device: cuda |
| 205 | + output_device: cpu |
| 206 | + test_time_augmentation: |
| 207 | + enabled: false |
| 208 | + # Only use short-range affinities for TTA ensemble (channels 0-2) |
| 209 | + select_channel: [0, 1, 2] |
| 210 | + ensemble_mode: min # min-pooling is conservative for affinities |
| 211 | + # activation_profile is not in TestTimeAugmentationConfig — removed |
| 212 | + postprocessing: |
| 213 | + enabled: true |
| 214 | + crop_pad: [32, 32, 128, 128, 128, 128] |
| 215 | + |
| 216 | + |
| 217 | +# ── Training stage ───────────────────────────────────────────────────────────── |
| 218 | +train: |
| 219 | + data: |
| 220 | + label_transform: |
| 221 | + # ── Label erosion ──────────────────────────────────────────────────────── |
| 222 | + # erosion: N applies SegErosionInstanced (Kisuk Lee's SNEMI3D preprocessing): |
| 223 | + # Any voxel in a (2N+1)×(2N+1) XY window that touches >1 segment ID is |
| 224 | + # set to background BEFORE affinity computation. |
| 225 | + # |
| 226 | + # For tightly-packed axons/neurons at 18 nm, erosion=1 (3×3 window) is |
| 227 | + # the standard choice — it creates a 1-voxel gap at every instance boundary, |
| 228 | + # making the affinity signal cleaner and reducing false merges. |
| 229 | + # Increase to erosion=2 if boundaries are still ambiguous after training. |
| 230 | + erosion: 1 |
| 231 | + |
| 232 | + # 9-channel affinity: distances 1, 3, 27 along each axis (z, y, x) |
| 233 | + targets: |
| 234 | + - name: affinity |
| 235 | + kwargs: |
| 236 | + offsets: |
| 237 | + # Distance 1 (short-range) — primary connectivity signal |
| 238 | + - "1-0-0" # ch 0: z |
| 239 | + - "0-1-0" # ch 1: y |
| 240 | + - "0-0-1" # ch 2: x |
| 241 | + # Distance 3 (mid-range) — bridges small gaps |
| 242 | + - "3-0-0" # ch 3: z |
| 243 | + - "0-3-0" # ch 4: y |
| 244 | + - "0-0-3" # ch 5: x |
| 245 | + # Distance 9 (long-range) — long-range context for agglomeration |
| 246 | + - "9-0-0" # ch 6: z |
| 247 | + - "0-9-0" # ch 7: y |
| 248 | + - "0-0-9" # ch 8: x |
| 249 | + |
| 250 | + # ── Data paths ─────────────────────────────────────────────────────────── |
| 251 | + # NOTE: split_enabled only works with HDF5/TIFF, not Zarr (raises ValueError). |
| 252 | + # LazyZarrVolumeDataset samples random patches from the full volume extent. |
| 253 | + # Both train and val use the same Zarr — val patches are drawn with a different |
| 254 | + # random seed (set_epoch is called per epoch). For a strict spatial split, |
| 255 | + # create separate Zarr sub-stores (e.g. zarr.open(...)[0:636]) offline. |
| 256 | + train: |
| 257 | + image: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/img |
| 258 | + label: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/label |
| 259 | + resolution: [25, 18, 18] # [z_nm, y_nm, x_nm] — anisotropic: 25 nm z, 18 nm xy |
| 260 | + |
| 261 | + val: |
| 262 | + image: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/img |
| 263 | + label: /orcd/scratch/bcs/002/mansour/train_liconn/liconn_18nm_labelled_cropped.zarr/label |
| 264 | + resolution: [25, 18, 18] |
| 265 | + |
| 266 | + # Data is already ZYX — no transpose needed |
| 267 | + # data_transform: |
| 268 | + # train_transpose: [2, 1, 0] # only needed if Zarr is stored XYZ |
| 269 | + |
| 270 | + dataloader: |
| 271 | + # Patch must be ≥ 27 in all axes (largest affinity offset). |
| 272 | + # 64×256×256 = 4,194,304 voxels — good axon context (10–50 axons per crop at 18 nm). |
| 273 | + # Reducing crop size hurts long-range affinity training more than reducing batch size. |
| 274 | + patch_size: [64, 256, 256] |
| 275 | + # batch_size=1 halves GPU memory vs batch_size=2; compensate with accumulate_grad_batches=4 |
| 276 | + # to keep effective batch = 4 (same as batch_size=2 × accumulate=2). |
| 277 | + batch_size: 1 |
| 278 | + |
| 279 | + model: |
| 280 | + input_size: [64, 256, 256] |
| 281 | + output_size: [64, 256, 256] |
| 282 | + |
| 283 | + optimization: |
| 284 | + profile: warmup_cosine_lr |
| 285 | + max_epochs: 500 |
| 286 | + n_steps_per_epoch: 200 |
| 287 | + # Gradient accumulation: effective batch = batch_size × accumulate_grad_batches. |
| 288 | + # batch_size=1, accumulate=4, num_gpus=1 → effective batch = 4. |
| 289 | + # Larger effective batch reduces gradient noise from individual mislabeled patches. |
| 290 | + accumulate_grad_batches: 4 |
| 291 | + # EMA smooths weight updates over time — reduces sensitivity to noisy/imperfect |
| 292 | + # labels by averaging model weights across recent steps (decay=0.999 from profile). |
| 293 | + # Already enabled in warmup_cosine_lr profile (ema.enabled: true, decay: 0.999). |
| 294 | + # Validation uses EMA weights (validate_with_ema: true) for best generalization. |
| 295 | + |
| 296 | + system: |
| 297 | + num_gpus: 2 |
| 298 | + num_workers: 16 # num_cpus is not in SystemConfig; use num_workers for dataloader workers |
| 299 | + seed: 42 |
| 300 | + |
| 301 | + monitor: |
| 302 | + logging: |
| 303 | + scalar: |
| 304 | + loss: [train_loss_total_epoch, val_loss_total] |
| 305 | + loss_every_n_steps: 100 |
| 306 | + val_check_interval: 10 |
| 307 | + images: |
| 308 | + # dtype cast fixed: _log_multi_channel_viz and _log_single_channel_viz now |
| 309 | + # explicitly cast all tensors to float32 before make_grid (uint8 label channels |
| 310 | + # and bfloat16 model outputs both caused RuntimeError previously). |
| 311 | + enabled: true |
| 312 | + max_images: 8 |
| 313 | + num_slices: 4 |
| 314 | + log_every_n_epochs: 5 |
| 315 | + channel_mode: all |
| 316 | + checkpoint: |
| 317 | + dirpath: /orcd/scratch/bcs/002/mansour/train_liconn/outputs/checkpoints/ |
| 318 | + monitor: val_loss_total |
| 319 | + mode: min |
| 320 | + save_top_k: 3 |
| 321 | + |
| 322 | +# ── Test / inference stage ───────────────────────────────────────────────────── |
| 323 | +test: |
| 324 | + data: |
| 325 | + # Test volume: BA_5AA_proteintest_1 FFN-sharpened, uint8. |
| 326 | + # Shape: (59, 1024, 1024) ZYX — Z=59 is smaller than training patch (64). |
| 327 | + # Use window_size: [32, 256, 256] so the sliding window fits within Z=59. |
| 328 | + # No label available — inference only (evaluation.enabled: false). |
| 329 | + test: |
| 330 | + # image: /orcd/data/edboyden/002/mansour8/deb_data/BA_5AA_proteintest_1_2026-02-18_17.25.37_channel1_ffn_sharp.zarr |
| 331 | + image: /orcd/data/edboyden/002/dleible/DL288B/DL288B_251222S_cond5_40x_12tiles_round1_fused_488_crop512x1024x1024.tif |
| 332 | + # label: omitted — no ground truth available for this volume |
| 333 | + resolution: [25, 18, 18] # same voxel size as training data |
| 334 | + |
| 335 | + inference: |
| 336 | + sliding_window: |
| 337 | + # window_size Z=32 (< volume Z=59) — model was trained on Z=64 but handles |
| 338 | + # smaller windows; use Z=32 to avoid padding artifacts at volume boundaries. |
| 339 | + # Y/X=256 matches training receptive field for best quality. |
| 340 | + window_size: [32, 256, 256] |
| 341 | + overlap: 0.5 |
| 342 | + blending: gaussian |
| 343 | + keep_input_on_cpu: true |
| 344 | + sw_device: cuda |
| 345 | + output_device: cpu |
| 346 | + |
| 347 | + save_prediction: |
| 348 | + enabled: true |
| 349 | + output_formats: [h5] |
| 350 | + output_path: /orcd/scratch/bcs/002/mansour/train_liconn/outputs/results_axons/ |
| 351 | + # float32 for full precision affinities; use float16 to halve storage |
| 352 | + intensity_scale: -1.0 |
| 353 | + intensity_dtype: float32 |
| 354 | + compression: gzip |
| 355 | + |
| 356 | + # ── Decoding ────────────────────────────────────────────────────────────── |
| 357 | + # Option A: ABISS external pipeline — best quality, requires ABISS built. |
| 358 | + # Falls back to connected-components automatically if ABISS fails. |
| 359 | + # |
| 360 | + # ABISS binary: abiss/build/ws (built from abiss/ source via abiss/build_ws.sh) |
| 361 | + # Intel TBB runtime: /orcd/software/community/001/rocky8/intel/2024.2.1/tbb/2021.13/lib |
| 362 | + # must be on LD_LIBRARY_PATH at runtime (ws binary links against libtbb.so.12). |
| 363 | + # |
| 364 | + # Thresholds for axon segmentation at 18 nm: |
| 365 | + # ws_high_threshold=0.9 — only seed watershed at very high affinity (confident boundaries) |
| 366 | + # ws_low_threshold=0.3 — extend seeds down to moderate affinity |
| 367 | + # ws_size_threshold=100 — discard fragments smaller than 100 voxels |
| 368 | + decoding: |
| 369 | + - name: decode_abiss |
| 370 | + kwargs: |
| 371 | + command: > |
| 372 | + env LD_LIBRARY_PATH=/orcd/software/community/001/rocky8/intel/2024.2.1/tbb/2021.13/lib:$LD_LIBRARY_PATH |
| 373 | + {python_exe} scripts/run_abiss_single.py |
| 374 | + --input {input_h5} |
| 375 | + --output {output_h5} |
| 376 | + --abiss-home /home/mansour8/pytc_dev/pytorch_connectomics/abiss |
| 377 | + --ws-high-threshold 0.9 |
| 378 | + --ws-low-threshold 0.3 |
| 379 | + --ws-size-threshold 100 |
| 380 | + input_dataset: main |
| 381 | + output_dataset: main |
| 382 | + timeout_sec: 3600 |
| 383 | + |
| 384 | + # Option B: Lightweight connected-components (no ABISS required). |
| 385 | + # Only uses short-range affinities (ch 0-2); good for quick inspection. |
| 386 | + # Uncomment and comment out Option A above to use. |
| 387 | + # decoding: |
| 388 | + # - name: decode_affinity_cc |
| 389 | + # kwargs: |
| 390 | + # threshold: 0.5 |
| 391 | + # backend: auto |
| 392 | + |
| 393 | + evaluation: |
| 394 | + enabled: false # no ground truth — cannot compute metrics |
0 commit comments