Skip to content

Commit a094c1f

Browse files
committed
Improved efficiency of segmentation loss computation
1 parent 46bb3e2 commit a094c1f

2 files changed

Lines changed: 42 additions & 27 deletions

File tree

src/thunder/tasks/train_eval_probe.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -581,27 +581,28 @@ def train_eval(
581581
outputs[i], (label.shape[1], label.shape[2]), mode="bilinear"
582582
)
583583

584-
# Applying masking (removing pixels where gt == -1)
585-
if task_type == "segmentation":
586-
unmasked_label = [l != -1 for l in label]
587-
label = [l[u] for l, u in zip(label, unmasked_label)]
588-
else:
584+
if task_type == "linear_probing":
589585
label = label.view(-1)
590586
loss = 0
591587
for i in range(len(outputs)):
592588
output = outputs[i]
593-
out = []
594589
if task_type == "segmentation":
595-
for o, m in zip(output, unmasked_label):
596-
out.append(
597-
torch.cat(
598-
[o[c][m].unsqueeze(-1) for c in range(o.shape[0])], dim=-1
590+
# Applying masking (removing pixels where gt == -1)
591+
curr_loss = criterion(output, label, label != -1)
592+
593+
if comp_metrics:
594+
unmasked_label = [l != -1 for l in label]
595+
label = [l[u] for l, u in zip(label, unmasked_label)]
596+
out = []
597+
for o, m in zip(output, unmasked_label):
598+
out.append(
599+
torch.cat(
600+
[o[c][m].unsqueeze(-1) for c in range(o.shape[0])],
601+
dim=-1,
602+
)
599603
)
600-
)
601-
curr_loss = sum([criterion(o, l) for o, l in zip(out, label)]) / len(
602-
out
603-
)
604604
else:
605+
out = []
605606
for c in range(output.shape[1]):
606607
out.append(output[:, c].unsqueeze(-1))
607608
out = torch.cat(out, dim=-1)

src/thunder/utils/dice_loss.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,41 @@
1+
import torch
12
import torch.nn.functional as F
23

34

45
# Adapted from https://medium.com/data-scientists-diary/implementation-of-dice-loss-vision-pytorch-7eef1e438f68
5-
def multiclass_dice_loss(pred, target, smooth=1):
6+
def multiclass_dice_loss(
7+
pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, smooth: float = 1
8+
) -> torch.Tensor:
69
"""
710
Computes Dice Loss for multi-class segmentation.
8-
:param pred: Tensor of predictions (batch_size, C, H, W).
9-
:param target: One-hot encoded ground truth (batch_size, C, H, W).
11+
:param pred: Tensor of predictions (B, C, H, W).
12+
:param target: Ground truth labels (B, H, W).
13+
:param mask: Mask to apply to pred and target.
1014
:param smooth: Smoothing factor.
1115
1216
:return: Scalar Dice Loss.
1317
"""
14-
pred = F.softmax(pred, dim=1) # Convert logits to probabilities
15-
num_classes = pred.shape[1] # Number of classes (C)
16-
dice = 0 # Initialize Dice loss accumulator
1718

18-
for c in range(num_classes): # Loop through each class
19-
pred_c = pred[:, c] # Predictions for class c
20-
target_c = (target == c).long() # Ground truth for class c
19+
pred = F.softmax(pred, dim=1) # Converting logits to probabilities
20+
num_classes = pred.shape[1] # Number of classes (C)
2121

22-
intersection = (pred_c * target_c).sum() # Element-wise multiplication
23-
union = pred_c.sum() + target_c.sum() # Sum of all pixels
22+
target[~mask] = (
23+
num_classes # Adding a dummy class to account for masked pixels (-1 label values)
24+
)
25+
target = F.one_hot(
26+
target, num_classes=num_classes + 1
27+
) # Creating a tensor of one-hot target vectors
28+
target = target[..., :-1] # Removing dummy class channel
29+
target = target.permute((0, 3, 1, 2))
30+
mask = mask.unsqueeze(1)
2431

25-
dice += (2.0 * intersection + smooth) / (union + smooth) # Per-class Dice score
32+
intersection = (pred * target * mask).sum(dim=(2, 3)) # Element-wise multiplication
33+
union = (pred * mask).sum(dim=(2, 3)) + (target * mask).sum(
34+
dim=(2, 3)
35+
) # Sum of all pixels
2636

27-
return 1 - dice.mean() / num_classes # Average Dice Loss across classes
37+
dice = (2.0 * intersection + smooth) / (
38+
union + smooth
39+
) # Per-class and per-image Dice score
40+
dice = dice.mean(dim=0) # Averaging Dice loss across images
41+
return 1 - dice.mean() # Averaging Dice Loss across classes

0 commit comments

Comments
 (0)