Skip to content

Commit 4b1777f

Browse files
Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss (#8744)
**Fixes #4650** ### Description When `batch_size > 1`, `GeneralizedWassersteinDiceLoss` produces incorrect loss values because of a tensor broadcasting issue in `_compute_generalized_true_positive` and `_compute_denominator`. After `torch.gather`, `alpha_extended` has shape `(B, 1, S)` while `wasserstein_distance_map` has shape `(B, S)`. The element-wise multiply silently broadcasts to `(B, B, S)`, which mixes values across batch samples. This means the loss has always been wrong for any training run with `batch_size > 1`. The fix follows the [reference implementation](https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss) by the original paper's author — squeeze `dim=1` after the gather so both tensors are `(B, S)`, and reduce with `dim=1` instead of `dim=[1, 2]`. I also noticed that `reduction="none"` was broken (never had test coverage) — it tried to reshape the per-sample loss `(B,)` into `(B, C, 1, ...)`, but GWDL aggregates over classes internally so the class dimension doesn't exist in the output. Fixed that as well. ### Changes - `monai/losses/dice.py`: squeeze + dim fix in `_compute_generalized_true_positive` and `_compute_denominator`; fixed `reduction="none"` path - `tests/losses/test_generalized_wasserstein_dice_loss.py`: two new regression tests for batch consistency ### Tests All existing tests pass. The new regression tests fail on unpatched code and pass with the fix. --------- Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com> Signed-off-by: Jeffrey Qiu <hongjie.qiu@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 0b0a840 commit 4b1777f

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

monai/losses/dice.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
548548
elif self.reduction == LossReduction.SUM.value:
549549
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
550550
elif self.reduction == LossReduction.NONE.value:
551-
# If we are not computing voxelwise loss components at least
552-
# make sure a none reduction maintains a broadcastable shape
553-
broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2)
554-
wass_dice_loss = wass_dice_loss.view(broadcast_shape)
551+
# GWDL aggregates over classes internally, so wass_dice_loss has shape (B,)
552+
pass
555553
else:
556554
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
557555

@@ -609,8 +607,9 @@ def _compute_generalized_true_positive(
609607
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
610608
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
611609
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
610+
alpha_extended = torch.squeeze(alpha_extended, dim=1)
612611

613-
return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2])
612+
return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1)
614613

615614
def _compute_denominator(
616615
self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
@@ -626,8 +625,9 @@ def _compute_denominator(
626625
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
627626
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
628627
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
628+
alpha_extended = torch.squeeze(alpha_extended, dim=1)
629629

630-
return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2])
630+
return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1)
631631

632632
def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor:
633633
"""

tests/losses/test_generalized_wasserstein_dice_loss.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,91 @@ def forward(self, x):
218218
# check that the predicted segmentation has improved
219219
self.assertGreater(diff_start, diff_end)
220220

221+
def test_batch_size_greater_than_one(self):
222+
"""
223+
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
224+
With M=identity and batch_size > 1, the GWDL should produce the same
225+
per-sample loss values as with batch_size=1.
226+
"""
227+
target_single = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
228+
target_single = target_single.unsqueeze(0) # shape (1, H, W)
229+
pred_single = 1000 * F.one_hot(target_single, num_classes=2).permute(0, 3, 1, 2).float()
230+
231+
# Create a batch of size 2 by repeating the same sample
232+
target_batch = target_single.repeat(2, 1, 1) # shape (2, H, W)
233+
pred_batch = pred_single.repeat(2, 1, 1, 1) # shape (2, C, H, W)
234+
235+
for w_mode in ["default", "GDL"]:
236+
loss_fn = GeneralizedWassersteinDiceLoss(
237+
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none"
238+
)
239+
240+
loss_single = loss_fn(pred_single, target_single)
241+
loss_batch = loss_fn(pred_batch, target_batch)
242+
243+
# Each sample in the batch should produce the same loss as the single sample
244+
for i in range(2):
245+
self.assertAlmostEqual(
246+
float(loss_batch[i]),
247+
float(loss_single[0]),
248+
places=5,
249+
msg=f"Batch loss[{i}] != single loss for weighting_mode={w_mode}",
250+
)
251+
252+
# Also test with mean reduction using a non-trivial (poor) prediction
253+
# so the expected loss is not near zero
254+
pred_poor = 1000 * F.one_hot(1 - target_single, num_classes=2).permute(0, 3, 1, 2).float()
255+
pred_poor_batch = pred_poor.repeat(2, 1, 1, 1)
256+
257+
for w_mode in ["default", "GDL"]:
258+
loss_fn = GeneralizedWassersteinDiceLoss(
259+
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="mean"
260+
)
261+
262+
loss_single = float(loss_fn(pred_poor, target_single))
263+
loss_batch = float(loss_fn(pred_poor_batch, target_batch))
264+
265+
# Verify the loss is non-trivial (close to 1 for poor predictions)
266+
self.assertGreater(loss_single, 0.5, msg=f"Expected non-trivial loss for weighting_mode={w_mode}")
267+
self.assertAlmostEqual(
268+
loss_batch,
269+
loss_single,
270+
places=5,
271+
msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}",
272+
)
273+
274+
def test_batch_size_different_samples(self):
275+
"""
276+
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
277+
Verify loss is computed correctly when batch contains different samples.
278+
"""
279+
target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0)
280+
target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0)
281+
282+
pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float()
283+
# Use a poor prediction for sample b so its loss is non-trivial (~1.0)
284+
pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float()
285+
286+
# Combine into a batch
287+
target_batch = torch.cat([target_a, target_b], dim=0)
288+
pred_batch = torch.cat([pred_a, pred_b], dim=0)
289+
290+
for w_mode in ["default", "GDL"]:
291+
loss_fn = GeneralizedWassersteinDiceLoss(
292+
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none"
293+
)
294+
295+
loss_a = float(loss_fn(pred_a, target_a))
296+
loss_b = float(loss_fn(pred_b, target_b))
297+
loss_batch = loss_fn(pred_batch, target_batch)
298+
299+
self.assertAlmostEqual(
300+
float(loss_batch[0]), loss_a, places=5, msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}"
301+
)
302+
self.assertAlmostEqual(
303+
float(loss_batch[1]), loss_b, places=5, msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}"
304+
)
305+
221306
def test_script(self):
222307
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
223308

0 commit comments

Comments
 (0)