diff --git a/scripts/error_analysis.py b/scripts/error_analysis.py index f7db32d..2d9d7cf 100644 --- a/scripts/error_analysis.py +++ b/scripts/error_analysis.py @@ -17,7 +17,6 @@ def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file: str): """script to run error analysis fo incorrect TF-Bench tasks""" - client = OpenAI() tasks = load_tfb_from_hf(split) model = basename(abspath(result_file_dir)) @@ -36,7 +35,7 @@ def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file: print(f"Running error classification on {len(incorrect)} incorrect results") for task, answer, msg in tqdm(incorrect): - error = error_analysis(client, task, answer, error_msg=msg) + error = error_analysis(task, answer, error_msg=msg) log_obj: ErrorAnalysisResult = { "model": model, "split": split, @@ -44,7 +43,6 @@ def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file: "ground_truth": task.signature, "predicted": answer.answer if answer else None, "error_category": error.category, - "error_explanation": error.explanation, } orjsonl.append(output_file, log_obj) diff --git a/src/tfbench/error_analysis.py b/src/tfbench/error_analysis.py index c71e128..9be0c88 100644 --- a/src/tfbench/error_analysis.py +++ b/src/tfbench/error_analysis.py @@ -1,8 +1,7 @@ from typing import TypedDict, Literal from pydantic import BaseModel -from openai import OpenAI - +from ollama import chat from .common import get_prompt as get_task_prompt, BenchmarkTask from .lm import LMAnswer @@ -57,7 +56,6 @@ The prompt asked to only output the type signature, but the answer contains additional text or explanation. Choose one category from the above. -Only output the one-word classification and a short explanation of the why this category fits. """ ErrorCategories = Literal[ @@ -74,7 +72,6 @@ class ErrorAnalysisResponse(BaseModel): category: ErrorCategories - explanation: str def get_error_analysis_prompt( @@ -92,26 +89,32 @@ def get_error_analysis_prompt( def error_analysis( - client: OpenAI, task: BenchmarkTask, answer: LMAnswer | None, error_msg: str, + model: str = "qwen3:235b", ) -> ErrorAnalysisResponse: - """classify errors for all incorrect answers in the run_result""" + """classify errors for all incorrect answers in the run_result + NOTE: this function uses the OpenAI-compatible API of vLLM. + Which model to use is determined by how you serve the model. + """ if answer is None: - return ErrorAnalysisResponse( - category="ResponseError", explanation="No answer provided." - ) - - response = client.responses.parse( - model="gpt-5", - instructions=INSTRUCTION, - input=get_error_analysis_prompt(task, answer, error_msg=error_msg), - reasoning={"effort": "medium"}, - text_format=ErrorAnalysisResponse, + return ErrorAnalysisResponse(category="ResponseError") + + response = chat( + model=model, + messages=[ + {"role": "system", "content": INSTRUCTION}, + { + "role": "user", + "content": get_error_analysis_prompt(task, answer, error_msg=error_msg), + }, + ], + format=ErrorAnalysisResponse.model_json_schema(), ) - assert response.output_parsed is not None - return response.output_parsed + content = response.message.content # type: ignore + err = ErrorAnalysisResponse.model_validate_json(content) + return err class ErrorAnalysisResult(TypedDict): @@ -121,4 +124,3 @@ class ErrorAnalysisResult(TypedDict): ground_truth: str predicted: str | None error_category: ErrorCategories - error_explanation: str diff --git a/src/tfbench/lm/_ollama.py b/src/tfbench/lm/_ollama.py index 9e848e6..2074d26 100644 --- a/src/tfbench/lm/_ollama.py +++ b/src/tfbench/lm/_ollama.py @@ -34,6 +34,7 @@ def _gen(self, prompt: str) -> LMAnswer: }, ], think=True, + keep_alive=True, ) return LMAnswer( answer=response.message.content, # type: ignore