Skip to content

Commit 73433ee

Browse files
authored
Merge pull request #222 from howard0su/test_correctness
bench: add GSM8K/HumanEval correctness scoring
2 parents 97197a9 + 882ca14 commit 73433ee

5 files changed

Lines changed: 393 additions & 51 deletions

File tree

dflash/scripts/bench_llm.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,18 @@
3838
BUDGET = 22 # default; overridden by --budget CLI arg
3939
N_SAMPLE = 10
4040

41+
def _gsm_gold(x):
42+
"""Extract numeric answer after #### from GSM8K answer field."""
43+
ans = x["answer"]
44+
idx = ans.rfind("####")
45+
if idx >= 0:
46+
return ans[idx + 4:].strip().replace(",", "")
47+
return ans.strip()
48+
49+
4150
BENCHES = [
4251
("HumanEval", "openai_humaneval", None, "test", lambda x: x["prompt"], None, N_GEN),
43-
("GSM8K", "gsm8k", "main", "test", lambda x: f"Question: {x['question']}\nAnswer: ", None, N_GEN),
52+
("GSM8K", "gsm8k", "main", "test", lambda x: f"Question: {x['question']}\nAnswer: ", _gsm_gold, 1024),
4453
("Math500", "HuggingFaceH4/MATH-500", None, "test", lambda x: f"Problem: {x['problem']}\nSolution: Put your final answer in \\boxed{{}}.\n", lambda x: x["answer"], 2048),
4554
]
4655

@@ -182,6 +191,9 @@ def _normalize_math(s: str) -> str:
182191
s = s.strip()
183192
if s.startswith("$") and s.endswith("$"):
184193
s = s[1:-1].strip()
194+
# Strip currency $ (e.g. "$18" → "18")
195+
if re.match(r'^\$\d', s):
196+
s = s[1:]
185197
s = re.sub(r"\\text\s*\{([^}]*)\}", r"\1", s)
186198
s = re.sub(r"\\mathrm\s*\{([^}]*)\}", r"\1", s)
187199
for cmd in [r"\left", r"\right", r"\displaystyle", r"\tfrac", r"\dfrac"]:
@@ -240,23 +252,14 @@ def _math_equiv(pred: str, gold: str) -> bool:
240252

241253

242254
def score_math(output_bin: Path, gold_answer: str, tok) -> tuple[bool, str]:
243-
"""Score a Math500 output against the gold answer.
244-
245-
Extracts \\boxed{} answers from model output (after </think> for thinking
246-
models), compares against gold with normalized string matching + numeric/
247-
fraction equivalence. Returns (correct, detail_str).
248-
"""
255+
"""Score a Math500 output against the gold answer. Returns (correct, detail_str)."""
249256
ids = _read_ids(output_bin)
250257
text = tok.decode(ids)
251258

252259
think_end = text.rfind("</think>")
253260
answer_text = text[think_end + len("</think>"):] if think_end >= 0 else text
254261

255262
pred = _extract_boxed(answer_text)
256-
if not pred:
257-
pred = _extract_boxed(text)
258-
if not pred:
259-
pred = None
260263

261264
# Fallback: "the answer is **X**" patterns
262265
if pred is None:
@@ -285,6 +288,66 @@ def score_math(output_bin: Path, gold_answer: str, tok) -> tuple[bool, str]:
285288
return correct, detail
286289

287290

