Skip to content

Commit 5240b3b

Browse files
authored
Add post-processing functions to extract math500 answers (#225)
* 1> Add mmlu dataset; 2> fix minor bug for mmlu prompt; 3> add scikit-learn to requirements.in * update the full dataset to a sampled dataset (1200 examples from 2 subjects) * add overall_throughput to metrics; add flag to filter data based on min_input_length * fix pylint issues * add postprocessing functions for math500 answers * add sympy to requirements.txt; move some functions to math_utils.py * fix lint warnings * fix pylint warnings
1 parent f5c19bc commit 5240b3b

4 files changed

Lines changed: 372 additions & 27 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ unit-tests:
5151
coverage run -m unittest -v
5252

5353
check-test-coverage:
54-
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py" --fail-under=96
54+
coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/math_utils.py" --fail-under=96

benchmarks/eval_accuracy.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,75 @@
1818
import nltk
1919
import evaluate
2020
import json
21+
import re
2122

2223
import numpy as np
24+
from benchmarks.math_utils import extract_numbers, post_processing_math_ans, sympify_set
2325

2426

2527
def extract_boxed_answers(text):
2628
pieces = text.split("boxed{")
2729
if len(pieces) == 1:
28-
return ""
30+
return [""]
2931
piece = pieces[1]
30-
n = 0
31-
for i in range(len(piece)):
32-
if piece[i] == "{":
33-
n += 1
34-
elif piece[i] == "}":
35-
n -= 1
36-
if n < 0:
37-
if i + 1 < len(piece) and piece[i + 1] == "%":
38-
return piece[: i + 1]
39-
else:
40-
return piece[:i]
41-
return ""
42-
43-
44-
def replace_space_answers(text):
45-
return text.replace(" ", "")
46-
47-
48-
def special_handling(text):
49-
return text.replace("\\dfrac", "\\frac")
32+
ans = []
33+
for piece in pieces[1:]:
34+
n = 0
35+
for i in range(len(piece)):
36+
if piece[i] == "{":
37+
n += 1
38+
elif piece[i] == "}":
39+
n -= 1
40+
if n < 0:
41+
if i + 1 < len(piece) and piece[i + 1] == "%":
42+
ans.append(piece[: i + 1])
43+
break
44+
else:
45+
ans.append(piece[:i])
46+
break
47+
if ans:
48+
return ans
49+
else:
50+
return [""]
51+
52+
53+
def extract_answer(pred_str, exhaust=False):
54+
pred = []
55+
if "boxed{" in pred_str:
56+
pred = extract_boxed_answers(pred_str)
57+
elif "Answer:" in pred_str:
58+
matches = re.findall(r"Answer:[\*]*\s+(\S*.*)", pred_str)
59+
if matches:
60+
pred = [extract_numbers(matches[-1])]
61+
elif "the answer is" in pred_str:
62+
pred = [extract_numbers(pred_str.split("the answer is")[-1].strip())]
63+
elif "final answer is $" in pred_str and "$. I hope" in pred_str:
64+
tmp = pred_str.split("final answer is $", 1)[1]
65+
pred = [tmp.split("$. I hope", 1)[0].strip()]
66+
else: # use the last number
67+
pattern = r"-?\d*\.?\d+"
68+
ans = re.findall(pattern, pred_str.replace(",", ""))
69+
if len(ans) >= 1:
70+
ans = ans[-1]
71+
else:
72+
ans = ""
73+
if ans:
74+
pred.append(ans)
75+
# multiple line
76+
pred_list = []
77+
for ans in pred:
78+
ans = ans.replace("<|end_of_text|>", "")
79+
ans = ans.strip().split("\n")[0]
80+
ans = ans.lstrip(":")
81+
ans = ans.lstrip("$")
82+
ans = ans.rstrip("$")
83+
ans = ans.rstrip(".")
84+
ans = ans.rstrip("/")
85+
pred_list.append(ans)
86+
if exhaust:
87+
return pred_list
88+
else:
89+
return pred_list[-1] if pred_list else ""
5090

5191

5292
def postprocess_text(preds, targets):
@@ -71,11 +111,15 @@ def eval_accuracy(request_outputs_dict, match_type):
71111
correct_ans = 0
72112
wrong_ans = 0
73113
for p, t in zip(preds, targets):
74-
ans = extract_boxed_answers(p)
75-
ans = replace_space_answers(ans)
76-
ans = special_handling(ans)
77-
tt = replace_space_answers(t)
78-
if tt == ans:
114+
115+
p = extract_answer(p)
116+
ans_set = post_processing_math_ans(p)
117+
sympified_ans_set = sympify_set(ans_set)
118+
119+
target_set = post_processing_math_ans(t)
120+
sympified_target_set = sympify_set(target_set)
121+
122+
if sympified_target_set == sympified_ans_set:
79123
correct_ans += 1
80124
continue
81125
wrong_ans += 1

0 commit comments

Comments
 (0)