Skip to content

Commit 610756d

Browse files
committed
refactor: centralize ignore_index masking into create_ignore_mask helper
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 3bd76e7 commit 610756d

12 files changed

Lines changed: 103 additions & 183 deletions

monai/losses/dice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from monai.losses.utils import compute_tp_fp_fn
2626
from monai.networks import one_hot
2727
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
28+
from monai.metrics.utils import create_ignore_mask
2829

2930

3031
class DiceLoss(_Loss):
@@ -166,16 +167,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
166167
if self.other_act is not None:
167168
input = self.other_act(input)
168169

169-
mask: torch.Tensor | None = None
170-
if self.ignore_index is not None:
171-
mask = (target != self.ignore_index).float()
170+
original_target = target if self.ignore_index is not None else None
172171

173172
if self.to_onehot_y:
174173
if n_pred_ch == 1:
175174
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
176175
else:
177176
target = one_hot(target, num_classes=n_pred_ch)
178177

178+
mask = create_ignore_mask(original_target if original_target is not None else target, self.ignore_index)
179179
if not self.include_background:
180180
if n_pred_ch == 1:
181181
warnings.warn("single channel prediction, `include_background=False` ignored.")

monai/losses/focal_loss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from monai.networks import one_hot
2222
from monai.utils import LossReduction
23+
from monai.metrics.utils import create_ignore_mask
2324

2425

2526
class FocalLoss(_Loss):
@@ -164,8 +165,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
164165
if target.shape != input.shape:
165166
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
166167

167-
if self.ignore_index is not None:
168-
mask = (target != self.ignore_index).float()
168+
mask = create_ignore_mask(target, self.ignore_index)
169+
if mask is not None:
169170
input = input * mask
170171
target = target * mask
171172

monai/losses/tversky.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.losses.utils import compute_tp_fp_fn
2121
from monai.networks import one_hot
2222
from monai.utils import LossReduction
23+
from monai.metrics.utils import create_ignore_mask
2324

2425

2526
class TverskyLoss(_Loss):
@@ -137,15 +138,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
137138

138139
if self.ignore_index is not None:
139140
mask_src = original_target if self.to_onehot_y and n_pred_ch > 1 else target
140-
141-
if mask_src.shape[1] == 1:
142-
mask = (mask_src != self.ignore_index).to(input.dtype)
143-
else:
144-
# Fallback for cases where target is already one-hot
145-
mask = (1.0 - mask_src[:, self.ignore_index : self.ignore_index + 1]).to(input.dtype)
146-
147-
input = input * mask
148-
target = target * mask
141+
mask = create_ignore_mask(mask_src, self.ignore_index)
142+
if mask is not None:
143+
input = input * mask
144+
target = target * mask
149145

150146
if not self.include_background:
151147
if n_pred_ch == 1:

monai/losses/unified_focal_loss.py

Lines changed: 19 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from monai.networks import one_hot
2020
from monai.utils import LossReduction
21+
from monai.metrics.utils import create_ignore_mask
2122

2223

2324
class AsymmetricFocalTverskyLoss(_Loss):
@@ -74,19 +75,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7475
if y_true.shape != y_pred.shape:
7576
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
7677

77-
# Build mask after one_hot conversion
78-
mask = torch.ones_like(y_true)
79-
if self.ignore_index is not None:
80-
if original_y_true is not None and self.to_onehot_y:
81-
# Use original labels to build spatial mask
82-
spatial_mask = (original_y_true != self.ignore_index).float()
83-
elif self.ignore_index < y_true.shape[1]:
84-
# For already one-hot: use ignored class channel
85-
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
86-
else:
87-
# For sentinel values: any valid channel
88-
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
89-
mask = spatial_mask.expand_as(y_true)
78+
mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
79+
if mask is not None:
80+
mask = mask.expand_as(y_true)
9081
y_pred = y_pred * mask
9182
y_true = y_true * mask
9283

@@ -169,18 +160,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
169160
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
170161
cross_entropy = -y_true * torch.log(y_pred)
171162

172-
# Build mask from original labels if available
173-
spatial_mask: torch.Tensor | None = None
174-
if self.ignore_index is not None:
175-
if original_y_true is not None and self.to_onehot_y:
176-
spatial_mask = (original_y_true != self.ignore_index).float()
177-
elif self.ignore_index < y_true.shape[1]:
178-
spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1]
179-
else:
180-
spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float()
181-
182-
if spatial_mask is not None:
183-
cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy)
163+
# Apply mask from original labels if available
164+
mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
165+
if mask is not None:
166+
cross_entropy = cross_entropy * mask.expand_as(cross_entropy)
167+
# Rename for compatibility with reduction block below
168+
spatial_mask = mask
184169

