|
1 | | -from typing import Dict |
2 | | - |
3 | | -from langchain_openai import OpenAIEmbeddings |
| 1 | +from typing import Dict, List |
4 | 2 |
|
5 | 3 | from models_provider.base_model_provider import MaxKBBaseModel |
| 4 | +from volcenginesdkarkruntime import Ark |
| 5 | + |
| 6 | + |
| 7 | +class VolcanicEngineEmbeddingModel(MaxKBBaseModel): |
| 8 | + api_key: str |
| 9 | + model_name: str |
| 10 | + api_base: str |
| 11 | + params: Dict[str, object] |
| 12 | + |
| 13 | + def __init__(self, api_key: str, model: str, api_base: str, params: Dict[str, object] = None): |
| 14 | + self.client = Ark( |
| 15 | + api_key=api_key, |
| 16 | + base_url=api_base |
| 17 | + ) |
| 18 | + self.model_name = model |
| 19 | + self.params = params |
6 | 20 |
|
| 21 | + @staticmethod |
| 22 | + def is_cache_model(): |
| 23 | + return False |
7 | 24 |
|
8 | | -class VolcanicEngineEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): |
9 | 25 | @staticmethod |
10 | 26 | def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): |
11 | 27 | return VolcanicEngineEmbeddingModel( |
12 | | - openai_api_key=model_credential.get('api_key'), |
| 28 | + api_key=model_credential.get("api_key"), |
13 | 29 | model=model_name, |
14 | | - openai_api_base=model_credential.get('api_base'), |
15 | | - check_embedding_ctx_length=False, |
| 30 | + api_base=model_credential.get("api_base"), |
| 31 | + **model_kwargs |
16 | 32 | ) |
| 33 | + |
| 34 | + def embed_query(self, text: str): |
| 35 | + res = self.embed_documents([text]) |
| 36 | + return res[0] |
| 37 | + |
| 38 | + def embed_documents( |
| 39 | + self, texts: List[str], chunk_size: int | None = None |
| 40 | + ) -> List[List[float]]: |
| 41 | + if self.model_name.startswith("doubao-embedding-vision-"): |
| 42 | + multimodal_inputs = [] |
| 43 | + for text in texts: |
| 44 | + multimodal_inputs.append({ |
| 45 | + "type": "text", |
| 46 | + "text": text |
| 47 | + }) |
| 48 | + resp = self.client.multimodal_embeddings.create( |
| 49 | + model=self.model_name, |
| 50 | + input=multimodal_inputs, |
| 51 | + **self.params |
| 52 | + ) |
| 53 | + return [resp.data.get('embedding')] |
| 54 | + else: |
| 55 | + resp = self.client.embeddings.create( |
| 56 | + model=self.model_name, |
| 57 | + input=texts, |
| 58 | + **self.params |
| 59 | + ) |
| 60 | + return [e.embedding for e in resp.data] |
0 commit comments