fix: OpenAI Vector Model Using Openai Supplier#2781
Conversation
|
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
| self, texts: List[str], chunk_size: int | None = None | ||
| ) -> List[List[float]]: | ||
| res = self.client.create(input=texts, model=self.model_name, encoding_format="float") | ||
| return [e.embedding for e in res.data] |
There was a problem hiding this comment.
The code has a few issues, main ones being:
-
The
openaimodule does not directly call any class method, but rather uses its attributes to perform operations (create). -
Using dynamic method calls (like in
.embed_documents) can be less readable and maintainable. -
It lacks validation checks, particularly for required parameters such as apiKey, model_name, etc.
-
There is no exception handling.
To improve the code, you could do these changes:
@@ -6,18 +6,34 @@
@date:2024/7/12 17:44
@desc:
"""
-from typing import Dict
+from typing import Dict, List
-import openai
+from setting.models_provider.base_model_provider import MaxKBBaseModel
-class OpenAIEmbeddingModel(MaxKBBaseModel, openai.Embeddings):
+class OpenAIEmbeddingModel(MaxKBBaseModel):
_client: openai.Completion
api_key: str
base_url: str
model_name: str
+ def __init__(self, api_key, base_url, model_name: str = "text-davinci-003"):
if not isinstance(api_key, str) or not api_key:
raise ValueError("API key cannot be empty")
if not isinstance(base_url, str) or not base_url:
raise ValueError("Base URL cannot be empty")
self.api_key = api_key
self.base_url = base_url
- super().__init__()
+ self._client = openai.OpenAI(api_key=api_key, base_url=base_url)
+ # set default model name if one wasn't provided
+ if not self.model_name:
+ self.model_name = "text-davinci-003"
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OpenAIEmbeddingModel(
@@ -34,7 +49,7 @@
return result.choices[0].message.content.split("\n\n")[:-1]
def embed_query(self, text: str) -> List[float]:
- response = self._client.create(prompt=text)
+ response = self.client.completions.create(prompt=text)
token_ids = [token["id"] for token in response.choices[0]["tokens"]]
embedding_values = [token["logprobs"]["probs"][i] for i in range(0, len(token_ids), 512)]
@@ -43,20 +58,24 @@This revised version improves readability and error-handling by using explicit initialization and proper parameter checking.
fix: OpenAI Vector Model Using Openai Supplier