185170
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
186171
back_ce = (1 - self.delta) * back_ce
@@ -276,30 +261,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
276261
if y_pred.shape[1] == 1:
277262
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
278263

279-
# Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block
280264
if self.to_onehot_y:
281-
if self.ignore_index is not None:
282-
mask = (y_true != self.ignore_index).float()
283-
y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true)
284-
y_true = one_hot(y_true_clean, num_classes=self.num_classes)
285-
# Keep the channel-wise mask
286-
y_true = y_true * mask
287-
else:
288-
y_true = one_hot(y_true, num_classes=self.num_classes)
289-
290-
# Check if shapes match
291-
if y_true.shape[1] == 1 and y_pred.shape[1] == 2:
292-
if self.ignore_index is not None:
293-
# Create mask for valid pixels
294-
mask = (y_true != self.ignore_index).float()
295-
# Set ignore_index values to 0 before conversion
296-
y_true_clean = y_true * mask
297-
# Convert to 2-channel
298-
y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1)
299-
# Apply mask to both channels so ignored pixels are all zeros
300-
y_true = y_true * mask
301-
else:
302-
y_true = torch.cat([1 - y_true, y_true], dim=1)
265+
y_true = one_hot(y_true, num_classes=self.num_classes)
266+
elif y_true.shape[1] == 1 and y_pred.shape[1] == 2:
267+
y_true = torch.cat([1 - y_true, y_true], dim=1)
268+
269+
original_y_true_unified = y_true # Use transformed y_true as baseline if original unavailable
270+
mask = create_ignore_mask(original_y_true_unified if self.ignore_index is not None else None, self.ignore_index)
271+
272+
if mask is not None:
273+
y_true = y_true * mask
303274

304275
if y_true.shape != y_pred.shape:
305276
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

monai/metrics/confusion_matrix.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from monai.metrics.utils import do_metric_reduction, ignore_background
19+
from monai.metrics.utils import create_ignore_mask, do_metric_reduction, ignore_background
2020
from monai.utils import MetricReduction, ensure_tuple
2121

