|
| 1 | +import torch |
1 | 2 | import torch.nn.functional as F |
2 | 3 |
|
3 | 4 |
|
4 | 5 | # Adapted from https://medium.com/data-scientists-diary/implementation-of-dice-loss-vision-pytorch-7eef1e438f68 |
5 | | -def multiclass_dice_loss(pred, target, smooth=1): |
| 6 | +def multiclass_dice_loss( |
| 7 | + pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, smooth: float = 1 |
| 8 | +) -> torch.Tensor: |
6 | 9 | """ |
7 | 10 | Computes Dice Loss for multi-class segmentation. |
8 | | - :param pred: Tensor of predictions (batch_size, C, H, W). |
9 | | - :param target: One-hot encoded ground truth (batch_size, C, H, W). |
| 11 | + :param pred: Tensor of predictions (B, C, H, W). |
| 12 | + :param target: Ground truth labels (B, H, W). |
| 13 | + :param mask: Mask to apply to pred and target. |
10 | 14 | :param smooth: Smoothing factor. |
11 | 15 |
|
12 | 16 | :return: Scalar Dice Loss. |
13 | 17 | """ |
14 | | - pred = F.softmax(pred, dim=1) # Convert logits to probabilities |
15 | | - num_classes = pred.shape[1] # Number of classes (C) |
16 | | - dice = 0 # Initialize Dice loss accumulator |
17 | 18 |
|
18 | | - for c in range(num_classes): # Loop through each class |
19 | | - pred_c = pred[:, c] # Predictions for class c |
20 | | - target_c = (target == c).long() # Ground truth for class c |
| 19 | + pred = F.softmax(pred, dim=1) # Converting logits to probabilities |
| 20 | + num_classes = pred.shape[1] # Number of classes (C) |
21 | 21 |
|
22 | | - intersection = (pred_c * target_c).sum() # Element-wise multiplication |
23 | | - union = pred_c.sum() + target_c.sum() # Sum of all pixels |
| 22 | + target[~mask] = ( |
| 23 | + num_classes # Adding a dummy class to account for masked pixels (-1 label values) |
| 24 | + ) |
| 25 | + target = F.one_hot( |
| 26 | + target, num_classes=num_classes + 1 |
| 27 | + ) # Creating a tensor of one-hot target vectors |
| 28 | + target = target[..., :-1] # Removing dummy class channel |
| 29 | + target = target.permute((0, 3, 1, 2)) |
| 30 | + mask = mask.unsqueeze(1) |
24 | 31 |
|
25 | | - dice += (2.0 * intersection + smooth) / (union + smooth) # Per-class Dice score |
| 32 | + intersection = (pred * target * mask).sum(dim=(2, 3)) # Element-wise multiplication |
| 33 | + union = (pred * mask).sum(dim=(2, 3)) + (target * mask).sum( |
| 34 | + dim=(2, 3) |
| 35 | + ) # Sum of all pixels |
26 | 36 |
|
27 | | - return 1 - dice.mean() / num_classes # Average Dice Loss across classes |
| 37 | + dice = (2.0 * intersection + smooth) / ( |
| 38 | + union + smooth |
| 39 | + ) # Per-class and per-image Dice score |
| 40 | + dice = dice.mean(dim=0) # Averaging Dice loss across images |
| 41 | + return 1 - dice.mean() # Averaging Dice Loss across classes |
0 commit comments