|
1 | 1 | from typing import TypedDict, Literal |
2 | 2 |
|
3 | 3 | from pydantic import BaseModel |
4 | | -from openai import OpenAI |
5 | | - |
| 4 | +from ollama import chat |
6 | 5 | from .common import get_prompt as get_task_prompt, BenchmarkTask |
7 | 6 | from .lm import LMAnswer |
8 | 7 |
|
|
57 | 56 | The prompt asked to only output the type signature, |
58 | 57 | but the answer contains additional text or explanation. |
59 | 58 | Choose one category from the above. |
60 | | -Only output the one-word classification and a short explanation of the why this category fits. |
61 | 59 | """ |
62 | 60 |
|
63 | 61 | ErrorCategories = Literal[ |
@@ -92,26 +90,34 @@ def get_error_analysis_prompt( |
92 | 90 |
|
93 | 91 |
|
94 | 92 | def error_analysis( |
95 | | - client: OpenAI, |
96 | 93 | task: BenchmarkTask, |
97 | 94 | answer: LMAnswer | None, |
98 | 95 | error_msg: str, |
| 96 | + model: str = "qwen3:235b", |
99 | 97 | ) -> ErrorAnalysisResponse: |
100 | | - """classify errors for all incorrect answers in the run_result""" |
| 98 | + """classify errors for all incorrect answers in the run_result |
| 99 | + NOTE: this function uses the OpenAI-compatible API of vLLM. |
| 100 | + Which model to use is determined by how you serve the model. |
| 101 | + """ |
101 | 102 | if answer is None: |
102 | 103 | return ErrorAnalysisResponse( |
103 | 104 | category="ResponseError", explanation="No answer provided." |
104 | 105 | ) |
105 | 106 |
|
106 | | - response = client.responses.parse( |
107 | | - model="gpt-5", |
108 | | - instructions=INSTRUCTION, |
109 | | - input=get_error_analysis_prompt(task, answer, error_msg=error_msg), |
110 | | - reasoning={"effort": "medium"}, |
111 | | - text_format=ErrorAnalysisResponse, |
| 107 | + response = chat( |
| 108 | + model=model, |
| 109 | + messages=[ |
| 110 | + {"role": "system", "content": INSTRUCTION}, |
| 111 | + { |
| 112 | + "role": "user", |
| 113 | + "content": get_error_analysis_prompt(task, answer, error_msg=error_msg), |
| 114 | + }, |
| 115 | + ], |
| 116 | + format=ErrorAnalysisResponse.model_json_schema(), |
112 | 117 | ) |
113 | | - assert response.output_parsed is not None |
114 | | - return response.output_parsed |
| 118 | + content = response.message.content # type: ignore |
| 119 | + err = ErrorAnalysisResponse.model_validate_json(content) |
| 120 | + return err |
115 | 121 |
|
116 | 122 |
|
117 | 123 | class ErrorAnalysisResult(TypedDict): |
|
0 commit comments