Skip to content

Commit 8d465bc

Browse files
committed
ifbench updates
1 parent fe95eb5 commit 8d465bc

4 files changed

Lines changed: 81 additions & 99 deletions

File tree

eval_protocol/rewards/ifeval/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ score = ifeval_partial_credit_reward(response, ground_truth)
2222
## Dependencies
2323

2424
```bash
25-
pip install spacy nltk langdetect emoji syllapy immutabledict
26-
python -m spacy download en_core_web_sm
25+
pip install nltk langdetect emoji syllapy immutabledict absl-py
2726
```
2827

28+
NLTK resources are downloaded automatically on first use.
29+
2930
## Notes
3031

3132
- Automatically strips `<think>...</think>` tags before evaluation
@@ -39,7 +40,7 @@ python -m spacy download en_core_web_sm
3940
- `ifeval_registry.py` (from `instructions_registry.py`)
4041
- `ifeval_util.py` (from `instructions_util.py`)
4142

42-
**Copied from `IFBench/`:**
43+
**Copied from `IFBench/` (commit 8e6a9be, 2025-01):**
4344
- `ifbench_instructions.py` (from `instructions.py`)
4445
- `ifbench_registry.py` (from `instructions_registry.py`)
4546
- `ifbench_util.py` (from `instructions_util.py`)

eval_protocol/rewards/ifeval/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
"""IFEval reward function for evaluating instruction-following capabilities.
22
33
Usage:
4-
# Option 1: Import spacy first to avoid cupy conflicts in some Docker environments
5-
import spacy
6-
from eval_protocol.rewards.ifeval import ifeval_partial_credit_reward
7-
8-
# Option 2: Direct import (add ifeval dir to path)
94
import sys
105
sys.path.insert(0, '/path/to/eval_protocol/rewards/ifeval')
116
from reward import ifeval_partial_credit_reward

eval_protocol/rewards/ifeval/ifbench_instructions.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,28 @@
1515
"""Library of instructions."""
1616

1717
import logging
18+
import os
1819
import random
1920
import re
2021
import string
22+
from pathlib import Path
2123
from typing import Dict, Optional, Sequence, Union
24+
25+
# Set NLTK data path to local directory before importing nltk
26+
_nltk_data_dir = Path(__file__).parent / ".nltk_data"
27+
_nltk_data_dir.mkdir(exist_ok=True)
28+
os.environ.setdefault("NLTK_DATA", str(_nltk_data_dir))
29+
2230
import nltk
23-
import spacy
24-
from spacy.cli import download
31+
nltk.data.path.insert(0, str(_nltk_data_dir))
2532
import emoji
2633
import syllapy
2734
import unicodedata
2835
from collections import Counter
2936
import csv
3037
import io
3138

32-
try:
33-
from . import ifbench_util as instructions_util
34-
except ImportError:
35-
import ifbench_util as instructions_util
36-
37-
download('en_core_web_sm')
39+
import ifbench_util as instructions_util
3840

3941
logger = logging.getLogger(__name__)
4042

@@ -208,6 +210,8 @@ def get_instruction_args_keys(self):
208210
def check_following(self, value):
209211
"""Checks if the response contains the expected percentage of stop words."""
210212
num_words = instructions_util.count_words(value)
213+
if num_words == 0:
214+
return False
211215
num_stopwords = instructions_util.count_stopwords(value)
212216
stopword_percentage = (num_stopwords / num_words) * 100
213217
return stopword_percentage <= self._percentage
@@ -219,7 +223,7 @@ class SentTypeRatioChecker(Instruction):
219223
def build_description(self):
220224
"""Build the instruction description."""
221225
self._description_pattern = "Maintain a 2:1 ratio of declarative to interrogative sentences."
222-
nltk.download('punkt_tab')
226+
223227
return self._description_pattern
224228

225229
def get_instruction_args(self):
@@ -245,7 +249,7 @@ class SentBalanceChecker(Instruction):
245249

246250
def build_description(self):
247251
"""Build the instruction description."""
248-
nltk.download('punkt_tab')
252+
249253
self._description_pattern = "Ensure that the ratio of sentence types (declarative, interrogative, exclamatory) is balanced."
250254
return self._description_pattern
251255

@@ -310,7 +314,7 @@ def check_following(self, value):
310314

311315

312316
class PersonNameCountChecker(Instruction):
313-
"""Mention at least {N} different person names in the response."""
317+
"""Mention at least {N} different person names in the response, from this list of person names: Emma, Liam, Sophia..."""
314318

315319
def build_description(self, *, N=None):
316320
"""Build the instruction description.
@@ -326,8 +330,6 @@ def build_description(self, *, N=None):
326330
if self._num_person_names is None or self._num_person_names < 0:
327331
self._num_person_names = random.randint(1, 50)
328332

