@@ -477,7 +477,8 @@ def loss(
477477 all_losses = all_losses * span_mask .float ()
478478
479479 if reduction == "mean" :
480- loss = all_losses .mean ()
480+ num_valid = span_mask .float ().sum ()
481+ loss = all_losses .sum () / num_valid if num_valid > 0 else torch .tensor (0.0 , device = scores .device )
481482 elif reduction == "sum" :
482483 loss = all_losses .sum ()
483484 else :
@@ -679,7 +680,8 @@ def loss(
679680 all_losses = all_losses * mask
680681
681682 if reduction == "mean" :
682- loss = all_losses .mean ()
683+ num_valid = mask .float ().sum ()
684+ loss = all_losses .sum () / num_valid if num_valid > 0 else torch .tensor (0.0 , device = scores .device )
683685 elif reduction == "sum" :
684686 loss = all_losses .sum ()
685687 else :
@@ -976,7 +978,8 @@ def loss(
976978 all_losses = all_losses * mask_label .float ()
977979
978980 if reduction == "mean" :
979- loss = all_losses .mean ()
981+ num_valid = mask_label .float ().sum ()
982+ loss = all_losses .sum () / num_valid if num_valid > 0 else torch .tensor (0.0 , device = scores .device )
980983 elif reduction == "sum" :
981984 loss = all_losses .sum ()
982985 else :
@@ -1588,7 +1591,8 @@ def loss(
15881591 all_losses = all_losses * mask_label .float ()
15891592
15901593 if reduction == "mean" :
1591- loss = all_losses .mean ()
1594+ num_valid = mask_label .float ().sum ()
1595+ loss = all_losses .sum () / num_valid if num_valid > 0 else torch .tensor (0.0 , device = scores .device )
15921596 elif reduction == "sum" :
15931597 loss = all_losses .sum ()
15941598 else :
@@ -2477,19 +2481,20 @@ def loss(
24772481 """
24782482 all_losses = self ._loss (scores , labels , alpha , gamma , prob_margin , label_smoothing , negatives )
24792483
2480- all_losses = all_losses * (word_mask .unsqueeze (- 1 ) * prompts_embedding_mask .unsqueeze (1 )).unsqueeze (- 1 )
2484+ masked_loss = all_losses * (word_mask .unsqueeze (- 1 ) * prompts_embedding_mask .unsqueeze (1 )).unsqueeze (- 1 )
24812485
24822486 if reduction == "mean" :
2483- loss = all_losses .mean ()
2487+ num_valid = (word_mask .unsqueeze (- 1 ) * prompts_embedding_mask .unsqueeze (1 )).sum ()
2488+ loss = masked_loss .sum () / num_valid if num_valid > 0 else torch .tensor (0.0 , device = scores .device )
24842489 elif reduction == "sum" :
2485- loss = all_losses .sum ()
2490+ loss = masked_loss .sum ()
24862491 else :
24872492 warnings .warn (
24882493 f"Invalid Value for config 'loss_reduction': '{ reduction } ' \n Supported reduction modes:"
24892494 f" 'none', 'mean', 'sum'. It will be used 'sum' instead." ,
24902495 stacklevel = 2 ,
24912496 )
2492- loss = all_losses .sum ()
2497+ loss = masked_loss .sum ()
24932498 return loss
24942499
24952500 def represent_spans (
0 commit comments