Skip to content

Commit eeda3c7

Browse files
committed
chore: format and lint code
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent cfc54ec commit eeda3c7

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

monai/losses/unified_focal_loss.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
262262
if y_pred.shape[1] == 1:
263263
y_pred = torch.cat([1 - y_pred, y_pred], dim=1)
264264

265+
original_y_true = y_true if self.ignore_index is not None else None
266+
265267
if self.to_onehot_y:
268+
if self.ignore_index is not None:
269+
# Replace ignore_index with valid class before one_hot
270+
y_true = torch.where(y_true == self.ignore_index, torch.tensor(0, device=y_true.device), y_true)
266271
y_true = one_hot(y_true, num_classes=self.num_classes)
267272
elif y_true.shape[1] == 1 and y_pred.shape[1] == 2:
268273
y_true = torch.cat([1 - y_true, y_true], dim=1)
269274

270-
original_y_true_unified = y_true # Use transformed y_true as baseline if original unavailable
271-
mask = create_ignore_mask(original_y_true_unified if self.ignore_index is not None else None, self.ignore_index)
272-
273-
if mask is not None:
274-
y_true = y_true * mask
275+
if self.ignore_index is not None:
276+
mask = original_y_true == self.ignore_index
277+
y_true[mask.expand_as(y_true)] = self.ignore_index
275278

276279
if y_true.shape != y_pred.shape:
277280
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")

0 commit comments

Comments
 (0)