Skip to content

Commit ea9fb13

Browse files
committed
fix: add support for multimodal embeddings in AliyunBaiLianEmbedding class
1 parent cf637a8 commit ea9fb13

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

  • apps/models_provider/impl/aliyun_bai_lian_model_provider/model

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
@date:2024/10/16 16:34
77
@desc:
88
"""
9+
from http import HTTPStatus
910
from typing import Dict, List
1011

1112
from openai import OpenAI
@@ -16,11 +17,15 @@
1617
class AliyunBaiLianEmbedding(MaxKBBaseModel):
1718
model_name: str
1819
optional_params: dict
20+
api_base: str
21+
api_key: str
1922

2023
def __init__(self, api_key, model_name: str, api_base: str, optional_params: dict):
2124
self.client = OpenAI(api_key=api_key, base_url=api_base).embeddings
2225
self.model_name = model_name
2326
self.optional_params = optional_params
27+
self.api_key = api_key
28+
self.api_base = api_base
2429

2530
def is_cache_model(self):
2631
return False
@@ -42,6 +47,24 @@ def embed_query(self, text: str):
4247
def embed_documents(
4348
self, texts: List[str], chunk_size: int | None = None
4449
) -> List[List[float]]:
50+
# 处理多模态的向量化
51+
if 'vl-embedding' in self.model_name or 'embedding-vision' in self.model_name or 'multimodal' in self.model_name:
52+
import dashscope
53+
dashscope.api_key = self.api_key
54+
dashscope.base_http_api_url = self.api_base
55+
multimodal_input = [{"text": text} for text in texts]
56+
resp = dashscope.MultiModalEmbedding.call(
57+
model="tongyi-embedding-vision-plus",
58+
input=multimodal_input, # type: ignore
59+
**self.optional_params
60+
)
61+
62+
if resp.status_code == HTTPStatus.OK:
63+
embeddings_data = resp.output.get('embeddings', [])
64+
return [item.get('embedding', []) for item in embeddings_data]
65+
else:
66+
raise Exception(f'MultiModalEmbedding call failed: status={resp.status_code}, message={resp.message}')
67+
4568
if len(self.optional_params) > 0:
4669
res = self.client.create(
4770
input=texts, model=self.model_name, encoding_format="float",

0 commit comments

Comments
 (0)