Skip to content

Commit b50c21d

Browse files
committed
extract incorrect task-answer pairs
1 parent ebb259e commit b50c21d

2 files changed

Lines changed: 142 additions & 0 deletions

File tree

scripts/error_cls.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from os.path import abspath, dirname, basename, join as pjoin
2+
import os
3+
4+
import orjson
5+
from pydantic import BaseModel
6+
from openai import OpenAI
7+
import fire
8+
from tfbench import (
9+
analysis_multi_runs,
10+
load_tfb_from_hf,
11+
load_gen_results_jsonl,
12+
evaluate,
13+
LMAnswer,
14+
)
15+
from tqdm import tqdm
16+
17+
from tfbench.evaluation import get_incorrect
18+
from tfbench.common import get_prompt as get_task_prompt, BenchmarkTask
19+
20+
21+
PROMPT_TEMPLATE = """
22+
The Haskell type inference task is as follows:
23+
{task}
24+
25+
The ground-truth correct answer is:
26+
{correct_answer}
27+
28+
My incorrect answer is:
29+
{wrong_answer}
30+
31+
My reasoning behind my answer is:
32+
{reasoning}
33+
34+
What mistake did I make?
35+
"""
36+
37+
INSTRUCTION = """
38+
You are a programming language and logic expert.
39+
You will be shown a Haskell type inference task,
40+
an incorrect answer, and the reasoning behind it.
41+
Your job is to identify the mistake in the answer,
42+
and classify the mistake in one word.
43+
The current error categories are:
44+
{categories}.
45+
Choose one category, or construct a new one if you are sure that
46+
none of the current categories fit.
47+
Only output the one-word classification and a short definition of the class.
48+
NOTE that the short definition should be generalized to other tasks that fall in the same category.
49+
"""
50+
51+
52+
class ClsResponse(BaseModel):
53+
category: str
54+
definition: str
55+
56+
def __hash__(self):
57+
return hash(self.category)
58+
59+
60+
def get_prompt(task: BenchmarkTask, answer: LMAnswer) -> str:
61+
prompt = PROMPT_TEMPLATE.format(
62+
task=get_task_prompt(task),
63+
correct_answer=task.signature,
64+
wrong_answer=answer.answer,
65+
reasoning=answer.reasoning_steps,
66+
)
67+
return prompt
68+
69+
70+
def categories_str(categories: set[ClsResponse]) -> str:
71+
"""dump all categories in jsonl format string"""
72+
return "\n".join(orjson.dumps(c.__dict__).decode() for c in categories)
73+
74+
75+
def classify_run(
76+
client: OpenAI,
77+
categories: set[ClsResponse],
78+
tasks: list[BenchmarkTask],
79+
run_result: list[LMAnswer | None],
80+
) -> set[ClsResponse]:
81+
incorrect = get_incorrect(tasks, run_result)
82+
83+
categories_: set[ClsResponse] = categories.copy()
84+
for task, answer in tqdm(incorrect):
85+
assert answer is not None
86+
response = client.responses.parse(
87+
model="gpt-5",
88+
instructions=INSTRUCTION.format(categories=categories_str(categories_)),
89+
input=get_prompt(task, answer),
90+
reasoning={"effort": "medium"},
91+
text_format=ClsResponse,
92+
)
93+
assert response.output_parsed is not None
94+
categories_.add(response.output_parsed)
95+
return categories_
96+
97+
98+
def main(result_file_dir: str):
99+
100+
client = OpenAI()
101+
categories: set[ClsResponse] = set()
102+
103+
split = basename(abspath(result_file_dir))
104+
print(split)
105+
base = load_tfb_from_hf(split)
106+
107+
for file in os.listdir(result_file_dir):
108+
if not file.endswith(".jsonl"):
109+
continue
110+
result_file_path = pjoin(result_file_dir, file)
111+
run_result = load_gen_results_jsonl(result_file_path)
112+
print(f"Processing {result_file_path}")
113+
run_categories = classify_run(
114+
client,
115+
categories,
116+
base,
117+
run_result,
118+
)
119+
categories.update(run_categories)
120+
121+
with open("error_categories.json", "wb") as f:
122+
f.write(
123+
orjson.dumps(
124+
[c.model_dump(mode="json") for c in categories],
125+
option=orjson.OPT_INDENT_2,
126+
)
127+
)
128+
129+
130+
if __name__ == "__main__":
131+
fire.Fire(main)

src/tfbench/evaluation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,14 @@ def analysis_multi_runs(results: list[EvalResult]) -> tuple[float, float]:
9898
"""calculate mean and std of accuracy of multiple runs"""
9999
accs = list(map(lambda r: r["accuracy"], results))
100100
return np.mean(accs).item(), np.std(accs).item()
101+
102+
103+
def get_incorrect(
104+
tasks: list[BenchmarkTask], results: list[LMAnswer | None]
105+
) -> list[tuple[BenchmarkTask, LMAnswer | None]]:
106+
"""Get a list of tasks that were incorrectly answered."""
107+
incorrect = []
108+
for task, result in zip(tasks, results):
109+
if not evaluate_one_task(task, result):
110+
incorrect.append((task, result))
111+
return incorrect

0 commit comments

Comments
 (0)