|
3 | 3 | @project: MaxKB |
4 | 4 | @Author:虎 |
5 | 5 | @file: embedding.py |
6 | | - @date:2024/7/12 17:44 |
| 6 | + @date:2024/10/16 16:34 |
7 | 7 | @desc: |
8 | 8 | """ |
9 | | -from typing import Dict |
| 9 | +from typing import Dict, List |
10 | 10 |
|
11 | 11 | import requests |
12 | | -from langchain_community.embeddings import OpenAIEmbeddings |
13 | 12 |
|
14 | | -from common.utils.logger import maxkb_logger |
15 | 13 | from models_provider.base_model_provider import MaxKBBaseModel |
16 | 14 |
|
17 | 15 |
|
18 | | -class SiliconCloudEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): |
| 16 | +class SiliconCloudEmbeddingModel(MaxKBBaseModel): |
| 17 | + model_name: str |
| 18 | + openai_api_key: str |
| 19 | + base_url: str |
| 20 | + optional_params: dict |
| 21 | + |
| 22 | + def __init__(self, api_key, model_name: str, base_url, optional_params: dict): |
| 23 | + self.openai_api_key = api_key |
| 24 | + self.base_url = base_url |
| 25 | + self.model_name = model_name |
| 26 | + self.optional_params = optional_params |
| 27 | + |
| 28 | + def is_cache_model(self): |
| 29 | + return False |
| 30 | + |
19 | 31 | @staticmethod |
20 | 32 | def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): |
| 33 | + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) |
21 | 34 | return SiliconCloudEmbeddingModel( |
22 | | - openai_api_key=model_credential.get('api_key'), |
23 | | - model=model_name, |
24 | | - openai_api_base=model_credential.get('api_base'), |
| 35 | + api_key=model_credential.get('api_key'), |
| 36 | + model_name=model_name, |
| 37 | + optional_params=optional_params, |
| 38 | + base_url=model_credential.get('api_base'), |
25 | 39 | ) |
26 | 40 |
|
27 | 41 | def embed_query(self, text: str) -> list: |
28 | 42 | payload = { |
29 | | - "model": self.model, |
30 | | - "input": text |
| 43 | + "model": self.model_name, |
| 44 | + "input": text, |
| 45 | + **self.optional_params |
31 | 46 | } |
32 | 47 | headers = { |
33 | 48 | "Authorization": f"Bearer {self.openai_api_key}", |
34 | 49 | "Content-Type": "application/json" |
35 | 50 | } |
36 | 51 |
|
37 | | - response = requests.post(self.openai_api_base + '/embeddings', json=payload, headers=headers) |
| 52 | + response = requests.post(self.base_url + '/embeddings', json=payload, headers=headers) |
| 53 | + print(response.text) |
38 | 54 | data = response.json() |
39 | | - # print(data) |
40 | 55 | if data['data'] is None or 'code' in data: |
41 | 56 | raise ValueError(f"Embedding API returned no data: {data}") |
42 | 57 | # 假设返回结构中有 'data[0].embedding' |
|
0 commit comments