@@ -68,6 +68,7 @@ def __init__(
6868 batch : bool = False ,
6969 weight : Sequence [float ] | float | int | torch .Tensor | None = None ,
7070 soft_label : bool = False ,
71+ ignore_index : int | None = None ,
7172 ) -> None :
7273 """
7374 Args:
@@ -101,7 +102,8 @@ def __init__(
101102 The value/values should be no less than 0. Defaults to None.
102103 soft_label: whether the target contains non-binary values (soft labels) or not.
103104 If True a soft label formulation of the loss will be used.
104-
105+ ignore_index: if not None, specifies a target index that is ignored and does not contribute to
106+ the input gradient. Defaults to None.
105107 Raises:
106108 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
107109 ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
@@ -123,6 +125,7 @@ def __init__(
123125 self .smooth_nr = float (smooth_nr )
124126 self .smooth_dr = float (smooth_dr )
125127 self .batch = batch
128+ self .ignore_index = ignore_index
126129 weight = torch .as_tensor (weight ) if weight is not None else None
127130 self .register_buffer ("class_weight" , weight )
128131 self .class_weight : None | torch .Tensor
@@ -140,7 +143,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
140143 ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
141144
142145 Example:
143- >>> from monai.losses.dice import * # NOQA
146+ >>> from monai.losses.dice import * # NOQA
144147 >>> import torch
145148 >>> from monai.losses.dice import DiceLoss
146149 >>> B, C, H, W = 7, 5, 3, 2
@@ -164,6 +167,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
164167 if self .other_act is not None :
165168 input = self .other_act (input )
166169
170+ # mask the ignore_index if specified, must be done before one_hot
171+ mask : torch .Tensor | None = None
172+ if self .ignore_index is not None :
173+ mask = (target != self .ignore_index ).float ()
174+
167175 if self .to_onehot_y :
168176 if n_pred_ch == 1 :
169177 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
@@ -181,6 +189,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
181189 if target .shape != input .shape :
182190 raise AssertionError (f"ground truth has different shape ({ target .shape } ) from input ({ input .shape } )" )
183191
192+ if mask is not None :
193+ input = input * mask
194+ target = target * mask
195+
184196 # reducing only spatial dimensions (not batch nor channels)
185197 reduce_axis : list [int ] = torch .arange (2 , len (input .shape )).tolist ()
186198 if self .batch :
@@ -204,11 +216,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
204216 self .class_weight = torch .as_tensor ([self .class_weight ] * num_of_classes )
205217 else :
206218 if self .class_weight .shape [0 ] != num_of_classes :
207- raise ValueError (
208- """the length of the `weight` sequence should be the same as the number of classes.
219+ raise ValueError ("""the length of the `weight` sequence should be the same as the number of classes.
209220 If `include_background=False`, the weight should not include
210- the background category class 0."""
211- )
221+ the background category class 0.""" )
212222 if self .class_weight .min () < 0 :
213223 raise ValueError ("the value/values of the `weight` should be no less than 0." )
214224 # apply class_weight to loss
@@ -280,6 +290,7 @@ def __init__(
280290 smooth_dr : float = 1e-5 ,
281291 batch : bool = False ,
282292 soft_label : bool = False ,
293+ ignore_index : int | None = None ,
283294 ) -> None :
284295 """
285296 Args:
@@ -305,6 +316,8 @@ def __init__(
305316 If True, the class-weighted intersection and union areas are first summed across the batches.
306317 soft_label: whether the target contains non-binary values (soft labels) or not.
307318 If True a soft label formulation of the loss will be used.
319+ ignore_index: if not None, specifies a target index that is ignored and does not contribute to
320+ the input gradient.
308321
309322 Raises:
310323 TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -330,6 +343,7 @@ def __init__(
330343 self .smooth_dr = float (smooth_dr )
331344 self .batch = batch
332345 self .soft_label = soft_label
346+ self .ignore_index = ignore_index
333347
334348 def w_func (self , grnd ):
335349 if self .w_type == str (Weight .SIMPLE ):
@@ -360,6 +374,11 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
360374 if self .other_act is not None :
361375 input = self .other_act (input )
362376
377+ # Prepare mask before potential one-hot conversion
378+ mask : torch .Tensor | None = None
379+ if self .ignore_index is not None :
380+ mask = (target != self .ignore_index ).float ()
381+
363382 if self .to_onehot_y :
364383 if n_pred_ch == 1 :
365384 warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
@@ -370,14 +389,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
370389 if n_pred_ch == 1 :
371390 warnings .warn ("single channel prediction, `include_background=False` ignored." )
372391 else :
373- # if skipping background, removing first channel
374392 target = target [:, 1 :]
375393 input = input [:, 1 :]
376394
377395 if target .shape != input .shape :
378396 raise AssertionError (f"ground truth has differing shape ({ target .shape } ) from input ({ input .shape } )" )
379397
380- # reducing only spatial dimensions (not batch nor channels)
398+ # Exclude ignored regions from calculations
399+ if mask is not None :
400+ input = input * mask
401+ target = target * mask
402+
381403 reduce_axis : list [int ] = torch .arange (2 , len (input .shape )).tolist ()
382404 if self .batch :
383405 reduce_axis = [0 ] + reduce_axis
@@ -404,12 +426,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
404426 f : torch .Tensor = 1.0 - (numer / denom )
405427
406428 if self .reduction == LossReduction .MEAN .value :
407- f = torch .mean (f ) # the batch and channel average
429+ f = torch .mean (f )
408430 elif self .reduction == LossReduction .SUM .value :
409- f = torch .sum (f ) # sum over the batch and channel dims
431+ f = torch .sum (f )
410432 elif self .reduction == LossReduction .NONE .value :
411- # If we are not computing voxelwise loss components at least
412- # make sure a none reduction maintains a broadcastable shape
413433 broadcast_shape = list (f .shape [0 :2 ]) + [1 ] * (len (input .shape ) - 2 )
414434 f = f .view (broadcast_shape )
415435 else :
@@ -442,11 +462,12 @@ def __init__(
442462 reduction : LossReduction | str = LossReduction .MEAN ,
443463 smooth_nr : float = 1e-5 ,
444464 smooth_dr : float = 1e-5 ,
465+ ignore_index : int | None = None ,
445466 ) -> None :
446467 """
447468 Args:
448469 dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.
449- It must have dimension C x C where C is the number of classes.
470+ It must have dimension C x C where C is the number of classes.
450471 weighting_mode: {``"default"``, ``"GDL"``}
451472 Specifies how to weight the class-specific sum of errors.
452473 Default to ``"default"``.
@@ -466,35 +487,19 @@ def __init__(
466487 - ``"sum"``: the output will be summed.
467488 smooth_nr: a small constant added to the numerator to avoid zero.
468489 smooth_dr: a small constant added to the denominator to avoid nan.
490+ ignore_index: if not None, specifies a target index that is ignored and does not contribute to
491+ the input gradient.
469492
470493 Raises:
471494 ValueError: When ``dist_matrix`` is not a square matrix.
472-
473- Example:
474- .. code-block:: python
475-
476- import torch
477- import numpy as np
478- from monai.losses import GeneralizedWassersteinDiceLoss
479-
480- # Example with 3 classes (including the background: label 0).
481- # The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
482- # The distance between class 1 and class 2 is 0.5.
483- dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
484- wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
485-
486- pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)
487- grnd = torch.tensor([0, 1, 2], dtype=torch.int64)
488- wass_loss(pred_score, grnd) # 0
489-
490495 """
491496 super ().__init__ (reduction = LossReduction (reduction ).value )
492497
493498 if dist_matrix .shape [0 ] != dist_matrix .shape [1 ]:
494499 raise ValueError (f"dist_matrix must be C x C, got { dist_matrix .shape [0 ]} x { dist_matrix .shape [1 ]} ." )
495500
496501 if weighting_mode not in ["default" , "GDL" ]:
497- raise ValueError ("weighting_mode must be either 'default' or 'GDL, got %s." % weighting_mode )
502+ raise ValueError (f "weighting_mode must be either 'default' or 'GDL' , got { weighting_mode } ." )
498503
499504 self .m = dist_matrix
500505 if isinstance (self .m , np .ndarray ):
@@ -505,13 +510,13 @@ def __init__(
505510 self .num_classes = self .m .size (0 )
506511 self .smooth_nr = float (smooth_nr )
507512 self .smooth_dr = float (smooth_dr )
513+ self .ignore_index = ignore_index
508514
509515 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
510516 """
511517 Args:
512518 input: the shape should be BNH[WD].
513519 target: the shape should be BNH[WD].
514-
515520 """
516521 # Aggregate spatial dimensions
517522 flat_input = input .reshape (input .size (0 ), input .size (1 ), - 1 )
@@ -523,18 +528,20 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
523528 # Compute the Wasserstein distance map
524529 wass_dist_map = self .wasserstein_distance_map (probs , flat_target )
525530
531+ # Apply masking for ignore_index
532+ if self .ignore_index is not None :
533+ mask = (flat_target != self .ignore_index ).float ()
534+ wass_dist_map = wass_dist_map * mask
535+
526536 # Compute the values of alpha to use
527537 alpha = self ._compute_alpha_generalized_true_positives (flat_target )
528538
529539 # Compute the numerator and denominator of the generalized Wasserstein Dice loss
530540 if self .alpha_mode == "GDL" :
531541 # use GDL-style alpha weights (i.e. normalize by the volume of each class)
532- # contrary to the original definition we also use alpha in the "generalized all error".
533542 true_pos = self ._compute_generalized_true_positive (alpha , flat_target , wass_dist_map )
534543 denom = self ._compute_denominator (alpha , flat_target , wass_dist_map )
535544 else : # default: as in the original paper
536- # (i.e. alpha=1 for all foreground classes and 0 for the background).
537- # Compute the generalised number of true positives
538545 true_pos = self ._compute_generalized_true_positive (alpha , flat_target , wass_dist_map )
539546 all_error = torch .sum (wass_dist_map , dim = 1 )
540547 denom = 2 * true_pos + all_error
@@ -544,12 +551,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
544551 wass_dice_loss : torch .Tensor = 1.0 - wass_dice
545552
546553 if self .reduction == LossReduction .MEAN .value :
547- wass_dice_loss = torch .mean (wass_dice_loss ) # the batch and channel average
554+ wass_dice_loss = torch .mean (wass_dice_loss )
548555 elif self .reduction == LossReduction .SUM .value :
549- wass_dice_loss = torch .sum (wass_dice_loss ) # sum over the batch and channel dims
556+ wass_dice_loss = torch .sum (wass_dice_loss )
550557 elif self .reduction == LossReduction .NONE .value :
551- # If we are not computing voxelwise loss components at least
552- # make sure a none reduction maintains a broadcastable shape
553558 broadcast_shape = input .shape [0 :2 ] + (1 ,) * (len (input .shape ) - 2 )
554559 wass_dice_loss = wass_dice_loss .view (broadcast_shape )
555560 else :
@@ -674,6 +679,7 @@ def __init__(
674679 lambda_dice : float = 1.0 ,
675680 lambda_ce : float = 1.0 ,
676681 label_smoothing : float = 0.0 ,
682+ ignore_index : int | None = None ,
677683 ) -> None :
678684 """
679685 Args:
@@ -715,6 +721,8 @@ def __init__(
715721 label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
716722 by the given factor to reduce overfitting.
717723 Defaults to 0.0.
724+ ignore_index: if not None, specifies a target index that is ignored and does not contribute to
725+ the input gradient.
718726
719727 """
720728 super ().__init__ ()
@@ -737,15 +745,22 @@ def __init__(
737745 smooth_dr = smooth_dr ,
738746 batch = batch ,
739747 weight = dice_weight ,
748+ ignore_index = ignore_index ,
749+ )
750+ self .cross_entropy = nn .CrossEntropyLoss (
751+ weight = weight ,
752+ reduction = reduction ,
753+ label_smoothing = label_smoothing ,
754+ ignore_index = ignore_index if ignore_index is not None else - 100 ,
740755 )
741- self .cross_entropy = nn .CrossEntropyLoss (weight = weight , reduction = reduction , label_smoothing = label_smoothing )
742756 self .binary_cross_entropy = nn .BCEWithLogitsLoss (pos_weight = weight , reduction = reduction )
743757 if lambda_dice < 0.0 :
744758 raise ValueError ("lambda_dice should be no less than 0.0." )
745759 if lambda_ce < 0.0 :
746760 raise ValueError ("lambda_ce should be no less than 0.0." )
747761 self .lambda_dice = lambda_dice
748762 self .lambda_ce = lambda_ce
763+ self .ignore_index = ignore_index
749764
750765 def ce (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
751766 """
@@ -801,7 +816,21 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
801816 )
802817
803818 dice_loss = self .dice (input , target )
804- ce_loss = self .ce (input , target ) if input .shape [1 ] != 1 else self .bce (input , target )
819+
820+ if input .shape [1 ] != 1 :
821+ # CrossEntropyLoss handles ignore_index natively
822+ ce_loss = self .ce (input , target )
823+ else :
824+ # BCEWithLogitsLoss does not support ignore_index, handle manually
825+ ce_loss = self .bce (input , target )
826+ if self .ignore_index is not None :
827+ mask = (target != self .ignore_index ).float ()
828+ ce_loss = ce_loss * mask
829+ if self .dice .reduction == "mean" :
830+ ce_loss = torch .mean (ce_loss )
831+ elif self .dice .reduction == "sum" :
832+ ce_loss = torch .sum (ce_loss )
833+
805834 total_loss : torch .Tensor = self .lambda_dice * dice_loss + self .lambda_ce * ce_loss
806835
807836 return total_loss
0 commit comments