Skip to content

Commit 9bc9978

Browse files
authored
Merge pull request #337 from urchade/fix/mean_loss_reduction
Fix normalization factor for mean loss reduction
2 parents 92850f4 + dab7d00 commit 9bc9978

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

gliner/modeling/base.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ def loss(
477477
all_losses = all_losses * span_mask.float()
478478

479479
if reduction == "mean":
480-
loss = all_losses.mean()
480+
num_valid = span_mask.float().sum()
481+
loss = all_losses.sum() / num_valid if num_valid > 0 else torch.tensor(0.0, device=scores.device)
481482
elif reduction == "sum":
482483
loss = all_losses.sum()
483484
else:
@@ -679,7 +680,8 @@ def loss(
679680
all_losses = all_losses * mask
680681

681682
if reduction == "mean":
682-
loss = all_losses.mean()
683+
num_valid = mask.float().sum()
684+
loss = all_losses.sum() / num_valid if num_valid > 0 else torch.tensor(0.0, device=scores.device)
683685
elif reduction == "sum":
684686
loss = all_losses.sum()
685687
else:
@@ -976,7 +978,8 @@ def loss(
976978
all_losses = all_losses * mask_label.float()
977979

978980
if reduction == "mean":
979-
loss = all_losses.mean()
981+
num_valid = mask_label.float().sum()
982+
loss = all_losses.sum() / num_valid if num_valid > 0 else torch.tensor(0.0, device=scores.device)
980983
elif reduction == "sum":
981984
loss = all_losses.sum()
982985
else:
@@ -1588,7 +1591,8 @@ def loss(
15881591
all_losses = all_losses * mask_label.float()
15891592

15901593
if reduction == "mean":
1591-
loss = all_losses.mean()
1594+
num_valid = mask_label.float().sum()
1595+
loss = all_losses.sum() / num_valid if num_valid > 0 else torch.tensor(0.0, device=scores.device)
15921596
elif reduction == "sum":
15931597
loss = all_losses.sum()
15941598
else:
@@ -2477,19 +2481,20 @@ def loss(
24772481
"""
24782482
all_losses = self._loss(scores, labels, alpha, gamma, prob_margin, label_smoothing, negatives)
24792483

2480-
all_losses = all_losses * (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1)
2484+
masked_loss = all_losses * (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).unsqueeze(-1)
24812485

24822486
if reduction == "mean":
2483-
loss = all_losses.mean()
2487+
num_valid = (word_mask.unsqueeze(-1) * prompts_embedding_mask.unsqueeze(1)).sum()
2488+
loss = masked_loss.sum() / num_valid if num_valid > 0 else torch.tensor(0.0, device=scores.device)
24842489
elif reduction == "sum":
2485-
loss = all_losses.sum()
2490+
loss = masked_loss.sum()
24862491
else:
24872492
warnings.warn(
24882493
f"Invalid Value for config 'loss_reduction': '{reduction}' \n Supported reduction modes:"
24892494
f" 'none', 'mean', 'sum'. It will be used 'sum' instead.",
24902495
stacklevel=2,
24912496
)
2492-
loss = all_losses.sum()
2497+
loss = masked_loss.sum()
24932498
return loss
24942499

24952500
def represent_spans(

0 commit comments

Comments
 (0)