We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4413807 commit b54aa84Copy full SHA for b54aa84
1 file changed
src/thunder/utils/dice_loss.py
@@ -29,13 +29,15 @@ def multiclass_dice_loss(
29
target = target.permute((0, 3, 1, 2))
30
mask = mask.unsqueeze(1)
31
32
- intersection = (pred * target * mask).sum(dim=(2, 3)) # Element-wise multiplication
33
- union = (pred * mask).sum(dim=(2, 3)) + (target * mask).sum(
34
- dim=(2, 3)
+ intersection = (pred * target * mask).sum(
+ dim=(0, 2, 3)
+ ) # Element-wise multiplication
35
+ union = (pred * mask).sum(dim=(0, 2, 3)) + (target * mask).sum(
36
37
) # Sum of all pixels
38
39
dice = (2.0 * intersection + smooth) / (
40
union + smooth
41
) # Per-class and per-image Dice score
- dice = dice.mean(dim=0) # Averaging Dice loss across images
42
+
43
return 1 - dice.mean() # Averaging Dice Loss across classes
0 commit comments