291+
def score_gsm(output_bin: Path, gold_answer: str, tok) -> tuple[bool, str]:
292+
"""Score a GSM8K output against the gold numeric answer. Returns (correct, detail_str)."""
293+
ids = _read_ids(output_bin)
294+
text = tok.decode(ids)
295+
296+
think_end = text.rfind("</think>")
297+
answer_text = text[think_end + len("</think>"):] if think_end >= 0 else text
298+
299+
pred = None
300+
301+
# \boxed{<number>}
302+
boxed = _extract_boxed(answer_text)
303+
if boxed:
304+
cleaned = boxed.replace(",", "").replace("$", "").strip()
305+
if re.match(r'^[+-]?\d+\.?\d*$', cleaned):
306+
pred = cleaned
307+
308+
# #### <number>
309+
if pred is None:
310+
m = re.search(r'####\s*\$?([+-]?\d[\d,]*\.?\d*)', answer_text)
311+
if m:
312+
pred = m.group(1).replace(",", "")
313+
314+
# "the answer is **X**"
315+
if pred is None:
316+
m = re.search(
317+
r'(?:answer\s+is|result\s+is|equals?|there\s+are|we\s+get)\s*\*?\*?\$?([+-]?\d[\d,]*\.?\d*)',
318+
answer_text, re.IGNORECASE)
319+
if m:
320+
pred = m.group(1).replace(",", "")
321+
322+
# **<number>** or **$<number>**
323+
if pred is None:
324+
m = re.search(r'\*\*\$?([+-]?\d[\d,]*\.?\d*)\*\*', answer_text)
325+
if m:
326+
pred = m.group(1).replace(",", "")
327+
328+
# Last standalone number
329+
if pred is None:
330+
nums = re.findall(r'(?<![.\d])([+-]?\d[\d,]*\.?\d*)(?![.\d])', answer_text)
331+
if nums:
332+
pred = nums[-1].replace(",", "")
333+
334+
correct = False
335+
if pred is not None:
336+
try:
337+
correct = abs(float(pred) - float(gold_answer)) < 1e-6
338+
except (ValueError, TypeError):
339+
correct = pred.strip() == gold_answer.strip()
340+
341+
if correct:
342+
detail = f"🎯 {pred}"
343+
elif pred:
344+
detail = f"✗ pred={pred} gold={gold_answer}"
345+
else:
346+
detail = f"✗ no answer found, gold={gold_answer}"
347+
return correct, detail
348+
349+
350+
288351
def main():
289352
global DRAFT, BUDGET
290353

@@ -357,7 +420,10 @@ def _wrap_prompt(raw_prompt: str) -> str:
357420

358421
score_detail = ""
359422
if gold is not None:
360-
correct, score_detail = score_math(df_bin, gold, tok)
423+
if name == "GSM8K":
424+
correct, score_detail = score_gsm(df_bin, gold, tok)
425+
else:
426+
correct, score_detail = score_math(df_bin, gold, tok)
361427
n_scored += 1
362428
if correct:
363429
n_score_correct += 1

harness/benchmarks/generation_benchmark.py

Lines changed: 174 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,150 @@ def approx_token_count(text: str) -> int:
5252
return max(1, len(re.findall(r"\S+", text)))
5353

5454

