Skip to content

Commit 799afd5

Browse files
authored
Merge pull request #211 from PytorchConnectomics/feat/malis-gt-passthrough
MalisLoss: pass eroded gt_seg through data pipeline (skip per-step CC, fix crop topology)
2 parents aea1feb + 99fc8f8 commit 799afd5

21 files changed

Lines changed: 656 additions & 39 deletions

connectomics/config/pipeline/config_io.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def _raise_unconsumed_keys(yaml_conf: DictConfig) -> None:
153153
"(sibling of `monitor`, `inference`, `decoding`, `tune`)."
154154
),
155155
"use_timestamp": (
156-
"field removed. Train mode is always timestamped; "
157-
"test/tune modes are never timestamped."
156+
"field removed. Train mode is always timestamped; " "test/tune modes are never timestamped."
158157
),
159158
}
160159
_MONITOR_CHECKPOINT_ROOTS = (
@@ -201,8 +200,7 @@ def _reject_inference_runtime_alias_paths(explicit_field_paths: set[str]) -> Non
201200
alias_path = f"{root}.{alias}"
202201
if any(_path_is_or_descendant(path, alias_path) for path in explicit_field_paths):
203202
raise ValueError(
204-
f"`{alias_path}` was renamed. "
205-
f"Use `{root}.{canonical_tail}` instead."
203+
f"`{alias_path}` was renamed. " f"Use `{root}.{canonical_tail}` instead."
206204
)
207205

208206
for root in _MONITOR_CHECKPOINT_ROOTS:
@@ -212,8 +210,7 @@ def _reject_inference_runtime_alias_paths(explicit_field_paths: set[str]) -> Non
212210
if replacement.startswith("field "):
213211
raise ValueError(f"`{alias_path}` {replacement}")
214212
raise ValueError(
215-
f"`{alias_path}` was renamed. "
216-
f"Use `{root}.{replacement}` instead."
213+
f"`{alias_path}` was renamed. " f"Use `{root}.{replacement}` instead."
217214
)
218215

219216
# tune.output:* sub-block hoisted to tune.save_*
@@ -506,7 +503,6 @@ def validate_config(cfg: Config) -> None:
506503
if cfg.model.out_channels <= 0:
507504
raise ValueError("model.out_channels must be positive")
508505
model_heads = getattr(cfg.model, "heads", None) or {}
509-
inference_cfg = getattr(cfg, "inference", None)
510506
inference_head = get_inference_model_value(cfg, "head", None)
511507
images_cfg = getattr(getattr(getattr(cfg, "monitor", None), "logging", None), "images", None)
512508
visualization_head = getattr(images_cfg, "head", None) if images_cfg is not None else None
@@ -556,8 +552,8 @@ def validate_config(cfg: Config) -> None:
556552
missing = [h for h in inference_head_names if h not in model_heads]
557553
if missing:
558554
raise ValueError(
559-
f"inference.model.head={inference_head_names} references unknown heads {missing}; "
560-
f"available: {sorted(model_heads.keys())}."
555+
f"inference.model.head={inference_head_names} references unknown heads "
556+
f"{missing}; available: {sorted(model_heads.keys())}."
561557
)
562558
if (
563559
visualization_head is not None
@@ -909,6 +905,11 @@ def _resolve_split_paths(split_cfg):
909905
split_cfg.image = _combine_path(split_base, split_cfg.image)
910906
split_cfg.label = _combine_path(split_base, split_cfg.label)
911907
split_cfg.mask = _combine_path(split_base, split_cfg.mask)
908+
split_json_resolved = _combine_path(split_base, split_cfg.json)
909+
if isinstance(split_json_resolved, list):
910+
split_cfg.json = split_json_resolved[0] if split_json_resolved else None
911+
else:
912+
split_cfg.json = split_json_resolved
912913

913914
# Resolve inference/test paths from merged runtime cfg.data.
914915
if getattr(cfg.data, "test", None) is not None:

connectomics/config/schema/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class LabelTransformConfig:
6666

6767
normalize: bool = True # Convert labels to 0-1 range
6868
erosion: int = 0 # Border erosion kernel half-size (0 = disabled, uses seg_widen_border)
69+
emit_gt_seg: bool = False # Emit post-augmentation/post-erosion segmentation for MALIS.
6970
skeleton_distance: SkeletonDistanceConfig = field(default_factory=SkeletonDistanceConfig)
7071
edge_mode: EdgeModeConfig = field(
7172
default_factory=EdgeModeConfig

connectomics/data/augmentation/build.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
BorderPadd,
1616
CenterSpatialCropd,
1717
Compose,
18+
CopyItemsd,
1819
Lambdad,
1920
OneOf,
2021
RandAdjustContrastd,
@@ -182,7 +183,7 @@ def _build_nnunet_preprocess_transform(keys, nnunet_pre_cfg, source_spacing):
182183

183184

184185
def build_train_transforms(
185-
cfg: Config, keys: list[str] = None, skip_loading: bool = False
186+
cfg: Config, keys: list[str] | None = None, skip_loading: bool = False
186187
) -> Compose:
187188
"""
188189
Build training transforms from Hydra config.
@@ -320,6 +321,8 @@ def build_train_transforms(
320321
# Apply instance erosion first if specified
321322
if hasattr(label_cfg, "erosion") and label_cfg.erosion > 0:
322323
transforms.append(SegErosionInstanced(keys=["label"], tsz_h=label_cfg.erosion))
324+
if label_cfg.emit_gt_seg:
325+
transforms.append(CopyItemsd(keys="label", names="gt_seg"))
323326

324327
# Build label transform pipeline directly from label_transform config
325328
label_transform = create_label_transform_pipeline(label_cfg)
@@ -348,7 +351,7 @@ def build_train_transforms(
348351

349352

350353
def _build_eval_transforms_impl(
351-
cfg: Config, mode: str = "val", keys: list[str] = None, skip_loading: bool = False
354+
cfg: Config, mode: str = "val", keys: list[str] | None = None, skip_loading: bool = False
352355
) -> Compose:
353356
"""
354357
Internal implementation for building evaluation transforms (validation or test).
@@ -669,6 +672,8 @@ def _resolve_eval_split():
669672
# Apply instance erosion first if specified
670673
if hasattr(label_cfg, "erosion") and label_cfg.erosion > 0:
671674
transforms.append(SegErosionInstanced(keys=["label"], tsz_h=label_cfg.erosion))
675+
if label_cfg.emit_gt_seg:
676+
transforms.append(CopyItemsd(keys="label", names="gt_seg"))
672677

673678
# Build label transform pipeline directly from label_transform config
674679
label_transform = create_label_transform_pipeline(label_cfg)
@@ -695,7 +700,7 @@ def _resolve_eval_split():
695700

696701

697702
def build_val_transforms(
698-
cfg: Config, keys: list[str] = None, skip_loading: bool = False
703+
cfg: Config, keys: list[str] | None = None, skip_loading: bool = False
699704
) -> Compose:
700705
"""
701706
Build validation transforms from Hydra config.
@@ -711,7 +716,9 @@ def build_val_transforms(
711716
return _build_eval_transforms_impl(cfg, mode="val", keys=keys, skip_loading=skip_loading)
712717

713718

714-
def build_test_transforms(cfg: Config, keys: list[str] = None, mode: str = "test") -> Compose:
719+
def build_test_transforms(
720+
cfg: Config, keys: list[str] | None = None, mode: str = "test"
721+
) -> Compose:
715722
"""
716723
Build test/tune inference transforms from Hydra config.
717724

connectomics/models/losses/malis.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ class MalisLoss(nn.Module):
2626
2D tensors are rejected explicitly because the vendored MALIS helpers operate
2727
on 3D affinity graphs by default. See ``lib/malis/INVESTIGATION.md`` for
2828
GPU MALIS candidates and algorithm-level speedup follow-ups.
29+
30+
Performance knobs (see ``docs/source/notes/malis.rst``):
31+
32+
- ``malis_crop_size`` — random sub-volume crop on each forward call.
33+
``64`` on a ``128^3`` patch gives ~4.6x measured step speedup vs
34+
the full-volume baseline (slurm 2505814 vs 2487040).
35+
- ``label_transform.emit_gt_seg: true`` (YAML, paired with this
36+
loss) — passes the eroded GT segmentation in via ``gt_seg=...``,
37+
skipping the per-step ``connected_components_affgraph`` call and
38+
preserving global instance IDs when ``malis_crop_size`` is set.
2939
"""
3040

3141
def __init__(
@@ -68,6 +78,7 @@ def forward(
6878
pred: torch.Tensor,
6979
target: torch.Tensor,
7080
mask: torch.Tensor | None = None,
81+
gt_seg: torch.Tensor | np.ndarray | None = None,
7182
) -> torch.Tensor:
7283
"""Compute MALIS-weighted squared affinity error.
7384
@@ -83,21 +94,31 @@ def forward(
8394
Masked-out edges are excluded from MALIS pass constraints and
8495
zeroed before per-pass normalization, but the mask does not
8596
change GT connected-component reconstruction.
97+
gt_seg: Optional ground-truth segmentation with shape ``[B, Z, Y, X]``
98+
or ``[B, 1, Z, Y, X]``. When supplied, MALIS uses these instance
99+
labels directly instead of reconstructing components from
100+
``target`` affinities.
86101
"""
87102
self._validate_inputs(pred, target)
88103

89104
pred_aff = torch.sigmoid(pred) if self.sigmoid else pred
90105
target_aff = target.to(device=pred.device, dtype=pred_aff.dtype)
91106
mask_aff = None if mask is None else self._prepare_mask(mask, pred_aff)
92-
pred_aff, target_aff, mask_aff = self._apply_crop_if_configured(
107+
gt_seg_tensor = self._prepare_gt_seg(gt_seg, pred_aff)
108+
pred_aff, target_aff, mask_aff, gt_seg_tensor = self._apply_crop_if_configured(
93109
pred_aff,
94110
target_aff,
95111
mask_aff,
112+
gt_seg_tensor,
96113
)
114+
weight_kwargs = {}
115+
if gt_seg_tensor is not None:
116+
weight_kwargs["gt_seg"] = gt_seg_tensor.detach()
97117
weights = self._compute_malis_weights(
98118
pred_aff.detach(),
99119
target_aff.detach(),
100120
None if mask_aff is None else mask_aff.detach(),
121+
**weight_kwargs,
101122
)
102123

103124
edge_loss = (pred_aff - target_aff) ** 2
@@ -211,12 +232,36 @@ def _prepare_mask(self, mask: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
211232
f"mask={tuple(mask.shape)}, pred={tuple(pred.shape)}."
212233
) from e
213234

235+
def _prepare_gt_seg(
236+
self,
237+
gt_seg: torch.Tensor | np.ndarray | None,
238+
pred: torch.Tensor,
239+
) -> torch.Tensor | None:
240+
if gt_seg is None:
241+
return None
242+
243+
gt_seg_tensor = torch.as_tensor(gt_seg, device=pred.device).detach()
244+
if gt_seg_tensor.ndim == pred.ndim and gt_seg_tensor.shape[1] == 1:
245+
gt_seg_tensor = gt_seg_tensor.squeeze(1)
246+
elif gt_seg_tensor.ndim == pred.ndim - 2 and pred.shape[0] == 1:
247+
gt_seg_tensor = gt_seg_tensor.unsqueeze(0)
248+
249+
expected_shape = (pred.shape[0],) + tuple(pred.shape[-3:])
250+
if tuple(gt_seg_tensor.shape) != expected_shape:
251+
raise ValueError(
252+
"MalisLoss gt_seg must have shape [B, Z, Y, X] or [B, 1, Z, Y, X] "
253+
f"matching pred spatial dims; got gt_seg={tuple(gt_seg_tensor.shape)}, "
254+
f"expected={expected_shape}."
255+
)
256+
return gt_seg_tensor.contiguous()
257+
214258
def _apply_crop_if_configured(
215259
self,
216260
pred_aff: torch.Tensor,
217261
target_aff: torch.Tensor,
218262
mask_aff: torch.Tensor | None,
219-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
263+
gt_seg: torch.Tensor | None,
264+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
220265
"""Apply the configured random sub-volume crop, if any.
221266
222267
Offset sampling stays on CPU. The returned tensors are contiguous copies
@@ -226,11 +271,11 @@ def _apply_crop_if_configured(
226271
crop=64, and fp16, pred + target + mask copies are about 9 MiB before
227272
overhead.
228273
229-
Returns ``(pred_cropped, target_cropped, mask_cropped)``. If no crop is
230-
configured the inputs are returned unchanged.
274+
Returns ``(pred_cropped, target_cropped, mask_cropped, gt_seg_cropped)``.
275+
If no crop is configured the inputs are returned unchanged.
231276
"""
232277
if self.malis_crop_size is None:
233-
return pred_aff, target_aff, mask_aff
278+
return pred_aff, target_aff, mask_aff, gt_seg
234279

235280
k_z, k_y, k_x = self.malis_crop_size
236281
z_dim, y_dim, x_dim = pred_aff.shape[-3:]
@@ -253,29 +298,40 @@ def _apply_crop_if_configured(
253298
if mask_aff is None
254299
else mask_aff.narrow(-3, z0, k_z).narrow(-2, y0, k_y).narrow(-1, x0, k_x).contiguous()
255300
)
256-
return pred_c, target_c, mask_c
301+
gt_seg_c = (
302+
None
303+
if gt_seg is None
304+
else gt_seg.narrow(-3, z0, k_z).narrow(-2, y0, k_y).narrow(-1, x0, k_x).contiguous()
305+
)
306+
return pred_c, target_c, mask_c, gt_seg_c
257307

258308
def _compute_malis_weights(
259309
self,
260310
pred_aff: torch.Tensor,
261311
target_aff: torch.Tensor,
262312
mask: torch.Tensor | None = None,
313+
*,
314+
gt_seg: torch.Tensor | None = None,
263315
) -> torch.Tensor:
264316
pred_np = pred_aff.to(dtype=torch.float32).cpu().numpy()
265317
target_np = target_aff.to(dtype=torch.float32).cpu().numpy()
266318
mask_np = None if mask is None else mask.to(dtype=torch.float32).cpu().numpy()
319+
gt_seg_np = None if gt_seg is None else gt_seg.cpu().numpy()
267320
weights = np.empty_like(pred_np, dtype=np.float32)
268321
for batch_idx in range(pred_np.shape[0]):
269322
gt_affs = np.ascontiguousarray(target_np[batch_idx] > 0.5, dtype=np.int32)
270323
pred_sample = np.ascontiguousarray(pred_np[batch_idx], dtype=np.float32)
271324
mask_sample = None
272325
if mask_np is not None:
273326
mask_sample = np.ascontiguousarray(mask_np[batch_idx] == 1, dtype=bool)
274-
gt_seg, _ = _malis_lib.connected_components_affgraph(gt_affs, self.nhood)
327+
if gt_seg_np is None:
328+
gt_seg_sample, _ = _malis_lib.connected_components_affgraph(gt_affs, self.nhood)
329+
else:
330+
gt_seg_sample = np.ascontiguousarray(gt_seg_np[batch_idx], dtype=np.uint64)
275331
weights[batch_idx] = self._compute_sample_weights(
276332
pred_sample,
277333
gt_affs,
278-
gt_seg,
334+
gt_seg_sample,
279335
mask_sample,
280336
)
281337

connectomics/models/losses/metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class LossMetadata:
1919
call_kind: LossCallKind = "pred_target" # pred_target | pred_only | pred_pred | unsupported
2020
target_kind: TargetKind = "dense" # dense | class_index | none
2121
spatial_weight_arg: Optional[str] = None # weight | mask | None
22+
gt_seg_arg: Optional[str] = None # gt_seg | None
2223

2324

2425
_LOSS_METADATA_BY_NAME = {
@@ -44,7 +45,7 @@ class LossMetadata:
4445
"SoftClDiceLoss": LossMetadata("SoftClDiceLoss", spatial_weight_arg="weight"),
4546
"WeightedMSELoss": LossMetadata("WeightedMSELoss", spatial_weight_arg="weight"),
4647
"WeightedMAELoss": LossMetadata("WeightedMAELoss", spatial_weight_arg="weight"),
47-
"MalisLoss": LossMetadata("MalisLoss", spatial_weight_arg="mask"),
48+
"MalisLoss": LossMetadata("MalisLoss", spatial_weight_arg="mask", gt_seg_arg="gt_seg"),
4849
# GAN is not compatible with the generic supervised orchestrator path
4950
"GANLoss": LossMetadata("GANLoss", call_kind="unsupported", target_kind="none"),
5051
# Regularization losses

connectomics/training/lightning/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ def _compute_loss(
845845
stage: str,
846846
mask: Optional[torch.Tensor] = None,
847847
target_mask: Optional[torch.Tensor] = None,
848+
gt_seg: Optional[torch.Tensor] = None,
848849
):
849850
"""Compute loss handling both standard and deep supervision outputs."""
850851
loss_orchestrator = self._require_loss_orchestrator()
@@ -853,10 +854,10 @@ def _compute_loss(
853854
)
854855
if is_deep_supervision:
855856
return loss_orchestrator.compute_deep_supervision_loss(
856-
outputs, labels, stage=stage, mask=mask, target_mask=target_mask
857+
outputs, labels, stage=stage, mask=mask, target_mask=target_mask, gt_seg=gt_seg
857858
)
858859
return loss_orchestrator.compute_standard_loss(
859-
outputs, labels, stage=stage, mask=mask, target_mask=target_mask
860+
outputs, labels, stage=stage, mask=mask, target_mask=target_mask, gt_seg=gt_seg
860861
)
861862

862863
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT:
@@ -868,13 +869,14 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_
868869
# Binarize mask: (B, 1, D, H, W) float, 1 = valid, 0 = ignore
869870
mask = (raw_mask > 0).float() if raw_mask is not None else None
870871
target_mask = batch.get("label_mask", None)
872+
gt_seg = batch.get("gt_seg", None)
871873

872874
# Forward pass
873875
outputs = self(images)
874876

875877
# Compute loss using the loss orchestrator
876878
total_loss, loss_dict = self._compute_loss(
877-
outputs, labels, stage="train", mask=mask, target_mask=target_mask
879+
outputs, labels, stage="train", mask=mask, target_mask=target_mask, gt_seg=gt_seg
878880
)
879881

880882
# Keep full training curves in TensorBoard while avoiding console spam.
@@ -896,13 +898,14 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
896898
raw_mask = batch.get("mask", None)
897899
mask = (raw_mask > 0).float() if raw_mask is not None else None
898900
target_mask = batch.get("label_mask", None)
901+
gt_seg = batch.get("gt_seg", None)
899902

900903
# Forward pass
901904
outputs = self(images)
902905

903906
# Compute loss using the loss orchestrator
904907
total_loss, loss_dict = self._compute_loss(
905-
outputs, labels, stage="val", mask=mask, target_mask=target_mask
908+
outputs, labels, stage="val", mask=mask, target_mask=target_mask, gt_seg=gt_seg
906909
)
907910

908911
# Compute evaluation metrics if enabled

0 commit comments

Comments
 (0)