329-
self.nlp = spacy.load("en_core_web_sm")
330-
331333
self._description_pattern = "Mention at least {N} different person names in the response, from this list of person names: Emma, Liam, Sophia, Jackson, Olivia, Noah, Ava, Lucas, Isabella, Mason, Mia, Ethan, Charlotte, Alexander, Amelia, Benjamin, Harper, Leo, Zoe, Daniel, Chloe, Samuel, Lily, Matthew, Grace, Owen, Abigail, Gabriel, Ella, Jacob, Scarlett, Nathan, Victoria, Elijah, Layla, Nicholas, Audrey, David, Hannah, Christopher, Penelope, Thomas, Nora, Andrew, Aria, Joseph, Claire, Ryan, Stella, Jonathan ."
332334
return self._description_pattern.format(N=self._num_person_names)
333335

@@ -384,7 +386,9 @@ def check_following(self, value):
384386
# Extract the named entities
385387
person_names = []
386388
for name in person_name_list:
387-
if name in value:
389+
# Use regex with word boundaries
390+
pattern = r'\b{}\b'.format(re.escape(name))
391+
if re.search(pattern, value):
388392
person_names.append(name)
389393
unique_person_names = set(person_names)
390394

@@ -426,6 +430,8 @@ def check_following(self, value):
426430
n = 3
427431
ngrams = set(nltk.ngrams(value, n))
428432
ref_ngrams = set(nltk.ngrams(self._reference_text, n))
433+
if not ngrams:
434+
return False
429435
overlap = len(ngrams.intersection(ref_ngrams)) / len(ngrams)
430436
return self._percentage - 2 <= overlap * 100 <= self._percentage + 2
431437

@@ -486,6 +492,8 @@ def check_following(self, value):
486492
"""Checks if each word of the response starts with the next letter of the alphabet."""
487493
value = value.translate(str.maketrans('', '', string.punctuation))
488494
words = value.strip(''.join(string.punctuation) + ' ').split()
495+
if not words:
496+
return False
489497
alphabet = string.ascii_lowercase
490498
correct_letter = words[0][0].lower()
491499
if correct_letter not in alphabet: # numbers are fails
@@ -564,7 +572,7 @@ class IncrementingAlliterationChecker(Instruction):
564572

565573
def build_description(self):
566574
"""Build the instruction description."""
567-
nltk.download('punkt_tab')
575+
568576
self._description_pattern = "Each sentence must have a longer sequence of consecutive alliterative words than the previous one."
569577
return self._description_pattern
570578

@@ -851,7 +859,7 @@ class EmojiSentenceChecker(Instruction):
851859

852860
def build_description(self):
853861
"""Build the instruction description."""
854-
nltk.download('punkt_tab')
862+
855863
self._description_pattern = "Please use an emoji at the end of every sentence."
856864
return self._description_pattern
857865

@@ -869,6 +877,9 @@ def check_following(self, value):
869877
sentences = instructions_util.split_into_sentences(value)
870878
for i, sentence in enumerate(sentences):
871879
stripped = sentence.translate(str.maketrans('', '', string.punctuation)).strip()
880+
# check for empty string
881+
if not stripped:
882+
return False
872883
last_char = stripped[-1]
873884
# because blank spaces are treated oddly
874885
second_last_char = stripped[-2] if len(stripped) > 1 else stripped[-1]
@@ -891,7 +902,7 @@ class CharacterCountUniqueWordsChecker(Instruction):
891902

