@@ -160,15 +160,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
160160 if y_true .shape != y_pred .shape :
161161 raise ValueError (f"ground truth has different shape ({ y_true .shape } ) from input ({ y_pred .shape } )" )
162162
163+ mask = create_ignore_mask (original_y_true if original_y_true is not None else y_true , self .ignore_index )
164+ if mask is not None :
165+ mask = mask .expand_as (y_true )
166+
163167 y_pred = torch .clamp (y_pred , self .epsilon , 1.0 - self .epsilon )
164168 cross_entropy = - y_true * torch .log (y_pred )
165169
166- # Apply mask from original labels if available
167- mask = create_ignore_mask (original_y_true if original_y_true is not None else y_true , self .ignore_index )
168170 if mask is not None :
169- cross_entropy = cross_entropy * mask .expand_as (cross_entropy )
170- # Rename for compatibility with reduction block below
171- spatial_mask = mask
171+ cross_entropy = cross_entropy * mask
172172
173173 back_ce = torch .pow (1 - y_pred [:, 0 ], self .gamma ) * cross_entropy [:, 0 ]
174174 back_ce = (1 - self .delta ) * back_ce
@@ -179,11 +179,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
179179 loss = torch .stack ([back_ce , fore_ce ], dim = 1 ) # [B, 2, H, W]
180180
181181 if self .reduction == LossReduction .MEAN .value :
182- if self .ignore_index is not None and spatial_mask is not None :
183- # Apply mask to loss, then average over valid elements only
184- # loss has shape [B, 2, H, W], spatial_mask has shape [B, 1, H, W]
185- masked_loss = loss * spatial_mask .expand_as (loss )
186- return masked_loss .sum () / (spatial_mask .expand_as (loss ).sum ().clamp (min = 1e-5 ))
182+ if mask is not None :
183+ masked_loss = loss * mask
184+ return masked_loss .sum () / mask .expand_as (loss ).sum ().clamp (min = 1e-5 )
187185 return loss .mean ()
188186 if self .reduction == LossReduction .SUM .value :
189187 return loss .sum ()
0 commit comments