Skip to content

Commit cd95675

Browse files
committed
use ollama for error analysis
1 parent 34897e5 commit cd95675

3 files changed

Lines changed: 21 additions & 15 deletions

File tree

scripts/error_analysis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file: str):
1919
"""script to run error analysis fo incorrect TF-Bench tasks"""
20-
client = OpenAI()
2120
tasks = load_tfb_from_hf(split)
2221
model = basename(abspath(result_file_dir))
2322

@@ -36,7 +35,7 @@ def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file:
3635

3736
print(f"Running error classification on {len(incorrect)} incorrect results")
3837
for task, answer, msg in tqdm(incorrect):
39-
error = error_analysis(client, task, answer, error_msg=msg)
38+
error = error_analysis(task, answer, error_msg=msg)
4039
log_obj: ErrorAnalysisResult = {
4140
"model": model,
4241
"split": split,

src/tfbench/error_analysis.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import TypedDict, Literal
22

33
from pydantic import BaseModel
4-
from openai import OpenAI
5-
4+
from ollama import chat
65
from .common import get_prompt as get_task_prompt, BenchmarkTask
76
from .lm import LMAnswer
87

@@ -57,7 +56,6 @@
5756
The prompt asked to only output the type signature,
5857
but the answer contains additional text or explanation.
5958
Choose one category from the above.
60-
Only output the one-word classification and a short explanation of the why this category fits.
6159
"""
6260

6361
ErrorCategories = Literal[
@@ -92,26 +90,34 @@ def get_error_analysis_prompt(
9290

9391

9492
def error_analysis(
95-
client: OpenAI,
9693
task: BenchmarkTask,
9794
answer: LMAnswer | None,
9895
error_msg: str,
96+
model: str = "qwen3:235b",
9997
) -> ErrorAnalysisResponse:
100-
"""classify errors for all incorrect answers in the run_result"""
98+
"""classify errors for all incorrect answers in the run_result
99+
NOTE: this function uses the OpenAI-compatible API of vLLM.
100+
Which model to use is determined by how you serve the model.
101+
"""
101102
if answer is None:
102103
return ErrorAnalysisResponse(
103104
category="ResponseError", explanation="No answer provided."
104105
)
105106

106-
response = client.responses.parse(
107-
model="gpt-5",
108-
instructions=INSTRUCTION,
109-
input=get_error_analysis_prompt(task, answer, error_msg=error_msg),
110-
reasoning={"effort": "medium"},
111-
text_format=ErrorAnalysisResponse,
107+
response = chat(
108+
model=model,
109+
messages=[
110+
{"role": "system", "content": INSTRUCTION},
111+
{
112+
"role": "user",
113+
"content": get_error_analysis_prompt(task, answer, error_msg=error_msg),
114+
},
115+
],
116+
format=ErrorAnalysisResponse.model_json_schema(),
112117
)
113-
assert response.output_parsed is not None
114-
return response.output_parsed
118+
content = response.message.content # type: ignore
119+
err = ErrorAnalysisResponse.model_validate_json(content)
120+
return err
115121

116122

117123
class ErrorAnalysisResult(TypedDict):

src/tfbench/lm/_ollama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _gen(self, prompt: str) -> LMAnswer:
3434
},
3535
],
3636
think=True,
37+
keep_alive=True,
3738
)
3839
return LMAnswer(
3940
answer=response.message.content, # type: ignore

0 commit comments

Comments
 (0)