Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,34 @@
@date:2024/7/12 17:44
@desc:
"""
from typing import Dict
from typing import Dict, List

from langchain_community.embeddings import OpenAIEmbeddings
import openai

from setting.models_provider.base_model_provider import MaxKBBaseModel


class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
class OpenAIEmbeddingModel(MaxKBBaseModel):
model_name: str

def __init__(self, api_key, base_url, model_name: str):
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
self.model_name = model_name

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model=model_name,
openai_api_base=model_credential.get('api_base'),
model_name=model_name,
base_url=model_credential.get('api_base'),
)

def embed_query(self, text: str):
res = self.embed_documents([text])
return res[0]

def embed_documents(
self, texts: List[str], chunk_size: int | None = None
) -> List[List[float]]:
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
return [e.embedding for e in res.data]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has a few issues, main ones being:

  1. The openai module does not directly call any class method, but rather uses its attributes to perform operations (create).

  2. Using dynamic method calls (like in .embed_documents) can be less readable and maintainable.

  3. It lacks validation checks, particularly for required parameters such as apiKey, model_name, etc.

  4. There is no exception handling.

To improve the code, you could do these changes:

@@ -6,18 +6,34 @@
     @date2024/7/12 17:44
     @desc:
 """
-from typing import Dict
+from typing import Dict, List

-import openai
+from setting.models_provider.base_model_provider import MaxKBBaseModel


-class OpenAIEmbeddingModel(MaxKBBaseModel, openai.Embeddings):
+class OpenAIEmbeddingModel(MaxKBBaseModel):
     _client: openai.Completion
     api_key: str
     base_url: str
     model_name: str

+    def __init__(self, api_key, base_url, model_name: str = "text-davinci-003"):
         if not isinstance(api_key, str) or not api_key:
             raise ValueError("API key cannot be empty")
         if not isinstance(base_url, str) or not base_url:
             raise ValueError("Base URL cannot be empty")

         self.api_key = api_key
         self.base_url = base_url
-        super().__init__()
+        self._client = openai.OpenAI(api_key=api_key, base_url=base_url)

+        # set default model name if one wasn't provided
+        if not self.model_name:
+            self.model_name = "text-davinci-003"

     @staticmethod
     def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
         return OpenAIEmbeddingModel(
@@ -34,7 +49,7 @@
         return result.choices[0].message.content.split("\n\n")[:-1]

     def embed_query(self, text: str) -> List[float]:
-        response = self._client.create(prompt=text)
+        response = self.client.completions.create(prompt=text)
         token_ids = [token["id"] for token in response.choices[0]["tokens"]]
         embedding_values = [token["logprobs"]["probs"][i] for i in range(0, len(token_ids), 512)]
@@ -43,20 +58,24 @@

This revised version improves readability and error-handling by using explicit initialization and proper parameter checking.