Skip to content

Commit 3a3e4d0

Browse files
committed
feat: implement Volcanic Engine embedding model with support for document embedding
1 parent a183e0f commit 3a3e4d0

1 file changed

Lines changed: 51 additions & 7 deletions

File tree

  • apps/models_provider/impl/volcanic_engine_model_provider/model
Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,60 @@
1-
from typing import Dict
2-
3-
from langchain_openai import OpenAIEmbeddings
1+
from typing import Dict, List
42

53
from models_provider.base_model_provider import MaxKBBaseModel
4+
from volcenginesdkarkruntime import Ark
5+
6+
7+
class VolcanicEngineEmbeddingModel(MaxKBBaseModel):
8+
api_key: str
9+
model_name: str
10+
api_base: str
11+
params: Dict[str, object]
12+
13+
def __init__(self, api_key: str, model: str, api_base: str, params: Dict[str, object] = None):
14+
self.client = Ark(
15+
api_key=api_key,
16+
base_url=api_base
17+
)
18+
self.model_name = model
19+
self.params = params
620

21+
@staticmethod
22+
def is_cache_model():
23+
return False
724

8-
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
925
@staticmethod
1026
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
1127
return VolcanicEngineEmbeddingModel(
12-
openai_api_key=model_credential.get('api_key'),
28+
api_key=model_credential.get("api_key"),
1329
model=model_name,
14-
openai_api_base=model_credential.get('api_base'),
15-
check_embedding_ctx_length=False,
30+
api_base=model_credential.get("api_base"),
31+
**model_kwargs
1632
)
33+
34+
def embed_query(self, text: str):
35+
res = self.embed_documents([text])
36+
return res[0]
37+
38+
def embed_documents(
39+
self, texts: List[str], chunk_size: int | None = None
40+
) -> List[List[float]]:
41+
if self.model_name.startswith("doubao-embedding-vision-"):
42+
multimodal_inputs = []
43+
for text in texts:
44+
multimodal_inputs.append({
45+
"type": "text",
46+
"text": text
47+
})
48+
resp = self.client.multimodal_embeddings.create(
49+
model=self.model_name,
50+
input=multimodal_inputs,
51+
**self.params
52+
)
53+
return [resp.data.get('embedding')]
54+
else:
55+
resp = self.client.embeddings.create(
56+
model=self.model_name,
57+
input=texts,
58+
**self.params
59+
)
60+
return [e.embedding for e in resp.data]

0 commit comments

Comments
 (0)