Skip to content

Commit b54aa84

Browse files
committed
Updated dice loss to compute IoU over batch instead of per-sample
1 parent 4413807 commit b54aa84

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

src/thunder/utils/dice_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ def multiclass_dice_loss(
2929
target = target.permute((0, 3, 1, 2))
3030
mask = mask.unsqueeze(1)
3131

32-
intersection = (pred * target * mask).sum(dim=(2, 3)) # Element-wise multiplication
33-
union = (pred * mask).sum(dim=(2, 3)) + (target * mask).sum(
34-
dim=(2, 3)
32+
intersection = (pred * target * mask).sum(
33+
dim=(0, 2, 3)
34+
) # Element-wise multiplication
35+
union = (pred * mask).sum(dim=(0, 2, 3)) + (target * mask).sum(
36+
dim=(0, 2, 3)
3537
) # Sum of all pixels
3638

3739
dice = (2.0 * intersection + smooth) / (
3840
union + smooth
3941
) # Per-class and per-image Dice score
40-
dice = dice.mean(dim=0) # Averaging Dice loss across images
42+
4143
return 1 - dice.mean() # Averaging Dice Loss across classes

0 commit comments

Comments
 (0)