Skip to content

Commit 8755ad5

Browse files
committed
Address ericspod's review comments: use class member for smooth parameter
- Replace hardcoded 1e-7 with configurable smooth parameter in SoftclDiceLoss - Add smooth parameter to SoftDiceclDiceLoss with docstring note about differing from DiceLoss defaults (1e-5 vs 1e-4) - Add validation for smooth parameter in __init__ - Update Raises sections in docstrings Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 781616b commit 8755ad5

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

monai/losses/cldice.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
iter_: int = 3,
123123
smooth_nr: float = 1.0,
124124
smooth_dr: float = 1.0,
125+
smooth: float = 1e-4,
125126
include_background: bool = True,
126127
to_onehot_y: bool = False,
127128
sigmoid: bool = False,
@@ -134,6 +135,7 @@ def __init__(
134135
iter_: Number of iterations for skeletonization. Must be a non-negative integer. Defaults to 3.
135136
smooth_nr: a small constant added to the numerator to avoid zero. Defaults to 1.0.
136137
smooth_dr: a small constant added to the denominator to avoid nan. Defaults to 1.0.
138+
smooth: a small constant added to the denominator of the harmonic mean to avoid nan. Defaults to 1e-4.
137139
include_background: if False, channel index 0 (background category) is excluded from the calculation.
138140
if the non-background segmentations are small compared to the total image size they can get overwhelmed
139141
by the signal from the background so excluding it in such cases helps convergence.
@@ -154,6 +156,7 @@ def __init__(
154156
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
155157
TypeError: When ``iter_`` is not an ``int``.
156158
ValueError: When ``iter_`` is a negative integer.
159+
ValueError: When ``smooth`` is not a positive value.
157160
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
158161
Incompatible values.
159162
@@ -167,9 +170,12 @@ def __init__(
167170
raise TypeError(f"iter_ must be an integer but got {type(iter_).__name__}.")
168171
if iter_ < 0:
169172
raise ValueError(f"iter_ must be a non-negative integer but got {iter_}.")
173+
if smooth <= 0:
174+
raise ValueError(f"smooth must be a positive value but got {smooth}.")
170175
self.iter = iter_
171176
self.smooth_nr = float(smooth_nr)
172177
self.smooth_dr = float(smooth_dr)
178+
self.smooth = float(smooth)
173179
self.include_background = include_background
174180
self.to_onehot_y = to_onehot_y
175181
self.sigmoid = sigmoid
@@ -233,7 +239,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
233239
torch.sum(skel_true, dim=reduce_axis) + self.smooth_dr
234240
)
235241
# Add small epsilon for numerical stability in harmonic mean
236-
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + 1e-7)
242+
cl_dice: torch.Tensor = 1.0 - 2.0 * (tprec * tsens) / (tprec + tsens + self.smooth)
237243

238244
# Apply reduction
239245
if self.reduction == LossReduction.MEAN.value:
@@ -266,6 +272,7 @@ def __init__(
266272
alpha: float = 0.5,
267273
smooth_nr: float = 1.0,
268274
smooth_dr: float = 1.0,
275+
smooth: float = 1e-4,
269276
include_background: bool = True,
270277
to_onehot_y: bool = False,
271278
sigmoid: bool = False,
@@ -280,6 +287,8 @@ def __init__(
280287
Defaults to 0.5.
281288
smooth_nr: a small constant added to the numerator to avoid zero, used by both Dice and clDice. Defaults to 1.0.
282289
smooth_dr: a small constant added to the denominator to avoid nan, used by both Dice and clDice. Defaults to 1.0.
290+
smooth: a small constant added to the denominator of the harmonic mean in clDice to avoid nan.
291+
Defaults to 1e-4. Note: This differs from standalone DiceLoss defaults (1e-5) to follow clDice convention.
283292
include_background: if False, channel index 0 (background category) is excluded from the calculation.
284293
if the non-background segmentations are small compared to the total image size they can get overwhelmed
285294
by the signal from the background so excluding it in such cases helps convergence.
@@ -320,6 +329,7 @@ def __init__(
320329
iter_=iter_,
321330
smooth_nr=smooth_nr,
322331
smooth_dr=smooth_dr,
332+
smooth=smooth,
323333
include_background=include_background,
324334
to_onehot_y=False,
325335
sigmoid=sigmoid,

0 commit comments

Comments
 (0)