892903
def build_description(self):
893904
"""Build the instruction description."""
894-
nltk.download('punkt_tab')
905+
895906
self._description_pattern = "Respond with three sentences, all containing the same number of characters but using all different words."
896907
return self._description_pattern
897908

@@ -980,7 +991,7 @@ class StartWithVerbChecker(Instruction):
980991
def build_description(self):
981992
"""Build the instruction description."""
982993
self._description_pattern = "The response must start with a verb."
983-
nltk.download('averaged_perceptron_tagger_eng')
994+
984995
return self._description_pattern
985996

986997
def get_instruction_args(self):
@@ -1050,7 +1061,7 @@ def build_description(self, *, word=None, N=None):
10501061
Returns:
10511062
A string representing the instruction description.
10521063
"""
1053-
nltk.download('punkt_tab')
1064+
10541065

10551066
if not word:
10561067
self._keyword = instructions_util.generate_keywords(
@@ -1078,7 +1089,9 @@ def check_following(self, value):
10781089
sentences = instructions_util.split_into_sentences(value)
10791090
if len(sentences) < self._keyword_position:
10801091
return False
1081-
return self._keyword.lower() in sentences[int(self._keyword_position - 1)].lower()
1092+
# Use regex with word boundaries for robust matching
1093+
pattern = r'\b{}\b'.format(re.escape(self._keyword))
1094+
return bool(re.search(pattern, sentences[int(self._keyword_position - 1)], re.IGNORECASE))
10821095

10831096

10841097
class PronounCountChecker(Instruction):
@@ -1117,8 +1130,8 @@ def check_following(self, value):
11171130
'itself', 'they', 'them', 'their', 'theirs', 'themselves'])
11181131
value = value.replace('/',
11191132
' ') # to correctly count pronoun sets like she/her/hers, a common use case of pronouns
1120-
value = value.lower().translate(str.maketrans('', '', string.punctuation))
1121-
words = value.split()
1133+
# Use NLTK word_tokenize for better tokenization
1134+
words = nltk.word_tokenize(value.lower())
11221135
pronoun_count = sum(1 for word in words if word in pronouns)
11231136
return pronoun_count >= self._num_pronouns
11241137

@@ -1151,7 +1164,7 @@ class LastWordFirstNextChecker(Instruction):
11511164

11521165
def build_description(self):
11531166
"""Build the instruction description."""
1154-
nltk.download('punkt_tab')
1167+
11551168
self._description_pattern = "The last word of each sentence must become the first word of the next sentence."
11561169
return self._description_pattern
11571170

@@ -1167,9 +1180,11 @@ def check_following(self, value):
11671180
"""Checks if the last word of each sentence in the response is the first word of the next sentence."""
11681181
sentences = instructions_util.split_into_sentences(value)
11691182
for i in range(len(sentences) - 1):
1170-
last_word = sentences[i].rstrip(''.join(string.punctuation) + ' ').split()[-1]
1171-
first_word = sentences[i + 1].lstrip(''.join(string.punctuation) + ' ').split()[0]
1172-
if last_word.lower() != first_word.lower():
1183+
last_words = sentences[i].rstrip(''.join(string.punctuation) + ' ').split()
1184+
first_words = sentences[i + 1].lstrip(''.join(string.punctuation) + ' ').split()
1185+
if not last_words or not first_words:
1186+
return False
1187+
if last_words[-1].lower() != first_words[0].lower():
11731188
return False
11741189
return True
11751190

@@ -1222,7 +1237,7 @@ def build_description(self, *, small_n=None):
12221237
if self._num_increment is None or self._num_increment < 0:
12231238
self._num_increment = random.randint(1, _NUM_INCREMENT)
12241239

1225-
nltk.download('punkt_tab')
1240+
12261241

12271242
self._description_pattern = "Each sentence must contain exactly {small_n} more words than the previous one."
12281243
return self._description_pattern.format(small_n=self._num_increment)
@@ -1326,12 +1341,13 @@ def get_instruction_args_keys(self):
13261341
def check_following(self, value):
13271342
"""Checks if there are no quotes next to each other
13281343
and the passage does not end with a quote."""
1329-
value = value.replace('', '"').replace('', '"')
1344+
value = value.replace('"', '"').replace('"', '"')
13301345
value = value.replace("'\"'", '') # remove references to the character '"'
13311346
value = ''.join(value.split()) # remove all whitespace
13321347
if '""' in value:
13331348
return False
1334-
if value.strip(string.digits + string.punctuation.replace('"', ''))[-1] == '"':
1349+
stripped = value.strip(string.digits + string.punctuation.replace('"', ''))
1350+
if stripped and stripped[-1] == '"':
13351351
return False
13361352
return True
13371353

@@ -1605,7 +1621,7 @@ class WordReverseOrderChecker(Instruction):
16051621
"""What animal is the national symbol of the US? Respond to this query, but make your sentence in reverse order of what it should be, per word."""
16061622

16071623
def build_description(self, **kwargs):
1608-
nltk.download('punkt_tab')
1624+
16091625
self._description_pattern = "What animal is the national symbol of the US? Respond to this query, but make your sentence in reverse order of what it should be, per word."
16101626
return self._description_pattern
16111627

@@ -1650,7 +1666,7 @@ class SentenceAlphabetChecker(Instruction):
16501666
"""Tell me a 26-sentence story where each sentence's first word starts with the letters of the alphabet in order."""
16511667

