Skip to content

Commit 1b356aa

Browse files
Donglai Weiclaude
andcommitted
tutorials/neuron_nisb: SDT/clDice/MALIS-ft/x2xy experiment configs + schedule cleanup
New configs for the thin-neurite training-side line: - base_banis+_cldice.yaml: single-head 8ch (aff+fg+SDT) with SoftClDice on fg - base_banis+_sdt_1head.yaml: single-head 7ch variant of base_banis+_sdt - base_banis+_malis_ft.yaml: +100k MALIS finetune of the v3_erosion2 200k ckpt - base_banis+_x2xy.yaml: 2x-XY inference diagnostic (no retrain) Edits: - base_banis+_sdt.yaml: erosion=0 (erosion>=2 erases thin connectors from both affinity and SDT targets, defeating the SDT line) - base_banis_opt.yaml: bake in the 200k schedule (max_steps + cosine t_max) so the whole banis+ line inherits it - base_banis+.yaml: drop the redundant mednext_l arch block (now inherited) Also move dev/nisb/v3_erosion2_oracle_lut.py into dev/nisb/scripts/. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 429d983 commit 1b356aa

8 files changed

Lines changed: 223 additions & 6 deletions

File tree

File renamed without changes.

tutorials/neuron_nisb/base_banis+.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@ description: >-
99
save_path: outputs/nisb_base_banis_v3_erosion2
1010

