Skip to content

Commit ee82c46

Browse files
author
Donglai Wei
committed
Fix loss balancing wiring and SNEMI SDT config
1 parent 69c315c commit ee82c46

3 files changed

Lines changed: 217 additions & 2 deletions

File tree

connectomics/training/loss/balancing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,16 @@ def build_loss_weighter(
188188
if not hasattr(cfg, "model") or not hasattr(cfg.model, "loss"):
189189
return None
190190

191-
lb_cfg = getattr(cfg.model, "loss", None)
192-
if lb_cfg is None:
191+
loss_cfg = getattr(cfg.model, "loss", None)
192+
if loss_cfg is None:
193193
return None
194+
195+
# Prefer the schema-defined nested loss_balancing block, but keep support
196+
# for older flat configs that placed strategy fields directly under model.loss.
197+
lb_cfg = getattr(loss_cfg, "loss_balancing", None)
198+
if lb_cfg is None or getattr(lb_cfg, "strategy", None) is None:
199+
lb_cfg = loss_cfg
200+
194201
strategy = getattr(lb_cfg, "strategy", None)
195202
if strategy is None:
196203
return None

tests/unit/test_loss_balancing.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from types import SimpleNamespace
2+
3+
import torch.nn as nn
4+
5+
from connectomics.training.loss import (
6+
GradNormLossWeighter,
7+
UncertaintyLossWeighter,
8+
build_loss_weighter,
9+
)
10+
11+
12+
def _cfg(
13+
strategy=None,
14+
*,
15+
gradnorm_alpha=0.5,
16+
gradnorm_lambda=1.0,
17+
gradnorm_parameter_strategy="last",
18+
legacy_flat=False,
19+
):
20+
if legacy_flat:
21+
loss = SimpleNamespace(
22+
strategy=strategy,
23+
gradnorm_alpha=gradnorm_alpha,
24+
gradnorm_lambda=gradnorm_lambda,
25+
gradnorm_parameter_strategy=gradnorm_parameter_strategy,
26+
)
27+
else:
28+
loss = SimpleNamespace(
29+
loss_balancing=SimpleNamespace(
30+
strategy=strategy,
31+
gradnorm_alpha=gradnorm_alpha,
32+
gradnorm_lambda=gradnorm_lambda,
33+
gradnorm_parameter_strategy=gradnorm_parameter_strategy,
34+
)
35+
)
36+
return SimpleNamespace(model=SimpleNamespace(loss=loss))
37+
38+
39+
def test_build_loss_weighter_uses_nested_uncertainty_strategy():
40+
weighter = build_loss_weighter(_cfg(strategy="uncertainty"), num_tasks=3)
41+
42+
assert isinstance(weighter, UncertaintyLossWeighter)
43+
44+
45+
def test_build_loss_weighter_uses_nested_gradnorm_settings():
46+
model = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 2))
47+
48+
weighter = build_loss_weighter(
49+
_cfg(
50+
strategy="gradnorm",
51+
gradnorm_alpha=0.25,
52+
gradnorm_lambda=2.5,
53+
gradnorm_parameter_strategy="first",
54+
),
55+
num_tasks=3,
56+
model=model,
57+
)
58+
59+
assert isinstance(weighter, GradNormLossWeighter)
60+
assert weighter.alpha == 0.25
61+
assert weighter.gradnorm_lambda == 2.5
62+
assert len(weighter.shared_parameters) == 1
63+
assert weighter.shared_parameters[0] is next(model.parameters())
64+
65+
66+
def test_build_loss_weighter_keeps_legacy_flat_strategy_support():
67+
weighter = build_loss_weighter(_cfg(strategy="uncertainty", legacy_flat=True), num_tasks=2)
68+
69+
assert isinstance(weighter, UncertaintyLossWeighter)