16521668
def build_description(self, **kwargs):
1653-
nltk.download('punkt_tab')
1669+
16541670
self._description_pattern = "Tell me a 26-sentence story where each sentence's first word starts with the letters of the alphabet in order."
16551671
return self._description_pattern
16561672

@@ -1667,7 +1683,10 @@ def check_following(self, value):
16671683
if len(sentences) != 26:
16681684
return False
16691685
for i, sentence in enumerate(sentences):
1670-
if sentence.lstrip().split()[0].lower()[0] != chr(97 + i):
1686+
words = sentence.lstrip().split()
1687+
if not words or not words[0]:
1688+
return False
1689+
if words[0].lower()[0] != chr(97 + i):
16711690
return False
16721691
return True
16731692

@@ -1976,7 +1995,7 @@ def check_following(self, value):
19761995
words = instructions_util.nltk.word_tokenize(sentences[self._n - 1])
19771996
if len(words) < self._m:
19781997
return False
1979-
if words[self._m - 1] == self._keyword:
1998+
if words[self._m - 1].lower() == self._keyword.lower():
19801999
return True
19812000
else:
19822001
return False
@@ -2024,7 +2043,7 @@ def check_following(self, value):
20242043
words = instructions_util.nltk.word_tokenize(value)
20252044
if len(words) < 2:
20262045
return False
2027-
if words[1] == words[-2] == self._keyword:
2046+
if words[1].lower() == words[-2].lower() == self._keyword.lower():
20282047
return True
20292048
else:
20302049
return False
@@ -2109,7 +2128,7 @@ def check_following(self, value):
21092128

21102129

21112130
class RepeatSpanChecker(Instruction):
2112-
"Copy the span of words that lies between (and including) index {n_start} and {n_end}, the indices are word indices, split by whitespace!"
2131+
"Copy the span of words that lies between (and including) index {n_start} and {n_end}, the indices are character indices!"
21132132

21142133
def build_description(self, prompt_to_repeat=None, n_start=None, n_end=None):
21152134
"""Build the instruction description.
@@ -2183,6 +2202,12 @@ def check_following(self, value):
21832202
"""
21842203
words = instructions_util.nltk.word_tokenize(value)
21852204
for word in words:
2205+
if not word or not word[0].isalpha():
2206+
continue
2207+
if len(word) == 1:
2208+
if word[0].islower():
2209+
return False
2210+
continue
21862211
if word[0].isupper() and word[1:].islower():
21872212
continue
21882213
elif word[0].islower() and word[1:].isupper():

0 commit comments

Comments
 (0)