1111
default:
12-
model:
13-
arch:
14-
profile: mednext_l
15-
mednext:
16-
size: L
17-
kernel_size: 3
1812
data:
1913
dataloader:
2014
batch_size: 2
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
_base_:
2+
- base_banis+.yaml
3+
4+
experiment_name: nisb_base_banis_plus_cldice
5+
description: >-
6+
banis+ (MedNeXt-L/k3) single-head, 8-channel output: 6 affinity (r1+r10) +
7+
1 foreground + 1 skeleton-aware SDT. Adds a connectivity-aware SoftClDice loss
8+
on the foreground channel to discourage breaks at thin neurites (the per-voxel
9+
affinity BCE under-penalizes a single missing voxel on a 1-voxel-wide
10+
connector, which is exactly what fragments long backbones and caps NERL).
11+
Affinity decode path is unchanged (channels 0:3). erosion=2, bf16-mixed,
12+
cosine 200k inherited.
13+
14+
save_path: outputs/nisb_base_banis_plus_cldice
15+
16+
default:
17+
model:
18+
# Single shared head: backbone emits all 8 channels (no `heads:` dict).
19+
out_channels: 8
20+
loss:
21+
deep_supervision: false
22+
losses:
23+
# Affinity (decode signal) — unchanged from banis+/v1.
24+
- function: PerChannelBCEWithLogitsLoss
25+
weight: 1.0
26+
pred_slice: "0:6"
27+
target_slice: "0:6"
28+
kwargs:
29+
auto_pos_weight: true
30+
max_pos_weight: 10.0
31+
# Foreground head: plain BCE for calibration ...
32+
- function: WeightedBCEWithLogitsLoss
33+
weight: 1.0
34+
pred_slice: "6:7"
35+
target_slice: "6:7"
36+
# ... plus connectivity-aware clDice on the same foreground channel.
37+
- function: SoftClDiceLoss
38+
weight: 0.5
39+
pred_slice: "6:7"
40+
target_slice: "6:7"
41+
kwargs:
42+
sigmoid: true
43+
num_iters: 5
44+
mode: binary
45+
# Skeleton-aware SDT (auxiliary structure signal).
46+
- function: SmoothL1Loss
47+
weight: 10.0
48+
pred_slice: "7:8"
49+
target_slice: "7:8"
50+
kwargs:
51+
tanh: true
52+
data:
53+
label_transform:
54+
erosion: 0 # do NOT erode GT (erosion>=2 erases thin neurites from all targets)
55+
targets:
56+
- name: affinity
57+
kwargs:
58+
offsets: ["1-0-0", "0-1-0", "0-0-1", "10-0-0", "0-10-0", "0-0-10"]
59+
affinity_mode: banis
60+
- name: binary
61+
- name: skeleton_aware_edt
62+
kwargs:
63+
alpha: 0.8
64+
bg_value: -1.0
65+
66+
train:
67+
optimization:
68+
precision: "bf16-mixed"
69+
monitor:
70+
logging:
71+
scalar:
72+
loss:
73+
- train_loss_total_epoch
74+
- val_loss_total
75+
- train_loss_term_0_weighted
76+
- train_loss_term_1_weighted
77+
- train_loss_term_2_weighted
78+
- train_loss_term_3_weighted
79+
loss_every_n_steps: 100
80+
images:
81+
channel_mode: all
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
_base_:
2+
- base_banis+_malis.yaml
3+
4+
experiment_name: nisb_base_banis_v3_erosion2_malis_ft
5+
description: >-
6+
Finetune the v3_erosion2 200k MedNeXt-L affinity checkpoint (PerChannelBCE
7+
only, no MALIS) for an additional 100k steps with the short-range MALIS loss
8+
added. Weights are loaded init-only (no optimizer/scheduler/EMA state) so the
9+
MALIS term starts from the converged affinity model and a fresh, smaller-LR
10+
cosine schedule anneals over the 100k finetune.
11+
12+
save_path: outputs/nisb_base_banis_v3_erosion2_malis_ft
13+
14+
default:
15+
model:
16+
# Init-only load of the base_banis+ (v3_erosion2) 200k weights. Architecture
17+
# is identical to the MALIS variant (MALIS only adds a loss over existing
18+
# channels). This is an in-framework Lightning checkpoint whose state_dict
19+
# keys are double-nested ("model.model.<mednext>": ConnectomicsModule.model
20+
# = wrapper, wrapper.model = MedNeXt), so strip "model.model." (not the
21+
# default "model.") to land on the bare MedNeXt keys. Verified 0 missing /
22+
# 0 unexpected keys.
23+
external_weights_path: outputs/nisb_base_banis_v3_erosion2/20260508_224029/checkpoints/step=00200000.ckpt
24+
external_weights_key_prefix: "model.model."
25+
26+
train:
27+
optimization:
28+
# 100k finetune with a 10x-smaller peak LR than the 1e-3 base run; cosine
29+
# period matched to the finetune length so LR anneals over the full run.
30+
optimizer:
31+
lr: 0.0001
32+
max_steps: 100000
33+
scheduler:
34+
params:
35+
t_max: 100000

