Skip to content

Commit 8b2e316

Browse files
committed
perf: Vectorize DiceHelper.__call__()
Replace nested batch/channel loops with vectorized torch operations. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent d53eb00 commit 8b2e316

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

monai/metrics/meandice.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,61 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
322322
y_pred = torch.sigmoid(y_pred)
323323
y_pred = y_pred > 0.5
324324

325+
# Vectorized computation (replaces nested loops for better performance)
326+
batch_size = y_pred.shape[0]
327+
device = y_pred.device
328+
329+
# Convert to boolean for computation
330+
if y_pred.shape[1] == 1 and n_pred_ch > 1:
331+
# Single-channel class indices: convert to one-hot
332+
y_pred_bool = torch.zeros(batch_size, n_pred_ch, *y_pred.shape[2:], dtype=torch.bool, device=device)
333+
y_bool = torch.zeros(batch_size, n_pred_ch, *y.shape[2:], dtype=torch.bool, device=device)
334+
335+
for c in range(n_pred_ch):
336+
y_pred_bool[:, c] = (y_pred[:, 0] == c)
337+
y_bool[:, c] = (y[:, 0] == c)
338+
else:
339+
# One-hot format: cast to bool
340+
y_pred_bool = y_pred.bool()
341+
if y.shape[1] == 1 and y_pred.shape[1] > 1:
342+
# Expand y to match y_pred channels
343+
y_bool = (y == 1).expand(batch_size, n_pred_ch, *y.shape[2:])
344+
else:
345+
y_bool = y.bool()
346+
347+
# Flatten spatial dimensions for vectorized computation: (batch, channels, -1)
348+
y_pred_flat = y_pred_bool.reshape(batch_size, n_pred_ch, -1).float()
349+
y_flat = y_bool.reshape(batch_size, n_pred_ch, -1).float()
350+
351+
# Compute Dice per (batch, channel) vectorized: all reductions at once
352+
intersection = torch.sum(y_pred_flat * y_flat, dim=-1) # (batch, n_pred_ch)
353+
pred_sum = torch.sum(y_pred_flat, dim=-1) # (batch, n_pred_ch)
354+
y_sum = torch.sum(y_flat, dim=-1) # (batch, n_pred_ch)
355+
356+
# Dice formula: 2 * intersection / (pred_sum + y_sum)
357+
union = pred_sum + y_sum
358+
dice = (2.0 * intersection) / union # (batch, n_pred_ch)
359+
360+
# Handle empty ground truth cases
361+
if self.ignore_empty:
362+
# Set NaN where ground truth is empty
363+
dice = torch.where(y_sum > 0, dice, torch.tensor(float("nan"), device=device, dtype=dice.dtype))
364+
else:
365+
# Set 1.0 if both empty, 0.0 if only pred is non-empty
366+
empty_mask = y_sum == 0
367+
dice = torch.where(
368+
empty_mask,
369+
torch.where(pred_sum == 0, torch.tensor(1.0, device=device, dtype=dice.dtype),
370+
torch.tensor(0.0, device=device, dtype=dice.dtype)),
371+
dice
372+
)
373+
374+
# Select channels: exclude background if requested
325375
first_ch = 0 if self.include_background else 1
326-
data = []
327-
for b in range(y_pred.shape[0]):
328-
c_list = []
329-
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
330-
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
331-
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
332-
c_list.append(self.compute_channel(x_pred, x))
333-
data.append(torch.stack(c_list))
334-
data = torch.stack(data, dim=0).contiguous() # type: ignore
376+
if n_pred_ch > 1:
377+
data = dice[:, first_ch:] # (batch, num_classes_selected)
378+
else:
379+
data = dice # (batch, 1)
335380

336381
f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
337382
return (f, not_nans) if self.get_not_nans else f

0 commit comments

Comments
 (0)