Skip to content

Commit cfc54ec

Browse files
committed
fix: CI errors
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent c2612ea commit cfc54ec

1 file changed

Lines changed: 8 additions & 10 deletions

File tree

monai/losses/unified_focal_loss.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
160160
if y_true.shape != y_pred.shape:
161161
raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})")
162162

163+
mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
164+
if mask is not None:
165+
mask = mask.expand_as(y_true)
166+
163167
y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon)
164168
cross_entropy = -y_true * torch.log(y_pred)
165169

166-
# Apply mask from original labels if available
167-
mask = create_ignore_mask(original_y_true if original_y_true is not None else y_true, self.ignore_index)
168170
if mask is not None:
169-
cross_entropy = cross_entropy * mask.expand_as(cross_entropy)
170-
# Rename for compatibility with reduction block below
171-
spatial_mask = mask
171+
cross_entropy = cross_entropy * mask
172172

173173
back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0]
174174
back_ce = (1 - self.delta) * back_ce
@@ -179,11 +179,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
179179
loss = torch.stack([back_ce, fore_ce], dim=1) # [B, 2, H, W]
180180

181181
if self.reduction == LossReduction.MEAN.value:
182-
if self.ignore_index is not None and spatial_mask is not None:
183-
# Apply mask to loss, then average over valid elements only
184-
# loss has shape [B, 2, H, W], spatial_mask has shape [B, 1, H, W]
185-
masked_loss = loss * spatial_mask.expand_as(loss)
186-
return masked_loss.sum() / (spatial_mask.expand_as(loss).sum().clamp(min=1e-5))
182+
if mask is not None:
183+
masked_loss = loss * mask
184+
return masked_loss.sum() / mask.expand_as(loss).sum().clamp(min=1e-5)
187185
return loss.mean()
188186
if self.reduction == LossReduction.SUM.value:
189187
return loss.sum()

0 commit comments

Comments
 (0)