tutorials/neuron_nisb/base_banis+_sdt.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ default:
4545
tanh: true
4646
data:
4747
label_transform:
48+
# erosion=0: do NOT erode GT. erosion>=2 erases thin neurites (<=~4px) from BOTH the
49+
# affinity and SDT targets -> the model can never learn thin connectors. Preserving them is
50+
# the whole point of the SDT/thin-neurite line (overrides banis+ erosion=2).
51+
erosion: 0
4852
targets:
4953
- name: affinity
5054
kwargs:
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
_base_:
2+
- base_banis+.yaml
3+
4+
experiment_name: nisb_base_banis_plus_sdt_1head
5+
description: >-
6+
Single-head variant of base_banis+_sdt: banis+ (MedNeXt-L/k3) backbone emits
7+
one 7-channel tensor (6 affinity r1+r10 + 1 skeleton-aware SDT) directly off
8+
the shared trunk, instead of two separate task-head branches. Losses are
9+
routed by channel via pred_slice/target_slice: PerChannelBCE on 0:6, SmoothL1
10+
on 6:7. Efficiency: removes the per-head conv branches, uses bf16-mixed (no
11+
fp16 grad-scaler overhead). erosion=2, label erosion + cosine 200k inherited.
12+
13+
save_path: outputs/nisb_base_banis_plus_sdt_1head
14+
15+
default:
16+
model:
17+
# Single shared head: backbone outputs all 7 channels (no `heads:` dict).
18+
out_channels: 7
19+
loss:
20+
deep_supervision: false
21+
losses:
22+
- function: PerChannelBCEWithLogitsLoss
23+
weight: 1.0
24+
pred_slice: "0:6"
25+
target_slice: "0:6"
26+
kwargs:
27+
auto_pos_weight: true
28+
max_pos_weight: 10.0
29+
- function: SmoothL1Loss
30+
weight: 10.0
31+
pred_slice: "6:7"
32+
target_slice: "6:7"
33+
kwargs:
34+
tanh: true
35+
data:
36+
label_transform:
37+
erosion: 0 # do NOT erode GT (erosion>=2 erases thin neurites from both targets)
38+
targets:
39+
- name: affinity
40+
kwargs:
41+
offsets: ["1-0-0", "0-1-0", "0-0-1", "10-0-0", "0-10-0", "0-0-10"]
42+
affinity_mode: banis
43+
- name: skeleton_aware_edt
44+
kwargs:
45+
alpha: 0.8
46+
bg_value: -1.0
47+
48+
train:
49+
optimization:
50+
precision: "bf16-mixed"
51+
monitor:
52+
logging:
53+
scalar:
54+
loss:
55+
- train_loss_total_epoch
56+
- val_loss_total
57+
- train_loss_term_0_weighted
58+
- train_loss_term_1_weighted
59+
loss_every_n_steps: 100
60+
images:
61+
channel_mode: all
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
_base_:
2+
- base_banis+.yaml
3+
4+
experiment_name: nisb_base_banis_v3_erosion2_x2xy
5+
description: >-
6+
2x-XY inference of the trained base_banis+ (MedNeXt-L, 200k) for the thin-neurite
7+
diagnostic (dev/nisb/deepresearch_thin_neurite_findings.md exp-2 step-1): does finer in-plane
8+
resolution recover the thin short-range XY affinity that shatters into crumbs? NO retraining --
9+
same 200k checkpoint, fed 2x-XY-upsampled tiles. Lazy upsample: read native [64,64,128] (X,Y,Z)
10+
patches, upsample XY 2x -> model sees [128,128,128]; Z stays native (anisotropic, 20nm).
11+
Chunked raw-prediction output so the 6000x6000x1350 affinity never materializes whole.
12+
13+
save_path: outputs/nisb_base_banis_v3_erosion2_x2xy
14+
15+
default:
16+
data:
17+
dataloader:
18+
# native read; the two 9nm in-plane axes (0,1) are halved then upsampled 2x.
19+
patch_size: [64, 64, 128]
20+
data_transform:
21+
# model input size after the 2x-XY upsample (matches its 128^3 training window).
22+
resize: [128, 128, 128]
23+
inference:
24+
execution:
25+
strategy: chunked
26+
chunking:
27+
enabled: true
28+
output_mode: raw_prediction # we want the raw 2x affinity, not decoded
29+
axes: all
30+
# ZYX (post val_transpose) in the UPSAMPLED grid; small enough to fit memory.
31+
chunk_size: [256, 512, 512]
32+
halo: [16, 64, 64]
33+
save_intermediate: true
34+
window:
35+
window_size: [128, 128, 128]
36+
keep_input_on_cpu: true

tutorials/neuron_nisb/base_banis_opt.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ default:
2525

2626
train:
2727
optimization:
28+
# Longer schedule than the v0 base (50k): 200k steps with the cosine
29+
# period matched so LR anneals over the full run.
30+
max_steps: 200000
31+
scheduler:
32+
params:
33+
t_max: 200000
2834
ema:
2935
enabled: true
3036
decay: 0.999

0 commit comments

Comments
 (0)