Skip to content

Commit 5ebcade

Browse files
aymuos15ericspod
andauthored
Fix incomplete activation validation in HausdorffDTLoss (Project-MONAI#8841)
The validation check for mutually exclusive activation options was incomplete - it only checked sigmoid and softmax but not other_act, despite the error message explicitly mentioning other_act. Without this check, passing e.g. `sigmoid=True, other_act=relu` silently stacks both activations in `forward()` (`other_act(sigmoid(x))`) instead of applying only one, producing an incorrect loss with no warning. Before: ```python if int(sigmoid) + int(softmax) > 1: raise ValueError("... [sigmoid=True, softmax=True, other_act is not None].") ``` After: ```python if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: raise ValueError("... [sigmoid=True, softmax=True, other_act is not None].") ``` This is consistent with the validation in dice.py and tversky.py which correctly include all three options in the check. Added tests for: - sigmoid=True with other_act - softmax=True with other_act - All three options combined ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent c95d9a9 commit 5ebcade

2 files changed

Lines changed: 13 additions & 1 deletion

File tree

monai/losses/hausdorff_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
super().__init__(reduction=LossReduction(reduction).value)
8484
if other_act is not None and not callable(other_act):
8585
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
86-
if int(sigmoid) + int(softmax) > 1:
86+
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
8787
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
8888

8989
self.alpha = alpha

tests/losses/test_hausdorff_loss.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ def test_ill_shape(self):
212212
def test_ill_opts(self):
213213
with self.assertRaisesRegex(ValueError, ""):
214214
HausdorffDTLoss(sigmoid=True, softmax=True)
215+
with self.assertRaisesRegex(ValueError, ""):
216+
HausdorffDTLoss(sigmoid=True, other_act=torch.tanh)
217+
with self.assertRaisesRegex(ValueError, ""):
218+
HausdorffDTLoss(softmax=True, other_act=torch.tanh)
219+
with self.assertRaisesRegex(ValueError, ""):
220+
HausdorffDTLoss(sigmoid=True, softmax=True, other_act=torch.tanh)
215221
chn_input = torch.ones((1, 1, 3))
216222
chn_target = torch.ones((1, 1, 3))
217223
with self.assertRaisesRegex(ValueError, ""):
@@ -244,6 +250,12 @@ def test_ill_shape(self):
244250
def test_ill_opts(self):
245251
with self.assertRaisesRegex(ValueError, ""):
246252
LogHausdorffDTLoss(sigmoid=True, softmax=True)
253+
with self.assertRaisesRegex(ValueError, ""):
254+
LogHausdorffDTLoss(sigmoid=True, other_act=torch.tanh)
255+
with self.assertRaisesRegex(ValueError, ""):
256+
LogHausdorffDTLoss(softmax=True, other_act=torch.tanh)
257+
with self.assertRaisesRegex(ValueError, ""):
258+
LogHausdorffDTLoss(sigmoid=True, softmax=True, other_act=torch.tanh)
247259
chn_input = torch.ones((1, 1, 3))
248260
chn_target = torch.ones((1, 1, 3))
249261
with self.assertRaisesRegex(ValueError, ""):

0 commit comments

Comments
 (0)