2222
from .metric import CumulativeIterationMetric
@@ -169,14 +169,7 @@ def get_confusion_matrix(
169169
batch_size, n_class = y_pred.shape[:2]
170170

171171
# Create spatial mask if ignore_index is provided
172-
mask = None
173-
if ignore_index is not None:
174-
if ignore_index >= n_class:
175-
# If ignore_index is outside channel range (e.g. 255), we assume it's a spatial mask
176-
mask = y.sum(dim=1, keepdim=True) > 0
177-
else:
178-
# If ignore_index is a valid channel, exclude that specific channel
179-
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
172+
mask = create_ignore_mask(y, ignore_index)
180173

181174
# convert to [BNS], where S is the number of pixels for one sample.
182175
y_pred = y_pred.reshape(batch_size, n_class, -1)

monai/metrics/generalized_dice.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction
16+
from monai.metrics.utils import create_ignore_mask, do_metric_reduction
1717
from monai.utils import MetricReduction, Weight, deprecated_arg, look_up_option
1818

1919
from .metric import CumulativeIterationMetric
@@ -156,13 +156,8 @@ def compute_generalized_dice(
156156
raise ValueError(f"y_pred - {y_pred.shape} - and y - {y.shape} - should have the same shapes.")
157157

158158
# Apply ignore_index masking
159-
if ignore_index is not None:
160-
if 0 <= ignore_index < y.shape[1]:
161-
# For one-hot: use the ignored class channel
162-
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
163-
else:
164-
# For sentinel values (like 255 or -100), check if any channel is valid
165-
mask = (y.sum(dim=1, keepdim=True) > 0).float()
159+
mask = create_ignore_mask(y, ignore_index)
160+
if mask is not None:
166161
y_pred = y_pred * mask
167162
y = y * mask
168163

monai/metrics/hausdorff_distance.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
import numpy as np
1818
import torch
1919

20-
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
20+
from monai.metrics.utils import (
21+
create_ignore_mask,
22+
do_metric_reduction,
23+
get_edge_surface_distance,
24+
ignore_background,
25+
prepare_spacing,
26+
)
2127
from monai.utils import MetricReduction, convert_data_type
2228

2329
from .metric import CumulativeIterationMetric
@@ -51,7 +57,6 @@ class HausdorffDistanceMetric(CumulativeIterationMetric):
5157
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
5258
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
5359
Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
54-
ignore_index: index of the class to ignore during calculation. Defaults to ``None``.
5560
5661
"""
5762

@@ -100,9 +105,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
100105
if dims < 3:
101106
raise ValueError("y_pred should have at least three dimensions.")
102107

103-
mask = None
104-
if self.ignore_index is not None:
105-
mask = (y != self.ignore_index).all(dim=1, keepdim=True).float()
108+
mask = create_ignore_mask(y, self.ignore_index)
109+
if mask is not None:
106110
y_pred = y_pred * mask
107111
y = y * mask
108112

@@ -115,8 +119,6 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
115119
percentile=self.percentile,
116120
directed=self.directed,
117121
spacing=kwargs.get("spacing"),
118-
ignore_index=self.ignore_index,
119-
mask=mask,
120122
)
121123

122124
def aggregate(
@@ -148,8 +150,6 @@ def compute_hausdorff_distance(
148150
percentile: float | None = None,
149151
directed: bool = False,
150152
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
151-
mask: torch.Tensor | None = None,
152-
ignore_index: int | None = None,
153153
) -> torch.Tensor:
154154
"""
155155
Compute the Hausdorff distance.
@@ -175,7 +175,6 @@ def compute_hausdorff_distance(
175175
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
176176
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
177177
for all images in batch. Defaults to ``None``.
178-
ignore_index: index of the class to ignore during calculation. Defaults to ``None``.
179178
"""
180179

181180
if not include_background:
@@ -196,23 +195,12 @@ def compute_hausdorff_distance(
196195
yp = y_pred[b, c]
197196
yt = y[b, c]
198197

199-
if ignore_index is not None:
200-
valid_mask = y[b].sum(dim=0) > 0
201-
yp = yp * valid_mask
202-
yt = yt * valid_mask
203-
204-
# if everything is ignored, define distance as 0
205-
if not valid_mask.any():
206-
hd[b, c] = torch.tensor(0.0, device=y_pred.device)
207-
continue
208-
209198
_, distances, _ = get_edge_surface_distance(
210199
yp,
211200
yt,
212201
distance_metric=distance_metric,
213202
spacing=spacing_list[b],
214203
symmetric=not directed,
215-
mask=mask[b, 0] if mask is not None else None,
216204
)
217205

218206
if len(distances) == 0:

monai/metrics/meandice.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction
16+
from monai.metrics.utils import create_ignore_mask, do_metric_reduction
1717
from monai.utils import MetricReduction, deprecated_arg
1818

1919
from .metric import CumulativeIterationMetric
@@ -297,7 +297,7 @@ def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor, mask: torch.Ten
297297
mask: binary mask where 0 indicates voxels to ignore.
298298
"""
299299
if mask is not None:
300-
y_pred = y_pred * mask
300+
y_pred = y_pred & mask.bool()
301301
y = y * mask
302302

303303
y_o = torch.sum(y)
@@ -336,19 +336,7 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
336336
y_pred = y_pred > 0.5
337337

338338
# Create global mask for ignored voxels if ignore_index is set
339-
mask = None
340-
if self.ignore_index is not None:
341-
if y.shape[1] == 1:
342-
# Single channel - values are class indices
343-
mask = y != self.ignore_index
344-
else:
345-
# Multi-channel (one-hot or class probabilities)
346-
if self.ignore_index < n_pred_ch:
347-
# Class-based ignore: ignore specific class channel
348-
mask = y[:, self.ignore_index : self.ignore_index + 1] == 0
349-
else:
350-
# Sentinel-based ignore: ignore where all channels are 0
351-
mask = y.sum(dim=1, keepdim=True) > 0
339+
mask = create_ignore_mask(y, self.ignore_index)
352340

353341
first_ch = 0 if self.include_background else 1
354342
data = []

monai/metrics/meaniou.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction, ignore_background
16+
from monai.metrics.utils import create_ignore_mask, do_metric_reduction, ignore_background
1717
from monai.utils import MetricReduction
1818

1919
from .metric import CumulativeIterationMetric
@@ -143,15 +143,9 @@ def compute_iou(
143143
if y.shape != y_pred.shape:
144144
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
145145

146-
if ignore_index is not None:
147-
if ignore_index < y.shape[1]:
148-
# For one-hot: mask based on the ignored class channel
149-
mask = 1.0 - y[:, ignore_index : ignore_index + 1]
150-
if mask.shape != y_pred.shape:
151-
mask = mask.expand_as(y_pred)
152-
else:
153-
# For sentinel values, check if any channel is valid
154-
mask = (y.sum(dim=1, keepdim=True) > 0).float()
146+
mask = create_ignore_mask(y, ignore_index)
147+
if mask is not None:
148+
if mask.shape != y_pred.shape:
155149
mask = mask.expand_as(y_pred)
156150
y_pred = y_pred * mask
157151
y = y * mask

0 commit comments

Comments
 (0)