@@ -26,6 +26,16 @@ class MalisLoss(nn.Module):
2626 2D tensors are rejected explicitly because the vendored MALIS helpers operate
2727 on 3D affinity graphs by default. See ``lib/malis/INVESTIGATION.md`` for
2828 GPU MALIS candidates and algorithm-level speedup follow-ups.
29+
30+ Performance knobs (see ``docs/source/notes/malis.rst``):
31+
32+ - ``malis_crop_size`` — random sub-volume crop on each forward call.
33+ ``64`` on a ``128^3`` patch gives ~4.6x measured step speedup vs
34+ the full-volume baseline (slurm 2505814 vs 2487040).
35+ - ``label_transform.emit_gt_seg: true`` (YAML, paired with this
36+ loss) — passes the eroded GT segmentation in via ``gt_seg=...``,
37+ skipping the per-step ``connected_components_affgraph`` call and
38+ preserving global instance IDs when ``malis_crop_size`` is set.
2939 """
3040
3141 def __init__ (
@@ -68,6 +78,7 @@ def forward(
6878 pred : torch .Tensor ,
6979 target : torch .Tensor ,
7080 mask : torch .Tensor | None = None ,
81+ gt_seg : torch .Tensor | np .ndarray | None = None ,
7182 ) -> torch .Tensor :
7283 """Compute MALIS-weighted squared affinity error.
7384
@@ -83,21 +94,31 @@ def forward(
8394 Masked-out edges are excluded from MALIS pass constraints and
8495 zeroed before per-pass normalization, but the mask does not
8596 change GT connected-component reconstruction.
97+ gt_seg: Optional ground-truth segmentation with shape ``[B, Z, Y, X]``
98+ or ``[B, 1, Z, Y, X]``. When supplied, MALIS uses these instance
99+ labels directly instead of reconstructing components from
100+ ``target`` affinities.
86101 """
87102 self ._validate_inputs (pred , target )
88103
89104 pred_aff = torch .sigmoid (pred ) if self .sigmoid else pred
90105 target_aff = target .to (device = pred .device , dtype = pred_aff .dtype )
91106 mask_aff = None if mask is None else self ._prepare_mask (mask , pred_aff )
92- pred_aff , target_aff , mask_aff = self ._apply_crop_if_configured (
107+ gt_seg_tensor = self ._prepare_gt_seg (gt_seg , pred_aff )
108+ pred_aff , target_aff , mask_aff , gt_seg_tensor = self ._apply_crop_if_configured (
93109 pred_aff ,
94110 target_aff ,
95111 mask_aff ,
112+ gt_seg_tensor ,
96113 )
114+ weight_kwargs = {}
115+ if gt_seg_tensor is not None :
116+ weight_kwargs ["gt_seg" ] = gt_seg_tensor .detach ()
97117 weights = self ._compute_malis_weights (
98118 pred_aff .detach (),
99119 target_aff .detach (),
100120 None if mask_aff is None else mask_aff .detach (),
121+ ** weight_kwargs ,
101122 )
102123
103124 edge_loss = (pred_aff - target_aff ) ** 2
@@ -211,12 +232,36 @@ def _prepare_mask(self, mask: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
211232 f"mask={ tuple (mask .shape )} , pred={ tuple (pred .shape )} ."
212233 ) from e
213234
235+ def _prepare_gt_seg (
236+ self ,
237+ gt_seg : torch .Tensor | np .ndarray | None ,
238+ pred : torch .Tensor ,
239+ ) -> torch .Tensor | None :
240+ if gt_seg is None :
241+ return None
242+
243+ gt_seg_tensor = torch .as_tensor (gt_seg , device = pred .device ).detach ()
244+ if gt_seg_tensor .ndim == pred .ndim and gt_seg_tensor .shape [1 ] == 1 :
245+ gt_seg_tensor = gt_seg_tensor .squeeze (1 )
246+ elif gt_seg_tensor .ndim == pred .ndim - 2 and pred .shape [0 ] == 1 :
247+ gt_seg_tensor = gt_seg_tensor .unsqueeze (0 )
248+
249+ expected_shape = (pred .shape [0 ],) + tuple (pred .shape [- 3 :])
250+ if tuple (gt_seg_tensor .shape ) != expected_shape :
251+ raise ValueError (
252+ "MalisLoss gt_seg must have shape [B, Z, Y, X] or [B, 1, Z, Y, X] "
253+ f"matching pred spatial dims; got gt_seg={ tuple (gt_seg_tensor .shape )} , "
254+ f"expected={ expected_shape } ."
255+ )
256+ return gt_seg_tensor .contiguous ()
257+
214258 def _apply_crop_if_configured (
215259 self ,
216260 pred_aff : torch .Tensor ,
217261 target_aff : torch .Tensor ,
218262 mask_aff : torch .Tensor | None ,
219- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor | None ]:
263+ gt_seg : torch .Tensor | None ,
264+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor | None , torch .Tensor | None ]:
220265 """Apply the configured random sub-volume crop, if any.
221266
222267 Offset sampling stays on CPU. The returned tensors are contiguous copies
@@ -226,11 +271,11 @@ def _apply_crop_if_configured(
226271 crop=64, and fp16, pred + target + mask copies are about 9 MiB before
227272 overhead.
228273
229- Returns ``(pred_cropped, target_cropped, mask_cropped)``. If no crop is
230- configured the inputs are returned unchanged.
274+ Returns ``(pred_cropped, target_cropped, mask_cropped, gt_seg_cropped )``.
275+ If no crop is configured the inputs are returned unchanged.
231276 """
232277 if self .malis_crop_size is None :
233- return pred_aff , target_aff , mask_aff
278+ return pred_aff , target_aff , mask_aff , gt_seg
234279
235280 k_z , k_y , k_x = self .malis_crop_size
236281 z_dim , y_dim , x_dim = pred_aff .shape [- 3 :]
@@ -253,29 +298,40 @@ def _apply_crop_if_configured(
253298 if mask_aff is None
254299 else mask_aff .narrow (- 3 , z0 , k_z ).narrow (- 2 , y0 , k_y ).narrow (- 1 , x0 , k_x ).contiguous ()
255300 )
256- return pred_c , target_c , mask_c
301+ gt_seg_c = (
302+ None
303+ if gt_seg is None
304+ else gt_seg .narrow (- 3 , z0 , k_z ).narrow (- 2 , y0 , k_y ).narrow (- 1 , x0 , k_x ).contiguous ()
305+ )
306+ return pred_c , target_c , mask_c , gt_seg_c
257307
258308 def _compute_malis_weights (
259309 self ,
260310 pred_aff : torch .Tensor ,
261311 target_aff : torch .Tensor ,
262312 mask : torch .Tensor | None = None ,
313+ * ,
314+ gt_seg : torch .Tensor | None = None ,
263315 ) -> torch .Tensor :
264316 pred_np = pred_aff .to (dtype = torch .float32 ).cpu ().numpy ()
265317 target_np = target_aff .to (dtype = torch .float32 ).cpu ().numpy ()
266318 mask_np = None if mask is None else mask .to (dtype = torch .float32 ).cpu ().numpy ()
319+ gt_seg_np = None if gt_seg is None else gt_seg .cpu ().numpy ()
267320 weights = np .empty_like (pred_np , dtype = np .float32 )
268321 for batch_idx in range (pred_np .shape [0 ]):
269322 gt_affs = np .ascontiguousarray (target_np [batch_idx ] > 0.5 , dtype = np .int32 )
270323 pred_sample = np .ascontiguousarray (pred_np [batch_idx ], dtype = np .float32 )
271324 mask_sample = None
272325 if mask_np is not None :
273326 mask_sample = np .ascontiguousarray (mask_np [batch_idx ] == 1 , dtype = bool )
274- gt_seg , _ = _malis_lib .connected_components_affgraph (gt_affs , self .nhood )
327+ if gt_seg_np is None :
328+ gt_seg_sample , _ = _malis_lib .connected_components_affgraph (gt_affs , self .nhood )
329+ else :
330+ gt_seg_sample = np .ascontiguousarray (gt_seg_np [batch_idx ], dtype = np .uint64 )
275331 weights [batch_idx ] = self ._compute_sample_weights (
276332 pred_sample ,
277333 gt_affs ,
278- gt_seg ,
334+ gt_seg_sample ,
279335 mask_sample ,
280336 )
281337
0 commit comments