Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/lighteval/metrics/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
import re
import string
import sys
Expand All @@ -31,6 +32,8 @@
from lighteval.utils.imports import Extra, requires
from lighteval.utils.language import Language

logger = logging.getLogger(__name__)


# From HELM
def helm_normalizer(text: str) -> str:
Expand Down Expand Up @@ -523,8 +526,14 @@ def normalize_log_probs(
normalized_log_probs = [choices_logprob[ix] / len(choice) for ix, choice in enumerate(choices_text)]
case LogProbTokenNorm():
assert choices_tokens is not None, "choices_tokens must be provided for token normalization"
n = min(len(choices_logprob), len(choices_tokens))
if n < len(choices_logprob):
logger.warning(
f"choices_tokens length ({len(choices_tokens)}) is less than choices_logprob length "
f"({len(choices_logprob)}). This may indicate corrupted cache data. Truncating to {n} elements."
)
normalized_log_probs = [
choices_logprob[ix] / len(choices_tokens[ix]) for ix in range(len(choices_logprob))
choices_logprob[ix] / len(choices_tokens[ix]) for ix in range(n)
]
case LogProbPMINorm():
assert unconditioned_logprob is not None, "unconditioned_logprob must be provided for PMI normalization"
Expand Down