Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions metrics/bertscore/bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
""" BERTScore metric. """

import functools
import sys
from contextlib import contextmanager

import bert_score
Expand Down Expand Up @@ -79,6 +80,11 @@ def filter_log(record):
rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
baseline_path (str): Customized baseline file.
use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10.
max_length (int): Maximum sequence length for tokenization. Useful when the model does not
define ``model_max_length`` in its tokenizer config (e.g. DeBERTa variants), which causes
transformers to fall back to a huge sentinel value that overflows the Rust tokenizers backend.
When not set, the metric automatically caps any sentinel value larger than ``sys.maxsize``
to 512 to prevent the ``OverflowError``.

Returns:
precision: Precision.
Expand Down Expand Up @@ -142,6 +148,7 @@ def _compute(
rescale_with_baseline=False,
baseline_path=None,
use_fast_tokenizer=False,
max_length=None,
):

if isinstance(references[0], str):
Expand Down Expand Up @@ -200,6 +207,16 @@ def _compute(
baseline_path=baseline_path,
)

# Some models (e.g. DeBERTa) omit model_max_length from their tokenizer config.
# Transformers then sets it to a huge sentinel (~1e30) that overflows the Rust
# tokenizers backend when passed to enable_truncation(). Cap it here so that
# bert_score's internal encode calls stay within a safe integer range.
_tokenizer = self.cached_bertscorer._tokenizer
if max_length is not None:
_tokenizer.model_max_length = max_length
elif _tokenizer.model_max_length > sys.maxsize:
_tokenizer.model_max_length = 512

(P, R, F) = self.cached_bertscorer.score(
cands=predictions,
refs=references,
Expand Down
76 changes: 76 additions & 0 deletions tests/test_metric_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,82 @@ def predict(self, data, *args, **kwargs):
yield


def test_bertscore_large_model_max_length_does_not_overflow():
"""Regression test for https://github.com/huggingface/evaluate/issues/739.

Models that omit model_max_length from their tokenizer config cause transformers to
set it to a huge sentinel (~1e30). That sentinel overflows the Rust tokenizers backend
when bert_score passes it to enable_truncation(). BERTScore._compute() should clamp
the value to a safe integer before scoring.
"""
import sys
import torch

VERY_LARGE_INTEGER = int(1e30)

def bert_cos_score_idf(model, refs, *args, **kwargs):
return torch.tensor([[1.0, 1.0, 1.0]] * len(refs))

class FakeTokenizer:
model_max_length = VERY_LARGE_INTEGER

class FakeScorer:
hash = "fakehash"
_tokenizer = FakeTokenizer()

def score(self, cands, refs, **kwargs):
return (torch.tensor([1.0]), torch.tensor([1.0]), torch.tensor([1.0]))

with patch("bert_score.scorer.get_model"), patch(
"bert_score.scorer.bert_cos_score_idf"
) as mock_score, patch("bert_score.utils.get_hash", return_value="fakehash"), patch(
"bert_score.BERTScorer", return_value=FakeScorer()
):
mock_score.side_effect = bert_cos_score_idf
metric = load(os.path.join("metrics", "bertscore"))
result = metric.compute(
predictions=["hello there"],
references=["hello there"],
lang="en",
)
# model_max_length must have been capped to a safe value
assert FakeScorer._tokenizer.model_max_length <= sys.maxsize
assert result["f1"] == [1.0]


def test_bertscore_explicit_max_length_is_honoured():
"""max_length parameter is applied to the tokenizer when provided."""
import torch

def bert_cos_score_idf(model, refs, *args, **kwargs):
return torch.tensor([[1.0, 1.0, 1.0]] * len(refs))

class FakeTokenizer:
model_max_length = 512

class FakeScorer:
hash = "fakehash"
_tokenizer = FakeTokenizer()

def score(self, cands, refs, **kwargs):
return (torch.tensor([1.0]), torch.tensor([1.0]), torch.tensor([1.0]))

with patch("bert_score.scorer.get_model"), patch(
"bert_score.scorer.bert_cos_score_idf"
) as mock_score, patch("bert_score.utils.get_hash", return_value="fakehash"), patch(
"bert_score.BERTScorer", return_value=FakeScorer()
):
mock_score.side_effect = bert_cos_score_idf
metric = load(os.path.join("metrics", "bertscore"))
metric.compute(
predictions=["hello there"],
references=["hello there"],
lang="en",
max_length=256,
)
assert FakeScorer._tokenizer.model_max_length == 256


def test_seqeval_raises_when_incorrect_scheme():
metric = load(os.path.join("metrics", "seqeval"))
wrong_scheme = "ERROR"
Expand Down