Skip to content

Commit 583d5ca

Browse files
ytl0623ericspod
andauthored
Adjust execution order of activation and masking in MaskedDiceLoss (#8704)
Fixes #8655 ### Description 1. Modified `MaskedDiceLoss.forward` to ensure masked regions result in 0.0 probability instead of 0.5. 2. Updated unit tests with new expected values, as the fix mathematically changes the loss result by removing the erroneous background contribution. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent d53eb00 commit 583d5ca

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

monai/losses/dice.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import warnings
1515
from collections.abc import Callable, Sequence
16-
from typing import Any
1716

1817
import numpy as np
1918
import torch
@@ -239,11 +238,52 @@ class MaskedDiceLoss(DiceLoss):
239238
240239
"""
241240

242-
def __init__(self, *args: Any, **kwargs: Any) -> None:
241+
def __init__(
242+
self,
243+
include_background: bool = True,
244+
to_onehot_y: bool = False,
245+
sigmoid: bool = False,
246+
softmax: bool = False,
247+
other_act: Callable | None = None,
248+
squared_pred: bool = False,
249+
jaccard: bool = False,
250+
reduction: LossReduction | str = LossReduction.MEAN,
251+
smooth_nr: float = 1e-5,
252+
smooth_dr: float = 1e-5,
253+
batch: bool = False,
254+
weight: Sequence[float] | float | int | torch.Tensor | None = None,
255+
soft_label: bool = False,
256+
) -> None:
243257
"""
244258
Args follow :py:class:`monai.losses.DiceLoss`.
245259
"""
246-
super().__init__(*args, **kwargs)
260+
if other_act is not None and not callable(other_act):
261+
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
262+
if sigmoid and softmax:
263+
raise ValueError("Incompatible values: sigmoid=True and softmax=True.")
264+
if other_act is not None and (sigmoid or softmax):
265+
raise ValueError("Incompatible values: other_act is not None and sigmoid=True or softmax=True.")
266+
267+
self.pre_sigmoid = sigmoid
268+
self.pre_softmax = softmax
269+
self.pre_other_act = other_act
270+
271+
super().__init__(
272+
include_background=include_background,
273+
to_onehot_y=to_onehot_y,
274+
sigmoid=False,
275+
softmax=False,
276+
other_act=None,
277+
squared_pred=squared_pred,
278+
jaccard=jaccard,
279+
reduction=reduction,
280+
smooth_nr=smooth_nr,
281+
smooth_dr=smooth_dr,
282+
batch=batch,
283+
weight=weight,
284+
soft_label=soft_label,
285+
)
286+
247287
self.spatial_weighted = MaskedLoss(loss=super().forward)
248288

249289
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
@@ -253,6 +293,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
253293
target: the shape should be BNH[WD].
254294
mask: the shape should B1H[WD] or 11H[WD].
255295
"""
296+
297+
if self.pre_sigmoid:
298+
input = torch.sigmoid(input)
299+
300+
n_pred_ch = input.shape[1]
301+
if self.pre_softmax:
302+
if n_pred_ch == 1:
303+
warnings.warn("single channel prediction, `softmax=True` ignored.", stacklevel=2)
304+
else:
305+
input = torch.softmax(input, 1)
306+
307+
if self.pre_other_act is not None:
308+
input = self.pre_other_act(input)
256309
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
257310

258311

tests/losses/test_masked_dice_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
2828
"mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]),
2929
},
30-
0.500,
30+
0.333333,
3131
],
3232
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
3333
{"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -36,7 +36,7 @@
3636
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
3737
"mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]),
3838
},
39-
0.422969,
39+
0.301128,
4040
],
4141
[ # shape: (2, 2, 3), (2, 1, 3)
4242
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
@@ -54,7 +54,7 @@
5454
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
5555
"mask": torch.tensor([[[1.0, 1.0, 0.0]]]),
5656
},
57-
0.47033,
57+
0.579184,
5858
],
5959
[ # shape: (2, 2, 3), (2, 1, 3)
6060
{

0 commit comments

Comments
 (0)