Skip to content
255 changes: 197 additions & 58 deletions astrbot/core/config/default.py

Large diffs are not rendered by default.

28 changes: 23 additions & 5 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ def dynamic_import_provider(self, type: str) -> None:
)
case "longcat_chat_completion":
from .sources.longcat_source import ProviderLongCat as ProviderLongCat
case "minimax_token_plan":
from .sources.minimax_token_plan_source import (
ProviderMiniMaxTokenPlan as ProviderMiniMaxTokenPlan,
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
Expand Down Expand Up @@ -465,6 +469,18 @@ def dynamic_import_provider(self, type: str) -> None:
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "nvidia_embedding":
from .sources.nvidia_embedding_source import (
NvidiaEmbeddingProvider as NvidiaEmbeddingProvider,
)
case "ollama_embedding":
from .sources.ollama_embedding_source import (
OllamaEmbeddingProvider as OllamaEmbeddingProvider,
)
case "vllm_embedding":
from .sources.vllm_embedding_source import (
VLLMEmbeddingProvider as VLLMEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
Expand Down Expand Up @@ -566,7 +582,9 @@ async def load_provider(self, provider_config: dict) -> None:
return

logger.info(
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
"Loading model %s(%s) ...",
provider_config["type"],
provider_config["id"],
)

# 动态导入
Expand All @@ -587,7 +605,7 @@ async def load_provider(self, provider_config: dict) -> None:

if provider_config["type"] not in provider_cls_map:
logger.error(
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
f"Provider adapter not found: {provider_config['type']}({provider_config['id']}). Skipped.",
exc_info=True,
)
return
Expand Down Expand Up @@ -621,7 +639,7 @@ async def load_provider(self, provider_config: dict) -> None:
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
f"Selected {provider_config['type']}({provider_config['id']}) as default STT provider",
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
Expand All @@ -644,7 +662,7 @@ async def load_provider(self, provider_config: dict) -> None:
):
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
f"Selected {provider_config['type']}({provider_config['id']}) as default TTS provider",
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
Expand All @@ -670,7 +688,7 @@ async def load_provider(self, provider_config: dict) -> None:
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
f"Selected {provider_config['type']}({provider_config['id']}) as default chat model provider",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
Expand Down
46 changes: 46 additions & 0 deletions astrbot/core/provider/sources/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from typing import Any

from astrbot import logger


COMMON_MODEL_DIMENSIONS = {
"bge-m3": 1024,
"bge-large-en-v1.5": 1024,
"bge-large-zh-v1.5": 1024,
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}


def parse_configured_embedding_dimension(
raw_dimension: Any,
*,
provider_label: str,
provider_id: str,
) -> int | None:
if raw_dimension in (None, ""):
return None

try:
dimension = int(raw_dimension)
except (TypeError, ValueError):
logger.warning(
"[%s] %s 的 embedding_dimensions 不是有效整数: %r",
provider_label,
provider_id,
raw_dimension,
)
return None

return dimension if dimension > 0 else None


def infer_embedding_dimension_from_model(model_name: Any) -> int | None:
normalized_model = str(model_name or "").strip().lower()
for model_key, dimension in COMMON_MODEL_DIMENSIONS.items():
if model_key in normalized_model:
return dimension
return None
179 changes: 155 additions & 24 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from ..entities import ProviderType
from ..provider import EmbeddingProvider
from ..register import register_provider_adapter
from .embedding_utils import (
infer_embedding_dimension_from_model,
parse_configured_embedding_dimension,
)


@register_provider_adapter(
Expand All @@ -18,12 +22,14 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.provider_config = provider_config
self.provider_settings = provider_settings

proxy = provider_config.get("proxy", "")
provider_id = provider_config.get("id", "unknown_id")
http_client = None
if proxy:
logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)

api_base = (
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
.strip()
Expand All @@ -33,56 +39,181 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
# /v4 see #5699
api_base = api_base + "/v1"

# [新增] 保存处理后的 api_base 并转换为小写,用于后续特征比对
self.api_base_normalized = api_base.lower()

logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")

self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
base_url=api_base,
timeout=int(provider_config.get("timeout", 20)),
http_client=http_client,
)
self.model = provider_config.get("embedding_model", "text-embedding-3-small")

# [新增] 运行时状态标记:一旦触发 400 错误将此设为 True
self._is_vllm_detected = False

def _is_vllm(self) -> bool:
"""检测是否是 vLLM(vLLM 不支持 dimensions 参数)"""
# 1. 优先检查运行时已证实的标记
if self._is_vllm_detected:
return True

# 2. [核心修改] 检查 API Key 是否为 "vllm"
api_key = self.provider_config.get("embedding_api_key", "")
if api_key and api_key.lower() == "vllm":
logger.info("[OpenAI Embedding] vLLM mode enabled by API Key 'vllm'.")
return True
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

# 3. 辅助检查:ID 或 URL 中是否显式包含 "vllm"
provider_id = self.provider_config.get("id", "").lower()
api_base = self.api_base_normalized.lower()
if "vllm" in provider_id or "vllm" in api_base:
logger.info(f"[OpenAI Embedding] Detected vLLM by id/api_base: {provider_id}")
return True

# 4. 移除对端口 (8000, 8001) 的静态判定,避免误伤其他兼容服务
return False

def _mark_as_vllm(self) -> None:
"""标记此实例为vLLM(通过运行时错误检测出来的)"""
self._is_vllm_detected = True
logger.info("[OpenAI Embedding] Marked as vLLM (runtime detection via error)")

async def get_embedding(self, text: str) -> list[float]:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
"""获取文本的嵌入"""
kwargs = self._embedding_kwargs()
embedding = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs,
)
embedding = await self._request_with_vllm_retry(text, kwargs, batch=False)
return embedding.data[0].embedding