tutorials/neuron_snemi_sdt.yaml

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
experiment_name: rsunet_snemi_lee2017_modern_sdt
2+
description: SNEMI3D neuron affinity + SDT learning (9-channel affinity + 1-channel SDT)
3+
4+
_base_:
5+
- bases/all_profiles.yaml
6+
7+
default:
8+
system:
9+
profile: all-gpu-cpu
10+
model:
11+
arch:
12+
profile: rsunet
13+
input_size: [16, 224, 224]
14+
output_size: [16, 224, 224]
15+
out_channels: 10
16+
loss:
17+
profile: loss_bd
18+
data:
19+
label_transform:
20+
profile: label_affinity_9_sdt
21+
resolution: [30, 6, 6]
22+
dataloader:
23+
profile: cached
24+
patch_size: [16, 224, 224]
25+
data_transform:
26+
# Keep symmetric full-volume context padding on the inference input.
27+
pad_size: [17, 128, 128]
28+
augmentation:
29+
profile: aug_em_neuron
30+
inference:
31+
sliding_window:
32+
window_size: [16, 224, 224]
33+
sw_batch_size: 1
34+
keep_input_on_cpu: false
35+
test_time_augmentation:
36+
enabled: false
37+
#enabled: true
38+
patch_first_local: true
39+
flip_axes: all
40+
rotation90_axes: [[1, 2]]
41+
activation_profile: act_bd
42+
#select_channel: [0, 1, 2, 9]
43+
ensemble_mode: [["0:9", min], ["9:", mean]]
44+
postprocessing:
45+
enabled: true
46+
# crop_pad + affinity_crop[(17,0),(17,0),(17,0)] = pad_size [17,128,128]:
47+
# Z: 0+17=17, 17+0=17 | Y/X: 111+17=128, 128+0=128
48+
crop_pad: [0, 17, 111, 128, 111, 128]
49+
save_prediction:
50+
enabled: true
51+
decoding_profile: decoding_waterz
52+
evaluation:
53+
enabled: true
54+
metrics: [adapted_rand]
55+
56+
57+
train:
58+
data:
59+
train:
60+
image: datasets/SNEMI/train-input.tif
61+
label: datasets/SNEMI/train-labels.tif
62+
63+
optimization:
64+
profile: warmup_cosine_lr
65+
max_epochs: 200
66+
n_steps_per_epoch: 1000
67+
monitor:
68+
logging:
69+
scalar:
70+
loss: [train_loss_total_epoch, train_loss_affinity_total, train_loss_sdt_total]
71+
loss_every_n_steps: 50
72+
images:
73+
log_every_n_epochs: 10
74+
max_images: 8
75+
num_slices: 2
76+
channel_mode: all
77+
checkpoint:
78+
save_top_k: 3
79+
monitor: train_loss_total_epoch
80+
mode: min
81+
82+
test:
83+
data:
84+
test:
85+
path: datasets/SNEMI/
86+
# image: [train-input.tif, test-input_z29.h5]
87+
# label: [train-labels.tif, test-labels.h5]
88+
#image: train-input.tif
89+
#label: train-labels.tif
90+
#image: test-input_z29.h5
91+
image: test-input.tif
92+
label: test-labels.h5
93+
resolution: [30, 6, 6]
94+
inference:
95+
decoding:
96+
- profile: decoding_waterz
97+
kwargs:
98+
thresholds: 0.4
99+
merge_function: aff85_his256
100+
aff_threshold: [0.001, 0.999]
101+
102+
# ============================================================================
103+
# Parameter tuning for waterz agglomeration thresholds (--mode tune)
104+
# ============================================================================
105+
tune:
106+
profile: tune_waterz
107+
n_trials: 25
108+
study_name: snemi_waterz_tuning
109+
data:
110+
val:
111+
image: datasets/SNEMI/test-input.tif
112+
label: datasets/SNEMI/test-labels.h5
113+
#image: datasets/SNEMI/train-input.tif
114+
#label: datasets/SNEMI/train-labels.tif
115+
# Override profile defaults for SNEMI-specific search ranges
116+
parameter_space:
117+
decoding:
118+
defaults:
119+
thresholds: 0.4
120+
merge_function: aff85_his256
121+
aff_threshold: [0.001, 0.999]
122+
parameters:
123+
merge_function:
124+
type: categorical
125+
choices: [aff85_his256, aff75_his256, aff50_his256]
126+
description: "Agglomeration scoring function (quantile via histogram)"
127+
thresholds:
128+
range: [0.1, 0.9]
129+
step: 0.05
130+
aff_threshold_low:
131+
range: [0.001, 0.3]
132+
step: 0.01
133+
param_group: aff_threshold
134+
tuple_index: 0
135+
aff_threshold_high:
136+
range: [0.8, 0.9999]
137+
step: 0.01
138+
param_group: aff_threshold
139+
tuple_index: 1

0 commit comments

Comments
 (0)