Skip to content

Commit 660dac4

Browse files
committed
feat: refactor SiliconCloudEmbeddingModel to include optional parameters and improve API integration
1 parent 455be0c commit 660dac4

1 file changed

Lines changed: 27 additions & 12 deletions

File tree

  • apps/models_provider/impl/siliconCloud_model_provider/model

apps/models_provider/impl/siliconCloud_model_provider/model/embedding.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,55 @@
33
@project: MaxKB
44
@Author:虎
55
@file: embedding.py
6-
@date:2024/7/12 17:44
6+
@date:2024/10/16 16:34
77
@desc:
88
"""
9-
from typing import Dict
9+
from typing import Dict, List
1010

1111
import requests
12-
from langchain_community.embeddings import OpenAIEmbeddings
1312

14-
from common.utils.logger import maxkb_logger
1513
from models_provider.base_model_provider import MaxKBBaseModel
1614

1715

18-
class SiliconCloudEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
16+
class SiliconCloudEmbeddingModel(MaxKBBaseModel):
17+
model_name: str
18+
openai_api_key: str
19+
base_url: str
20+
optional_params: dict
21+
22+
def __init__(self, api_key, model_name: str, base_url, optional_params: dict):
23+
self.openai_api_key = api_key
24+
self.base_url = base_url
25+
self.model_name = model_name
26+
self.optional_params = optional_params
27+
28+
def is_cache_model(self):
29+
return False
30+
1931
@staticmethod
2032
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
33+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2134
return SiliconCloudEmbeddingModel(
22-
openai_api_key=model_credential.get('api_key'),
23-
model=model_name,
24-
openai_api_base=model_credential.get('api_base'),
35+
api_key=model_credential.get('api_key'),
36+
model_name=model_name,
37+
optional_params=optional_params,
38+
base_url=model_credential.get('api_base'),
2539
)
2640

2741
def embed_query(self, text: str) -> list:
2842
payload = {
29-
"model": self.model,
30-
"input": text
43+
"model": self.model_name,
44+
"input": text,
45+
**self.optional_params
3146
}
3247
headers = {
3348
"Authorization": f"Bearer {self.openai_api_key}",
3449
"Content-Type": "application/json"
3550
}
3651

37-
response = requests.post(self.openai_api_base + '/embeddings', json=payload, headers=headers)
52+
response = requests.post(self.base_url + '/embeddings', json=payload, headers=headers)
53+
print(response.text)
3854
data = response.json()
39-
# print(data)
4055
if data['data'] is None or 'code' in data:
4156
raise ValueError(f"Embedding API returned no data: {data}")
4257
# 假设返回结构中有 'data[0].embedding'

0 commit comments

Comments
 (0)