Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions scripts/error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file: str):
"""script to run error analysis fo incorrect TF-Bench tasks"""
client = OpenAI()
tasks = load_tfb_from_hf(split)
model = basename(abspath(result_file_dir))

Expand All @@ -36,15 +35,14 @@ def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file:

print(f"Running error classification on {len(incorrect)} incorrect results")
for task, answer, msg in tqdm(incorrect):
error = error_analysis(client, task, answer, error_msg=msg)
error = error_analysis(task, answer, error_msg=msg)
log_obj: ErrorAnalysisResult = {
"model": model,
"split": split,
"task_id": task.task_id,
"ground_truth": task.signature,
"predicted": answer.answer if answer else None,
"error_category": error.category,
"error_explanation": error.explanation,
}
orjsonl.append(output_file, log_obj)

Expand Down
40 changes: 21 additions & 19 deletions src/tfbench/error_analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TypedDict, Literal

from pydantic import BaseModel
from openai import OpenAI

from ollama import chat
from .common import get_prompt as get_task_prompt, BenchmarkTask
from .lm import LMAnswer

Expand Down Expand Up @@ -57,7 +56,6 @@
The prompt asked to only output the type signature,
but the answer contains additional text or explanation.
Choose one category from the above.
Only output the one-word classification and a short explanation of the why this category fits.
"""

ErrorCategories = Literal[
Expand All @@ -74,7 +72,6 @@

class ErrorAnalysisResponse(BaseModel):
category: ErrorCategories
explanation: str


def get_error_analysis_prompt(
Expand All @@ -92,26 +89,32 @@ def get_error_analysis_prompt(


def error_analysis(
client: OpenAI,
task: BenchmarkTask,
answer: LMAnswer | None,
error_msg: str,
model: str = "qwen3:235b",
) -> ErrorAnalysisResponse:
"""classify errors for all incorrect answers in the run_result"""
"""classify errors for all incorrect answers in the run_result
NOTE: this function uses the OpenAI-compatible API of vLLM.
Which model to use is determined by how you serve the model.
"""
if answer is None:
return ErrorAnalysisResponse(
category="ResponseError", explanation="No answer provided."
)

response = client.responses.parse(
model="gpt-5",
instructions=INSTRUCTION,
input=get_error_analysis_prompt(task, answer, error_msg=error_msg),
reasoning={"effort": "medium"},
text_format=ErrorAnalysisResponse,
return ErrorAnalysisResponse(category="ResponseError")

response = chat(
model=model,
messages=[
{"role": "system", "content": INSTRUCTION},
{
"role": "user",
"content": get_error_analysis_prompt(task, answer, error_msg=error_msg),
},
],
format=ErrorAnalysisResponse.model_json_schema(),
)
assert response.output_parsed is not None
return response.output_parsed
content = response.message.content # type: ignore
err = ErrorAnalysisResponse.model_validate_json(content)
return err


class ErrorAnalysisResult(TypedDict):
Expand All @@ -121,4 +124,3 @@ class ErrorAnalysisResult(TypedDict):
ground_truth: str
predicted: str | None
error_category: ErrorCategories
error_explanation: str
1 change: 1 addition & 0 deletions src/tfbench/lm/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _gen(self, prompt: str) -> LMAnswer:
},
],
think=True,
keep_alive=True,
)
return LMAnswer(
answer=response.message.content, # type: ignore
Expand Down
Loading