Skip to content

Commit 82424d4

Browse files
author
Donglai Wei
committed
Fix label transform resolution and batched affinity masks
1 parent af62e50 commit 82424d4

4 files changed

Lines changed: 86 additions & 0 deletions

File tree

connectomics/config/schema/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class LabelTransformConfig:
7878
allow_missing_keys: bool = False
7979
segment_id: Optional[List[int]] = None
8080
boundary_thickness: int = 1
81+
resolution: Optional[List[float]] = None # Forwarded into compatible label targets.
8182
targets: List[Any] = field(default_factory=list)
8283

8384

connectomics/models/loss/losses.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def _reduce_weighted_tensor(
3131
return loss_tensor.mean()
3232

3333
valid = weight > 0
34+
if valid.shape != loss_tensor.shape:
35+
try:
36+
valid = torch.broadcast_to(valid, loss_tensor.shape)
37+
except RuntimeError as e:
38+
raise ValueError(
39+
"Weight mask shape is not broadcastable to loss tensor shape: "
40+
f"weight={tuple(weight.shape)}, loss={tuple(loss_tensor.shape)}"
41+
) from e
3442
if not torch.any(valid):
3543
return loss_tensor.new_tensor(0.0)
3644
return loss_tensor[valid].mean()

tests/unit/test_hydra_config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from connectomics.config.schema import TestConfig as HydraTestConfig # noqa: E402
3131
from connectomics.config.schema import TuneConfig # noqa: E402
3232
from connectomics.data.augment.build import build_test_transforms # noqa: E402
33+
from connectomics.data.process import create_label_transform_pipeline # noqa: E402
3334

3435

3536
def test_default_config_creation():
@@ -467,6 +468,35 @@ def test_yaml_shared_profile_selectors(tmp_path):
467468
print("[OK]YAML shared profile selectors work")
468469

469470

471+
def test_label_transform_profile_can_inherit_top_level_resolution(tmp_path):
472+
repo_root = Path(__file__).resolve().parents[2]
473+
base_profiles = repo_root / "tutorials" / "bases" / "all_profiles.yaml"
474+
config_yaml = tmp_path / "config.yaml"
475+
config_yaml.write_text(
476+
f"""
477+
_base_:
478+
- {base_profiles}
479+
480+
default:
481+
data:
482+
label_transform:
483+
profile: label_affinity_9_sdt
484+
resolution: [30, 6, 6]
485+
""".strip()
486+
)
487+
488+
cfg = load_config(config_yaml)
489+
490+
targets = cfg.default.data.label_transform.targets
491+
assert len(targets) == 2
492+
assert targets[0]["name"] == "affinity"
493+
assert targets[1]["name"] == "skeleton_aware_edt"
494+
assert cfg.default.data.label_transform.resolution == [30.0, 6.0, 6.0]
495+
496+
pipeline = create_label_transform_pipeline(cfg.default.data.label_transform)
497+
assert pipeline.task_specs[1]["kwargs"]["resolution"] == [30.0, 6.0, 6.0]
498+
499+
470500
def test_arch_profile_rejects_non_model_sections(tmp_path):
471501
"""Arch profiles with invalid keys are rejected by OmegaConf schema merge."""
472502
base_yaml = tmp_path / "base.yaml"

tests/unit/test_loss_orchestrator.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,41 @@ def test_standard_loss_supports_affinity_deepem_crop():
325325
assert weight[0, 2, :, :, 1].sum() > 0 # x=1 valid for channel 2
326326

327327

328+
def test_standard_loss_supports_batched_affinity_deepem_crop_with_weighted_bce():
329+
cfg = _cfg(
330+
losses=[
331+
{
332+
"weight": 1.0,
333+
},
334+
]
335+
)
336+
cfg.data.label_transform.targets = [
337+
{
338+
"name": "affinity",
339+
"kwargs": {
340+
"offsets": ["1-0-0", "0-1-0", "0-0-1"],
341+
"deepem_crop": True,
342+
},
343+
}
344+
]
345+
orchestrator = LossOrchestrator(
346+
cfg=cfg,
347+
loss_functions=nn.ModuleList([WeightedBCEWithLogitsLoss(reduction="mean")]),
348+
loss_weights=[1.0],
349+
enable_nan_detection=False,
350+
debug_on_nan=False,
351+
)
352+
353+
outputs = torch.zeros(4, 3, 5, 5, 5)
354+
labels = torch.zeros(4, 3, 5, 5, 5)
355+
356+
total_loss, loss_dict = orchestrator.compute_standard_loss(outputs, labels, stage="train")
357+
358+
assert torch.isfinite(total_loss)
359+
assert "train_loss_total" in loss_dict
360+
assert torch.isclose(total_loss, torch.log(torch.tensor(2.0)))
361+
362+
328363
def test_standard_loss_rejects_invalid_runtime_negative_channel_slice():
329364
weighted_loss = WeightAwareSpyLoss()
330365
orchestrator = LossOrchestrator(
@@ -630,6 +665,18 @@ def test_weighted_bce_returns_zero_when_no_valid_weights():
630665
assert torch.isclose(loss, torch.tensor(0.0))
631666

632667

668+
def test_weighted_bce_mean_supports_broadcast_batchless_weight_mask():
669+
loss_fn = WeightedBCEWithLogitsLoss(reduction="mean")
670+
pred = torch.zeros(4, 1, 1, 1, 2)
671+
target = torch.tensor([[[[[1.0, 0.0]]]]]).expand_as(pred)
672+
weight = torch.tensor([[[[[2.0, 0.0]]]]])
673+
674+
loss = loss_fn(pred, target, weight=weight)
675+
676+
expected = 2.0 * torch.log(torch.tensor(2.0))
677+
assert torch.isclose(loss, expected)
678+
679+
633680
def test_multitask_single_scale_routes_class_index_and_dense_targets():
634681
ce_loss = CrossEntropyLossWrapperSpy()
635682
reg_loss = WeightAwareSpyLoss()

0 commit comments

Comments
 (0)