Skip to content

Commit 5f6d6e7

Browse files
committed
Fixed handling of ignored pixels in dice loss
1 parent f4dc382 commit 5f6d6e7

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

src/thunder/tasks/train_eval_probe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def train_eval(
661661
)
662662

663663
# Finding background-only masks
664-
bg_only = np.array([l.sum().item() == 0 for l in all_label])
664+
bg_only = np.array([l.sum().item() == 0 for l in all_label if len(l) > 0])
665665
freq_bg_only = bg_only.sum().item() / len(bg_only)
666666
no_bg_only_weight = max(
667667
1.0, freq_bg_only * cfg.task.no_bg_only_weight_test

src/thunder/utils/dice_loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
# Adapted from https://medium.com/data-scientists-diary/implementation-of-dice-loss-vision-pytorch-7eef1e438f68
66
def multiclass_dice_loss(
7-
pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, smooth: float = 1
7+
pred: torch.Tensor, label: torch.Tensor, mask: torch.Tensor, smooth: float = 1
88
) -> torch.Tensor:
99
"""
1010
Computes Dice Loss for multi-class segmentation.
1111
:param pred: Tensor of predictions (B, C, H, W).
12-
:param target: Ground truth labels (B, H, W).
12+
:param label: Ground truth labels (B, H, W).
1313
:param mask: Mask to apply to pred and target.
1414
:param smooth: Smoothing factor.
1515
@@ -19,6 +19,7 @@ def multiclass_dice_loss(
1919
pred = F.softmax(pred, dim=1) # Converting logits to probabilities
2020
num_classes = pred.shape[1] # Number of classes (C)
2121

22+
target = label.clone()
2223
target[~mask] = (
2324
num_classes # Adding a dummy class to account for masked pixels (-1 label values)
2425
)

0 commit comments

Comments
 (0)