async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
kwargs = self._embedding_kwargs()
embeddings = await self.client.embeddings.create(
input=text,
model=self.model,
**kwargs,
)
embeddings = await self._request_with_vllm_retry(text, kwargs, batch=True)
return [item.embedding for item in embeddings.data]

async def _request_with_vllm_retry(
self,
input_data: str | list[str],
kwargs: dict,
*,
batch: bool,
):
try:
return await self.client.embeddings.create(
input=input_data,
model=self.model,
**kwargs,
)
except Exception as exc:
if not self._should_retry_without_dimensions(exc, kwargs):
raise

if batch:
logger.warning(
f"[OpenAI Embedding] Detected vLLM dimensions error in batch mode, retrying without dimensions: {exc}"
)
else:
logger.warning(
f"[OpenAI Embedding] Detected vLLM dimensions error, retrying without dimensions parameter: {exc}"
)

kwargs_retry = {k: v for k, v in kwargs.items() if k != "dimensions"}
try:
embeddings = await self.client.embeddings.create(
input=input_data,
model=self.model,
**kwargs_retry,
)
except Exception as retry_error:
if batch:
logger.error(
f"[OpenAI Embedding] Batch retry without dimensions also failed: {retry_error}"
)
else:
logger.error(
f"[OpenAI Embedding] Retry without dimensions also failed: {retry_error}"
)
raise

if batch:
logger.info(
"[OpenAI Embedding] Successfully retrieved batch embeddings without dimensions parameter"
)
else:
logger.info(
"[OpenAI Embedding] Successfully retrieved embedding without dimensions parameter, marking as vLLM"
)

self._mark_as_vllm()
return embeddings

def _should_retry_without_dimensions(self, exc: Exception, kwargs: dict) -> bool:
if not kwargs.get("dimensions"):
return False

error_msg = str(exc).lower()
return "matryoshka" in error_msg or "dimensions" in error_msg

def _configured_dimension(self) -> int | None:
provider_id = self.provider_config.get("id", "unknown")
return parse_configured_embedding_dimension(
self.provider_config.get("embedding_dimensions", ""),
provider_label="OpenAI Embedding",
provider_id=provider_id,
)

def _embedding_kwargs(self) -> dict:
"""构建嵌入请求的可选参数"""
kwargs = {}
if "embedding_dimensions" in self.provider_config:
try:
kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
provider_id = self.provider_config.get("id", "unknown")
embedding_dim_config = self.provider_config.get("embedding_dimensions", "")
# 检查是否是vLLM
is_vllm = self._is_vllm()
if is_vllm:
logger.info(
f"[OpenAI Embedding] {provider_id}: Detected vLLM, skipping dimensions parameter (config value: '{embedding_dim_config}')"
)
return kwargs
# 非vLLM服务(OpenAI等)支持dimensions,读取配置
configured_dim = self._configured_dimension()
if configured_dim is not None:
kwargs["dimensions"] = configured_dim
logger.info(
f"[OpenAI Embedding] {provider_id}: Added dimensions parameter: {configured_dim}"
)
elif embedding_dim_config in (None, ""):
logger.info(
f"[OpenAI Embedding] {provider_id}: No embedding_dimensions configured, API will use default"
)
return kwargs

def get_dim(self) -> int:
"""获取向量的维度"""
if "embedding_dimensions" in self.provider_config:
try:
return int(self.provider_config["embedding_dimensions"])
except (ValueError, TypeError):
logger.warning(
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
)
provider_id = self.provider_config.get("id", "unknown")
embedding_dim_config = self.provider_config.get("embedding_dimensions", "")

configured_dim = self._configured_dimension()
if configured_dim is not None:
logger.info(
f"[OpenAI Embedding] {provider_id}: Dimension from config: {configured_dim}"
)
return configured_dim

model = self.provider_config.get("embedding_model", "")
inferred_dim = infer_embedding_dimension_from_model(model)
if inferred_dim:
logger.info(
f"[OpenAI Embedding] {provider_id}: Inferred dimension {inferred_dim} from model: {str(model).lower()}"
)
return inferred_dim

logger.warning(
f"[OpenAI Embedding] {provider_id}: Could not determine dimension (model: {str(model).lower()}, config: '{embedding_dim_config}')"
)
return 0

async def terminate(self):
Expand Down
Loading