Skip to content

Commit 9c1fa24

Browse files
committed
feat: implement ignore_index support in metrics and losses with dedicated unit tests
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent d18565f commit 9c1fa24

14 files changed

Lines changed: 686 additions & 192 deletions

monai/losses/dice.py

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
batch: bool = False,
6969
weight: Sequence[float] | float | int | torch.Tensor | None = None,
7070
soft_label: bool = False,
71+
ignore_index: int | None = None,
7172
) -> None:
7273
"""
7374
Args:
@@ -101,7 +102,8 @@ def __init__(
101102
The value/values should be no less than 0. Defaults to None.
102103
soft_label: whether the target contains non-binary values (soft labels) or not.
103104
If True a soft label formulation of the loss will be used.
104-
105+
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
106+
the input gradient. Defaults to None.
105107
Raises:
106108
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
107109
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
@@ -123,6 +125,7 @@ def __init__(
123125
self.smooth_nr = float(smooth_nr)
124126
self.smooth_dr = float(smooth_dr)
125127
self.batch = batch
128+
self.ignore_index = ignore_index
126129
weight = torch.as_tensor(weight) if weight is not None else None
127130
self.register_buffer("class_weight", weight)
128131
self.class_weight: None | torch.Tensor
@@ -140,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
140143
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
141144
142145
Example:
143-
>>> from monai.losses.dice import * # NOQA
146+
>>> from monai.losses.dice import * # NOQA
144147
>>> import torch
145148
>>> from monai.losses.dice import DiceLoss
146149
>>> B, C, H, W = 7, 5, 3, 2
@@ -164,6 +167,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
164167
if self.other_act is not None:
165168
input = self.other_act(input)
166169

170+
# mask the ignore_index if specified, must be done before one_hot
171+
mask: torch.Tensor | None = None
172+
if self.ignore_index is not None:
173+
mask = (target != self.ignore_index).float()
174+
167175
if self.to_onehot_y:
168176
if n_pred_ch == 1:
169177
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
@@ -181,6 +189,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
181189
if target.shape != input.shape:
182190
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
183191

192+
if mask is not None:
193+
input = input * mask
194+
target = target * mask
195+
184196
# reducing only spatial dimensions (not batch nor channels)
185197
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
186198
if self.batch:
@@ -204,11 +216,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
204216
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
205217
else:
206218
if self.class_weight.shape[0] != num_of_classes:
207-
raise ValueError(
208-
"""the length of the `weight` sequence should be the same as the number of classes.
219+
raise ValueError("""the length of the `weight` sequence should be the same as the number of classes.
209220
If `include_background=False`, the weight should not include
210-
the background category class 0."""
211-
)
221+
the background category class 0.""")
212222
if self.class_weight.min() < 0:
213223
raise ValueError("the value/values of the `weight` should be no less than 0.")
214224
# apply class_weight to loss
@@ -280,6 +290,7 @@ def __init__(
280290
smooth_dr: float = 1e-5,
281291
batch: bool = False,
282292
soft_label: bool = False,
293+
ignore_index: int | None = None,
283294
) -> None:
284295
"""
285296
Args:
@@ -305,6 +316,8 @@ def __init__(
305316
If True, the class-weighted intersection and union areas are first summed across the batches.
306317
soft_label: whether the target contains non-binary values (soft labels) or not.
307318
If True a soft label formulation of the loss will be used.
319+
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
320+
the input gradient.
308321
309322
Raises:
310323
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -330,6 +343,7 @@ def __init__(
330343
self.smooth_dr = float(smooth_dr)
331344
self.batch = batch
332345
self.soft_label = soft_label
346+
self.ignore_index = ignore_index
333347

334348
def w_func(self, grnd):
335349
if self.w_type == str(Weight.SIMPLE):
@@ -360,6 +374,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
360374
if self.other_act is not None:
361375
input = self.other_act(input)
362376

377+
# Prepare mask before potential one-hot conversion
378+
mask: torch.Tensor | None = None
379+
if self.ignore_index is not None:
380+
mask = (target != self.ignore_index).float()
381+
363382
if self.to_onehot_y:
364383
if n_pred_ch == 1:
365384
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
@@ -370,14 +389,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
370389
if n_pred_ch == 1:
371390
warnings.warn("single channel prediction, `include_background=False` ignored.")
372391
else:
373-
# if skipping background, removing first channel
374392
target = target[:, 1:]
375393
input = input[:, 1:]
376394

377395
if target.shape != input.shape:
378396
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
379397

380-
# reducing only spatial dimensions (not batch nor channels)
398+
# Exclude ignored regions from calculations
399+
if mask is not None:
400+
input = input * mask
401+
target = target * mask
402+
381403
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
382404
if self.batch:
383405
reduce_axis = [0] + reduce_axis
@@ -404,12 +426,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
404426
f: torch.Tensor = 1.0 - (numer / denom)
405427

406428
if self.reduction == LossReduction.MEAN.value:
407-
f = torch.mean(f) # the batch and channel average
429+
f = torch.mean(f)
408430
elif self.reduction == LossReduction.SUM.value:
409-
f = torch.sum(f) # sum over the batch and channel dims
431+
f = torch.sum(f)
410432
elif self.reduction == LossReduction.NONE.value:
411-
# If we are not computing voxelwise loss components at least
412-
# make sure a none reduction maintains a broadcastable shape
413433
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
414434
f = f.view(broadcast_shape)
415435
else:
@@ -442,11 +462,12 @@ def __init__(
442462
reduction: LossReduction | str = LossReduction.MEAN,
443463
smooth_nr: float = 1e-5,
444464
smooth_dr: float = 1e-5,
465+
ignore_index: int | None = None,
445466
) -> None:
446467
"""
447468
Args:
448469
dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.
449-
It must have dimension C x C where C is the number of classes.
470+
It must have dimension C x C where C is the number of classes.
450471
weighting_mode: {``"default"``, ``"GDL"``}
451472
Specifies how to weight the class-specific sum of errors.
452473
Default to ``"default"``.
@@ -466,35 +487,19 @@ def __init__(
466487
- ``"sum"``: the output will be summed.
467488
smooth_nr: a small constant added to the numerator to avoid zero.
468489
smooth_dr: a small constant added to the denominator to avoid nan.
490+
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
491+
the input gradient.
469492
470493
Raises:
471494
ValueError: When ``dist_matrix`` is not a square matrix.
472-
473-
Example:
474-
.. code-block:: python
475-
476-
import torch
477-
import numpy as np
478-
from monai.losses import GeneralizedWassersteinDiceLoss
479-
480-
# Example with 3 classes (including the background: label 0).
481-
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
482-
# The distance between class 1 and class 2 is 0.5.
483-
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
484-
wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
485-
486-
pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)
487-
grnd = torch.tensor([0, 1, 2], dtype=torch.int64)
488-
wass_loss(pred_score, grnd) # 0
489-
490495
"""
491496
super().__init__(reduction=LossReduction(reduction).value)
492497

493498
if dist_matrix.shape[0] != dist_matrix.shape[1]:
494499
raise ValueError(f"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.")
495500

496501
if weighting_mode not in ["default", "GDL"]:
497-
raise ValueError("weighting_mode must be either 'default' or 'GDL, got %s." % weighting_mode)
502+
raise ValueError(f"weighting_mode must be either 'default' or 'GDL', got {weighting_mode}.")
498503

499504
self.m = dist_matrix
500505
if isinstance(self.m, np.ndarray):
@@ -505,13 +510,13 @@ def __init__(
505510
self.num_classes = self.m.size(0)
506511
self.smooth_nr = float(smooth_nr)
507512
self.smooth_dr = float(smooth_dr)
513+
self.ignore_index = ignore_index
508514

509515
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
510516
"""
511517
Args:
512518
input: the shape should be BNH[WD].
513519
target: the shape should be BNH[WD].
514-
515520
"""
516521
# Aggregate spatial dimensions
517522
flat_input = input.reshape(input.size(0), input.size(1), -1)
@@ -523,18 +528,20 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
523528
# Compute the Wasserstein distance map
524529
wass_dist_map = self.wasserstein_distance_map(probs, flat_target)
525530

531+
# Apply masking for ignore_index
532+
if self.ignore_index is not None:
533+
mask = (flat_target != self.ignore_index).float()
534+
wass_dist_map = wass_dist_map * mask
535+
526536
# Compute the values of alpha to use
527537
alpha = self._compute_alpha_generalized_true_positives(flat_target)
528538

529539
# Compute the numerator and denominator of the generalized Wasserstein Dice loss
530540
if self.alpha_mode == "GDL":
531541
# use GDL-style alpha weights (i.e. normalize by the volume of each class)
532-
# contrary to the original definition we also use alpha in the "generalized all error".
533542
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
534543
denom = self._compute_denominator(alpha, flat_target, wass_dist_map)
535544
else: # default: as in the original paper
536-
# (i.e. alpha=1 for all foreground classes and 0 for the background).
537-
# Compute the generalised number of true positives
538545
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
539546
all_error = torch.sum(wass_dist_map, dim=1)
540547
denom = 2 * true_pos + all_error
@@ -544,12 +551,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
544551
wass_dice_loss: torch.Tensor = 1.0 - wass_dice
545552

546553
if self.reduction == LossReduction.MEAN.value:
547-
wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average
554+
wass_dice_loss = torch.mean(wass_dice_loss)
548555
elif self.reduction == LossReduction.SUM.value:
549-
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
556+
wass_dice_loss = torch.sum(wass_dice_loss)
550557
elif self.reduction == LossReduction.NONE.value:
551-
# If we are not computing voxelwise loss components at least
552-
# make sure a none reduction maintains a broadcastable shape
553558
broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2)
554559
wass_dice_loss = wass_dice_loss.view(broadcast_shape)
555560
else:
@@ -674,6 +679,7 @@ def __init__(
674679
lambda_dice: float = 1.0,
675680
lambda_ce: float = 1.0,
676681
label_smoothing: float = 0.0,
682+
ignore_index: int | None = None,
677683
) -> None:
678684
"""
679685
Args:
@@ -715,6 +721,8 @@ def __init__(
715721
label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
716722
by the given factor to reduce overfitting.
717723
Defaults to 0.0.
724+
ignore_index: if not None, specifies a target index that is ignored and does not contribute to
725+
the input gradient.
718726
719727
"""
720728
super().__init__()
@@ -737,15 +745,22 @@ def __init__(
737745
smooth_dr=smooth_dr,
738746
batch=batch,
739747
weight=dice_weight,
748+
ignore_index=ignore_index,
749+
)
750+
self.cross_entropy = nn.CrossEntropyLoss(
751+
weight=weight,
752+
reduction=reduction,
753+
label_smoothing=label_smoothing,
754+
ignore_index=ignore_index if ignore_index is not None else -100,
740755
)
741-
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
742756
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
743757
if lambda_dice < 0.0:
744758
raise ValueError("lambda_dice should be no less than 0.0.")
745759
if lambda_ce < 0.0:
746760
raise ValueError("lambda_ce should be no less than 0.0.")
747761
self.lambda_dice = lambda_dice
748762
self.lambda_ce = lambda_ce
763+
self.ignore_index = ignore_index
749764

750765
def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
751766
"""
@@ -801,7 +816,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
801816
)
802817

803818
dice_loss = self.dice(input, target)
804-
ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)
819+
820+
if input.shape[1] != 1:
821+
# CrossEntropyLoss handles ignore_index natively
822+
ce_loss = self.ce(input, target)
823+
else:
824+
# BCEWithLogitsLoss does not support ignore_index, handle manually
825+
ce_loss = self.bce(input, target)
826+
if self.ignore_index is not None:
827+
mask = (target != self.ignore_index).float()
828+
ce_loss = ce_loss * mask
829+
if self.dice.reduction == "mean":
830+
ce_loss = torch.mean(ce_loss)
831+
elif self.dice.reduction == "sum":
832+
ce_loss = torch.sum(ce_loss)
833+
805834
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss
806835

807836
return total_loss

0 commit comments

Comments
 (0)