-
-
Notifications
You must be signed in to change notification settings - Fork 49
Expand file tree
/
Copy pathopenai.py
More file actions
94 lines (74 loc) · 3.39 KB
/
Copy pathopenai.py
File metadata and controls
94 lines (74 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import logging
import openai
from openai.types.chat import ChatCompletion
from pydantic import BaseModel, Field, ValidationError
from vectorcode.cli_utils import Config
from vectorcode.rewriter.base import RewriterBase
logger = logging.getLogger(name=__name__)
class _NewQuery(BaseModel):
keywords: list[str] = Field(
description="Orthogonal keywords for the vector search."
)
class OpenAIRewriter(RewriterBase):
"""
OpenAIRewriter class is an adapter for openai-compatible API services that provides
structured output support. The `rewriter_params` dictionary accepts 3 keys:
- `client_kwargs`: dictionary, containing arguments that are passed to `openai.Client`.
See https://github.com/openai/openai-python/blob/67997a4ec1ebcdf8e740afb0d0b2e37897657bde/src/openai/_client.py#L80;
- `completion_kwargs`: dictionary, containing arguments that are passed to `openai.Client.beta.chat.completions.parse`.
See https://github.com/openai/openai-python/blob/main/helpers.md#structured-outputs-parsing-helpers.
- `system_prompt`: string, the system prompt that contains the guidelines for rewriting the query.
"""
def __init__(self, config: Config) -> None:
super().__init__(config)
self.client = openai.Client(
**self.config.rewriter_params.get("client_kwargs", {})
)
self.system_prompt = self.config.rewriter_params.get(
"system_prompt",
"""
Role:
You are a code-aware rewriter that improves technical queries/docs for retrieval. Never assume a programming language unless the input explicitly includes syntax, APIs, or error messages from one.
Rules:
For Queries:
Fix unambiguous typos (e.g., "Pytoch" → "PyTorch").
Omit langauge-specific keywords (e.g., "async def foo():" → "foo")
Do not include standard libraries in the query.
For Docs/Code:
Clarify ambiguous terms only with explicit context.
Never modify code logic or variable names.
Anti-Goals:
No language assumptions.
No code changes.
No hallucinations.
""",
)
async def rewrite(self, original_query: list[str]):
try:
comp: ChatCompletion = self.client.beta.chat.completions.parse(
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": " ".join(original_query)},
],
response_format=_NewQuery,
**self.config.rewriter_params.get("completion_kwargs", {}),
)
if comp is None or len(comp.choices) == 0:
logger.info(
"Received no rewritten query. Fallingback to original_query."
)
return original_query
choice = comp.choices[0].message
if choice and choice.parsed:
logger.debug(f"Rewritten queries to: {choice.parsed}")
return choice.parsed.keywords
else:
logger.warning(
f"Failed to parse structured output: {choice.refusal}. Fallingback to original_query."
)
return original_query
except ValidationError:
logger.warning(
"Failed to parse structured output. Fallingback to original_query."
)
return original_query