|
18 | 18 |
|
19 | 19 | from monai.networks import one_hot |
20 | 20 | from monai.utils import LossReduction |
| 21 | +from monai.metrics.utils import create_ignore_mask |
21 | 22 |
|
22 | 23 |
|
23 | 24 | class AsymmetricFocalTverskyLoss(_Loss): |
@@ -74,19 +75,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
74 | 75 | if y_true.shape != y_pred.shape: |
75 | 76 | raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") |
76 | 77 |
|
77 | | - # Build mask after one_hot conversion |
78 | | - mask = torch.ones_like(y_true) |
79 | | - if self.ignore_index is not None: |
80 | | - if original_y_true is not None and self.to_onehot_y: |
81 | | - # Use original labels to build spatial mask |
82 | | - spatial_mask = (original_y_true != self.ignore_index).float() |
83 | | - elif self.ignore_index < y_true.shape[1]: |
84 | | - # For already one-hot: use ignored class channel |
85 | | - spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1] |
86 | | - else: |
87 | | - # For sentinel values: any valid channel |
88 | | - spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float() |
89 | | - mask = spatial_mask.expand_as(y_true) |
| 78 | + mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index) |
| 79 | + if mask is not None: |
| 80 | + mask = mask.expand_as(y_true) |
90 | 81 | y_pred = y_pred * mask |
91 | 82 | y_true = y_true * mask |
92 | 83 |
|
@@ -169,18 +160,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
169 | 160 | y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) |
170 | 161 | cross_entropy = -y_true * torch.log(y_pred) |
171 | 162 |
|
172 | | - # Build mask from original labels if available |
173 | | - spatial_mask: torch.Tensor | None = None |
174 | | - if self.ignore_index is not None: |
175 | | - if original_y_true is not None and self.to_onehot_y: |
176 | | - spatial_mask = (original_y_true != self.ignore_index).float() |
177 | | - elif self.ignore_index < y_true.shape[1]: |
178 | | - spatial_mask = 1.0 - y_true[:, self.ignore_index : self.ignore_index + 1] |
179 | | - else: |
180 | | - spatial_mask = (y_true.sum(dim=1, keepdim=True) > 0).float() |
181 | | - |
182 | | - if spatial_mask is not None: |
183 | | - cross_entropy = cross_entropy * spatial_mask.expand_as(cross_entropy) |
| 163 | + # Apply mask from original labels if available |
| 164 | + mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index) |
| 165 | + if mask is not None: |
| 166 | + cross_entropy = cross_entropy * mask.expand_as(cross_entropy) |
| 167 | + # Rename for compatibility with reduction block below |
| 168 | + spatial_mask = mask |
184 | 169 |
|
185 | 170 | back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] |
186 | 171 | back_ce = (1 - self.delta) * back_ce |
@@ -276,30 +261,16 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
276 | 261 | if y_pred.shape[1] == 1: |
277 | 262 | y_pred = torch.cat([1 - y_pred, y_pred], dim=1) |
278 | 263 |
|
279 | | - # Move one_hot conversion OUTSIDE the if y_pred.shape[1] == 1 block |
280 | 264 | if self.to_onehot_y: |
281 | | - if self.ignore_index is not None: |
282 | | - mask = (y_true != self.ignore_index).float() |
283 | | - y_true_clean = torch.where(y_true == self.ignore_index, 0, y_true) |
284 | | - y_true = one_hot(y_true_clean, num_classes=self.num_classes) |
285 | | - # Keep the channel-wise mask |
286 | | - y_true = y_true * mask |
287 | | - else: |
288 | | - y_true = one_hot(y_true, num_classes=self.num_classes) |
289 | | - |
290 | | - # Check if shapes match |
291 | | - if y_true.shape[1] == 1 and y_pred.shape[1] == 2: |
292 | | - if self.ignore_index is not None: |
293 | | - # Create mask for valid pixels |
294 | | - mask = (y_true != self.ignore_index).float() |
295 | | - # Set ignore_index values to 0 before conversion |
296 | | - y_true_clean = y_true * mask |
297 | | - # Convert to 2-channel |
298 | | - y_true = torch.cat([1 - y_true_clean, y_true_clean], dim=1) |
299 | | - # Apply mask to both channels so ignored pixels are all zeros |
300 | | - y_true = y_true * mask |
301 | | - else: |
302 | | - y_true = torch.cat([1 - y_true, y_true], dim=1) |
| 265 | + y_true = one_hot(y_true, num_classes=self.num_classes) |
| 266 | + elif y_true.shape[1] == 1 and y_pred.shape[1] == 2: |
| 267 | + y_true = torch.cat([1 - y_true, y_true], dim=1) |
| 268 | + |
| 269 | + original_y_true_unified = y_true # Use transformed y_true as baseline if original unavailable |
| 270 | + mask = create_ignore_mask(original_y_true_unified if self.ignore_index is not None else None, self.ignore_index) |
| 271 | + |
| 272 | + if mask is not None: |
| 273 | + y_true = y_true * mask |
303 | 274 |
|
304 | 275 | if y_true.shape != y_pred.shape: |
305 | 276 | raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") |
|
0 commit comments