Skip to content

Commit 4d5422d

Browse files
authored
🐛 Bugfix: Batch add dashscope embedding model failed to calculate embedding dimension (#3172)
* ✨ Support preview epub, html, json and xml files ✨ Now support add multimodal embedding models from qwen dashscope * 🐛 Bugfix: Batch add dashscope embedding model failed to calculate embedding dimension * 🧪 Add test files
1 parent b478f58 commit 4d5422d

14 files changed

Lines changed: 2190 additions & 150 deletions

File tree

backend/services/model_health_service.py

Lines changed: 95 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from nexent.core import MessageObserver
55
from nexent.core.models import OpenAIModel, OpenAIVLModel
6-
from nexent.core.models.embedding_model import JinaEmbedding, OpenAICompatibleEmbedding
6+
from nexent.core.models.embedding_model import JinaEmbedding, OpenAICompatibleEmbedding, DashScopeMultimodalEmbedding
77
from nexent.monitor import set_monitoring_context, set_monitoring_operation
88
from nexent.core.models.rerank_model import OpenAICompatibleRerank
99

@@ -20,12 +20,33 @@
2020
PROVIDER_CATALOG_HEALTHCHECK_FACTORIES = {DASHSCOPE_MODEL_FACTORY, TOKENPONY_MODEL_FACTORY}
2121
PROVIDER_CATALOG_HEALTHCHECK_TYPES = {"vlm", "vlm2", "vlm3"}
2222

23+
EMBEDDING_TYPES = {"embedding", "multi_embedding"}
2324

24-
def _mask_secret(value: Optional[str]) -> str:
25-
"""Mask a secret value, showing only first and last 4 characters."""
26-
if not value or len(value) <= 8:
27-
return "***"
28-
return value[:4] + "****" + value[-4:]
25+
26+
def _normalize_embedding_url(base_url: str) -> str:
27+
"""Append /embeddings suffix to base_url if not already present.
28+
29+
For embedding and multimodal embedding models, the base_url should contain /embeddings.
30+
If the user provides a base URL without the endpoint (e.g., https://api.jina.ai/v1),
31+
this function normalizes it to include /embeddings (e.g., https://api.jina.ai/v1/embeddings).
32+
"""
33+
if not base_url or "/embeddings" in base_url:
34+
return base_url
35+
return f"{base_url.rstrip('/')}/embeddings"
36+
37+
38+
def _infer_model_factory(model_type: str, base_url: str, current_factory: Optional[str] = None) -> Optional[str]:
39+
"""Infer model_factory from base_url if not already set or is generic.
40+
41+
Currently handles:
42+
- multi_embedding with dashscope URL -> "dashscope"
43+
- embedding with dashscope URL -> "dashscope" (uses OpenAI-compatible endpoint)
44+
"""
45+
base_url_lower = base_url.lower()
46+
if "dashscope" in base_url_lower:
47+
return DASHSCOPE_MODEL_FACTORY
48+
49+
return current_factory
2950

3051

3152
async def _embedding_dimension_check(
@@ -34,36 +55,51 @@ async def _embedding_dimension_check(
3455
model_base_url: str,
3556
model_api_key: str,
3657
ssl_verify: bool = True,
58+
model_factory: Optional[str] = None,
3759
timeout_seconds: Optional[float] = None,
3860
):
39-
# Test connectivity based on different model types
61+
if model_type in EMBEDDING_TYPES:
62+
model_base_url = _normalize_embedding_url(model_base_url)
63+
64+
effective_timeout = timeout_seconds if timeout_seconds else 5.0
65+
4066
if model_type == "embedding":
67+
# DashScope text embedding models use OpenAI-compatible endpoint, same as generic
4168
embedding = await OpenAICompatibleEmbedding(
4269
model_name=model_name,
4370
base_url=model_base_url,
4471
api_key=model_api_key,
4572
embedding_dim=0,
4673
ssl_verify=ssl_verify,
47-
timeout_seconds=timeout_seconds,
48-
).dimension_check()
74+
).dimension_check(timeout=effective_timeout)
4975
if len(embedding) > 0:
5076
return len(embedding[0])
5177
logging.warning(
5278
f"Embedding dimension check for {model_name} gets empty response")
5379
return 0
5480
elif model_type == "multi_embedding":
55-
embedding = await JinaEmbedding(
56-
model_name=model_name,
57-
base_url=model_base_url,
58-
api_key=model_api_key,
59-
embedding_dim=0,
60-
ssl_verify=ssl_verify,
61-
timeout_seconds=timeout_seconds,
62-
).dimension_check()
63-
if len(embedding) > 0:
81+
model_factory_lower = (model_factory or "").lower()
82+
if model_factory_lower == "dashscope":
83+
embedding_instance = DashScopeMultimodalEmbedding(
84+
api_key=model_api_key,
85+
base_url=model_base_url,
86+
model_name=model_name,
87+
embedding_dim=0,
88+
ssl_verify=ssl_verify,
89+
)
90+
else:
91+
embedding_instance = JinaEmbedding(
92+
api_key=model_api_key,
93+
base_url=model_base_url,
94+
model_name=model_name,
95+
embedding_dim=0,
96+
ssl_verify=ssl_verify,
97+
)
98+
embedding = await embedding_instance.dimension_check(timeout=effective_timeout)
99+
if isinstance(embedding, list) and len(embedding) > 0 and isinstance(embedding[0], list):
64100
return len(embedding[0])
65101
logging.warning(
66-
f"Embedding dimension check for {model_name} gets empty response")
102+
f"Embedding dimension check for {model_name} gets unexpected response: {type(embedding)}, value: {embedding}")
67103
return 0
68104
else:
69105
raise ValueError(f"Unsupported model type: {model_type}")
@@ -123,27 +159,42 @@ async def _perform_connectivity_check(
123159
model_base_url = model_base_url.replace(
124160
LOCALHOST_NAME, DOCKER_INTERNAL_HOST).replace(LOCALHOST_IP, DOCKER_INTERNAL_HOST)
125161

162+
# Normalize embedding URLs by appending /embeddings if not present
163+
if model_type in EMBEDDING_TYPES:
164+
model_base_url = _normalize_embedding_url(model_base_url)
165+
166+
effective_timeout = timeout_seconds if timeout_seconds else 5.0
126167
connectivity: bool
127168

128-
# Test connectivity based on different model types
129169
if model_type == "embedding":
130-
embedding = OpenAICompatibleEmbedding(
170+
emb = await OpenAICompatibleEmbedding(
131171
model_name=model_name,
132172
base_url=model_base_url,
133173
api_key=model_api_key,
134174
embedding_dim=0,
135175
ssl_verify=ssl_verify,
136-
)
137-
connectivity = len(await embedding.dimension_check(timeout=timeout_seconds if timeout_seconds else 5.0)) > 0
176+
).dimension_check(timeout=effective_timeout)
177+
connectivity = len(emb) > 0 and len(emb[0]) > 0
138178
elif model_type == "multi_embedding":
139-
embedding = JinaEmbedding(
140-
model_name=model_name,
141-
base_url=model_base_url,
142-
api_key=model_api_key,
143-
embedding_dim=0,
144-
ssl_verify=ssl_verify,
145-
)
146-
connectivity = len(await embedding.dimension_check(timeout=timeout_seconds if timeout_seconds else 5.0)) > 0
179+
model_factory_lower = (model_factory or "").lower()
180+
if model_factory_lower == "dashscope":
181+
embedding = DashScopeMultimodalEmbedding(
182+
api_key=model_api_key,
183+
base_url=model_base_url,
184+
model_name=model_name,
185+
embedding_dim=0,
186+
ssl_verify=ssl_verify,
187+
)
188+
else:
189+
embedding = JinaEmbedding(
190+
api_key=model_api_key,
191+
base_url=model_base_url,
192+
model_name=model_name,
193+
embedding_dim=0,
194+
ssl_verify=ssl_verify,
195+
)
196+
emb = await embedding.dimension_check(timeout=effective_timeout)
197+
connectivity = len(emb) > 0 and len(emb[0]) > 0
147198
elif model_type == "llm":
148199
observer = MessageObserver()
149200
set_monitoring_operation("connectivity_check",
@@ -335,6 +386,9 @@ async def verify_model_config_connectivity(model_config: dict):
335386
# Get timeout from model config if present
336387
timeout_seconds = model_config.get("timeout_seconds")
337388

389+
# Infer model_factory from base_url when not provided
390+
model_factory = _infer_model_factory(model_type, model_base_url, model_config.get("model_factory"))
391+
338392
try:
339393
connectivity = await _perform_connectivity_check(
340394
model_name, model_type, model_base_url, model_api_key, ssl_verify,
@@ -385,22 +439,26 @@ async def embedding_dimension_check(model_config: dict):
385439

386440
try:
387441
ssl_verify = model_config.get("ssl_verify", True)
442+
model_factory = _infer_model_factory(model_type, model_base_url, model_config.get("model_factory"))
388443
timeout_seconds = model_config.get("timeout_seconds")
389444
dimension = await _embedding_dimension_check(
390445
model_name, model_type, model_base_url, model_api_key, ssl_verify,
391-
timeout_seconds=timeout_seconds
446+
model_factory=model_factory, timeout_seconds=timeout_seconds
392447
)
393448
# Fallback to ssl_verify=False if initial check fails
394449
if dimension == 0 and ssl_verify:
395450
dimension = await _embedding_dimension_check(
396451
model_name, model_type, model_base_url, model_api_key, False,
397-
timeout_seconds=timeout_seconds
452+
model_factory=model_factory, timeout_seconds=timeout_seconds
398453
)
454+
if dimension == 0:
455+
logger.error(f"Embedding dimension check returned 0 for model: {model_name}")
456+
return None
399457
return dimension
400458
except ValueError as e:
401-
logger.error(f"Error checking embedding dimension: {str(e)}")
402-
return 0
459+
logger.error(f"Error checking embedding dimension for {model_name}: {str(e)}")
460+
return None
403461
except Exception as e:
404462
logger.error(
405-
f"Error checking embedding dimension: {model_name}; Error: {str(e)}")
406-
return 0
463+
f"Error checking embedding dimension for {model_name}: {str(e)}")
464+
return None

backend/services/model_management_service.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
merge_existing_model_attributes,
2020
get_provider_models,
2121
)
22-
from services.model_health_service import embedding_dimension_check
22+
from services.model_health_service import embedding_dimension_check, _infer_model_factory
2323
from utils.model_name_utils import (
2424
add_repo_to_name,
2525
split_repo_name,
@@ -101,9 +101,23 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict
101101
raise ValueError(
102102
f"Name {model_data['display_name']} is already in use, please choose another display name")
103103

104-
# If embedding or multi_embedding, set max_tokens via embedding dimension check
104+
# If embedding or multi_embedding, ensure base_url ends with /embeddings
105105
if model_data.get("model_type") in ("embedding", "multi_embedding"):
106-
model_data["max_tokens"] = await embedding_dimension_check(model_data)
106+
base_url = model_data.get("base_url", "")
107+
if base_url and "/embeddings" not in base_url:
108+
model_data["base_url"] = f"{base_url.rstrip('/')}/embeddings"
109+
# Infer model_factory from base_url if not set
110+
model_data["model_factory"] = _infer_model_factory(
111+
model_data["model_type"], model_data["base_url"], model_data.get("model_factory")
112+
)
113+
# Get embedding dimension
114+
dimension = await embedding_dimension_check(model_data)
115+
if dimension is None:
116+
raise ValueError(
117+
f"Failed to get embedding dimension for model '{model_data.get('display_name', model_data.get('model_name'))}'. "
118+
"Please verify the URL, API key, and network connection."
119+
)
120+
model_data["max_tokens"] = dimension
107121
# Set default chunk_batch if not provided
108122
if model_data.get("chunk_batch") is None:
109123
model_data["chunk_batch"] = 10

backend/services/model_provider_service.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,15 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a
127127
# Determine the correct base_url and, for embeddings, update the actual
128128
# dimension by performing a real connectivity check.
129129
if model["model_type"] in ["embedding", "multi_embedding"]:
130-
if provider != ProviderEnum.MODELENGINE.value:
131-
# Ensure proper slash between base URL and endpoint
130+
if provider == ProviderEnum.DASHSCOPE.value and model["model_type"] == "embedding":
132131
model_dict["base_url"] = f"{model_url.rstrip('/')}/embeddings"
133-
else:
134-
# For ModelEngine embedding models, append the embeddings path
132+
elif provider == ProviderEnum.MODELENGINE.value:
135133
model_dict["base_url"] = f"{model_url.rstrip('/')}/{MODEL_ENGINE_NORTH_PREFIX}/embeddings"
136-
# The embedding dimension might differ from the provided max_tokens.
134+
elif "/embeddings" in model_url:
135+
# URL already contains /embeddings endpoint, use as-is
136+
model_dict["base_url"] = model_url.rstrip('/')
137+
else:
138+
model_dict["base_url"] = f"{model_url.rstrip('/')}/embeddings"
137139
model_dict["max_tokens"] = await embedding_dimension_check(model_dict)
138140
elif model["model_type"] == "rerank":
139141
if provider == ProviderEnum.DASHSCOPE.value:

backend/services/vectordatabase_service.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from fastapi import Body, Depends, Path, Query
2222
from fastapi.responses import StreamingResponse
23-
from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding
23+
from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, DashScopeMultimodalEmbedding, BaseEmbedding
2424
from nexent.core.models.rerank_model import OpenAICompatibleRerank, BaseRerank
2525
from nexent.vector_database.base import VectorDatabaseCore
2626
from nexent.vector_database.elasticsearch_core import ElasticSearchCore
@@ -335,6 +335,9 @@ def _create_embedding_model(model: dict) -> Any:
335335
"ssl_verify": model_config.get("ssl_verify", True),
336336
}
337337
if model.get("model_type", "embedding") == "multi_embedding":
338+
model_factory = model.get("model_factory", "").lower()
339+
if model_factory == "dashscope":
340+
return DashScopeMultimodalEmbedding(**common_kwargs)
338341
return JinaEmbedding(**common_kwargs)
339342
return OpenAICompatibleEmbedding(**common_kwargs)
340343

0 commit comments

Comments
 (0)