55+
def _extract_boxed(text: str) -> str | None:
56+
"""Extract the last \\boxed{...} from a string, handling nested braces."""
57+
results = []
58+
i = 0
59+
while i < len(text):
60+
idx = text.find("\\boxed{", i)
61+
if idx == -1:
62+
break
63+
start = idx + len("\\boxed{")
64+
depth = 1
65+
j = start
66+
while j < len(text) and depth > 0:
67+
if text[j] == "{":
68+
depth += 1
69+
elif text[j] == "}":
70+
depth -= 1
71+
j += 1
72+
if depth == 0:
73+
results.append(text[start:j-1].strip())
74+
i = j
75+
return results[-1] if results else None
76+
77+
78+
def _normalize_math(s: str) -> str:
79+
"""Normalize a math answer string for comparison."""
80+
if s is None:
81+
return ""
82+
s = s.strip()
83+
if s.startswith("$") and s.endswith("$"):
84+
s = s[1:-1].strip()
85+
s = re.sub(r"\\text\s*\{([^}]*)\}", r"\1", s)
86+
s = re.sub(r"\\mathrm\s*\{([^}]*)\}", r"\1", s)
87+
for cmd in [r"\left", r"\right", r"\displaystyle", r"\tfrac", r"\dfrac"]:
88+
s = s.replace(cmd, "")
89+
s = re.sub(r"\s+", " ", s).strip()
90+
s = s.rstrip(".,")
91+
return s
92+
93+
94+
def _math_equiv(pred: str, gold: str) -> bool:
95+
"""Check if two math answers are equivalent."""
96+
if pred is None or gold is None:
97+
return False
98+
p = _normalize_math(pred)
99+
g = _normalize_math(gold)
100+
if p == g:
101+
return True
102+
try:
103+
pf = float(p.replace(",", ""))
104+
gf = float(g.replace(",", ""))
105+
return abs(pf - gf) < 1e-6
106+
except (ValueError, TypeError):
107+
pass
108+
frac_pat = re.compile(r"\\?frac\s*\{([^}]+)\}\s*\{([^}]+)\}")
109+
for s, other in [(p, g), (g, p)]:
110+
m = frac_pat.search(s)
111+
if m:
112+
try:
113+
val = float(m.group(1)) / float(m.group(2))
114+
oval = float(other.replace(",", ""))
115+
if abs(val - oval) < 1e-6:
116+
return True
117+
except (ValueError, ZeroDivisionError):
118+
pass
119+
return False
120+
121+
122+
def _extract_numeric_answer(text: str) -> str | None:
123+
"""Extract a numeric answer from model output for GSM-style problems."""
124+
think_end = text.rfind("</think>")
125+
answer_text = text[think_end + len("</think>"):] if think_end >= 0 else text
126+
127+
# #### <number>
128+
m = re.search(r'####\s*([+-]?\d[\d,]*\.?\d*)', answer_text)
129+
if m:
130+
return m.group(1).replace(",", "")
131+
132+
# \boxed{<number>}
133+
boxed = _extract_boxed(answer_text)
134+
if boxed:
135+
cleaned = boxed.replace(",", "").strip()
136+
if re.match(r'^[+-]?\d+\.?\d*$', cleaned):
137+
return cleaned
138+
139+
# "the answer is <number>"
140+
m = re.search(
141+
r'(?:answer\s+is|result\s+is|equals?|there\s+are|we\s+get)\s*\$?\s*\\?(?:boxed\{)?([+-]?\d[\d,]*\.?\d*)',
142+
answer_text, re.IGNORECASE)
143+
if m:
144+
return m.group(1).replace(",", "")
145+
146+
# **<number>**
147+
m = re.search(r'\*\*([+-]?\d[\d,]*\.?\d*)\*\*', answer_text)
148+
if m:
149+
return m.group(1).replace(",", "")
150+
151+
# Last standalone number
152+
nums = re.findall(r'(?<![.\d])([+-]?\d[\d,]*\.?\d*)(?![.\d])', answer_text)
153+
if nums:
154+
return nums[-1].replace(",", "")
155+
156+
return None
157+
158+
159+
def score_gold_answer(case: dict[str, Any], text: str) -> tuple[bool | None, str]:
160+
"""Score model output against gold_answer if present.
161+
162+
Returns (correct_or_None, detail_str). None means no gold_answer to check.
163+
"""
164+
gold = case.get("gold_answer")
165+
if gold is None:
166+
return None, ""
167+
168+
suite = case.get("suite", "")
169+
think_end = text.rfind("</think>")
170+
answer_text = text[think_end + len("</think>"):] if think_end >= 0 else text
171+
172+
if suite == "gsm":
173+
pred = _extract_numeric_answer(text)
174+
if pred is None:
175+
return False, f"no numeric answer found, gold={gold}"
176+
try:
177+
correct = abs(float(pred) - float(gold)) < 1e-6
178+
except (ValueError, TypeError):
179+
correct = pred.strip() == gold.strip()
180+
return correct, f"pred={pred} gold={gold}"
181+
else:
182+
# Math-style: extract \boxed{} and compare
183+
pred = _extract_boxed(answer_text)
184+
if not pred:
185+
pred = _extract_boxed(text)
186+
if not pred:
187+
# Fallback: bold pattern
188+
m = re.search(
189+
r'(?:answer\s+is|result\s+is|equals?)\s*\*\*(.+?)\*\*',
190+
answer_text, re.IGNORECASE)
191+
if m:
192+
pred = m.group(1).strip().rstrip(".")
193+
if not pred:
194+
return False, f"no answer found, gold={gold}"
195+
correct = _math_equiv(pred, gold)
196+
return correct, f"pred={pred} gold={gold}"
197+
198+
55199
def expected_pass(case: dict[str, Any], text: str) -> tuple[bool, list[str]]:
56200
failures: list[str] = []
57201
for needle in case.get("expect_contains", []):
@@ -142,6 +286,7 @@ def run_case(
142286
token_source = "approx_words"
143287
prompt_tokens = usage.get("prompt_tokens")
144288
pass_expected, failures = expected_pass(case, text)
289+
gold_correct, gold_detail = score_gold_answer(case, text)
145290
runs.append(
146291
{
147292
"elapsed_s": elapsed,
@@ -151,23 +296,29 @@ def run_case(
151296
"token_count_source": token_source,
152297
"expected_pass": pass_expected,
153298
"expected_failures": failures,
299+
"gold_correct": gold_correct,
300+
"gold_detail": gold_detail,
154301
"text": text,
155302
"usage": usage,
156303
}
157304
)
158305

159306
tok_s_values = [r["tok_s"] for r in runs]
160307
elapsed_values = [r["elapsed_s"] for r in runs]
308+
gold_results = [r["gold_correct"] for r in runs if r["gold_correct"] is not None]
161309
return {
162310
"id": case["id"],
163311
"description": case.get("description", ""),
164312
"expect_contains": case.get("expect_contains", []),
165313
"expect_regex": case.get("expect_regex", []),
314+
"gold_answer": case.get("gold_answer"),
166315
"runs": runs,
167316
"mean_tok_s": statistics.mean(tok_s_values),
168317
"median_tok_s": statistics.median(tok_s_values),
169318
"mean_elapsed_s": statistics.mean(elapsed_values),
170319
"expected_pass": all(r["expected_pass"] for r in runs),
320+
"gold_correct": all(gold_results) if gold_results else None,
321+
"gold_detail": runs[-1].get("gold_detail", ""),
171322
"text": runs[-1]["text"],
172323
"completion_tokens": runs[-1]["completion_tokens"],
173324
"prompt_tokens": runs[-1]["prompt_tokens"],
@@ -179,20 +330,25 @@ def cmd_run(args: argparse.Namespace) -> int:
179330
cases = load_cases(Path(args.prompts))
180331
results = []
181332
for case in cases:
182-
print(f"[bench] {args.name}: {case['id']}", flush=True)
183-
results.append(
184-
run_case(
185-
case=case,
186-
base_url=args.url,
187-
api_key=args.api_key,
188-
model=args.model,
189-
max_tokens=args.max_tokens,
190-
temperature=args.temperature,
191-
timeout=args.timeout,
192-
repeats=args.repeats,
193-
)
333+
print(f"[bench] {args.name}: {case['id']}", end="", flush=True)
334+
result = run_case(
335+
case=case,
336+
base_url=args.url,
337+
api_key=args.api_key,
338+
model=args.model,
339+
max_tokens=args.max_tokens,
340+
temperature=args.temperature,
341+
timeout=args.timeout,
342+
repeats=args.repeats,
194343
)
195-
344+
results.append(result)
345+
if result["gold_correct"] is not None:
346+
mark = "🎯" if result["gold_correct"] else "✗"
347+
print(f" {mark} {result['gold_detail']}", flush=True)
348+
else:
349+
print(flush=True)
350+
351+
scored = [r for r in results if r["gold_correct"] is not None]
196352
report = {
197353
"name": args.name,
198354
"url": args.url,
@@ -206,13 +362,18 @@ def cmd_run(args: argparse.Namespace) -> int:
206362
"summary": {
207363
"cases": len(results),
208364
"expected_pass": sum(1 for r in results if r["expected_pass"]),
365+
"gold_correct": sum(1 for r in scored if r["gold_correct"]),
366+
"gold_scored": len(scored),
209367
"mean_tok_s": statistics.mean([r["mean_tok_s"] for r in results]) if results else 0.0,
210368
},
211369
}
212370
out = Path(args.json_out)
213371
out.parent.mkdir(parents=True, exist_ok=True)
214372
out.write_text(json.dumps(report, indent=2, sort_keys=True), encoding="utf-8")
215373
print(f"[bench] wrote {out}")
374+
if scored:
375+
print(f"[bench] correctness: {report['summary']['gold_correct']}/{len(scored)}"
376+
f" ({report['summary']['gold_correct']/len(scored)*100:.0f}%)")
216377
return 0 if report["summary"]["expected_pass"] == len(results) else 1
217378

218379

0 commit comments

Comments
 (0)