Skip to content

Commit 64ae120

Browse files
committed
fix(embedding): 自动检测改为探测 OpenAI embedding 最大可用维度
1 parent b13bf36 commit 64ae120

4 files changed

Lines changed: 87 additions & 3 deletions

File tree

astrbot/core/provider/provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,10 @@ def get_dim(self) -> int:
305305
"""获取向量的维度"""
306306
...
307307

308+
async def detect_dim(self) -> int:
309+
"""探测模型原生向量维度(默认实现)"""
310+
return len(await self.get_embedding("astrbot"))
311+
308312
async def test(self) -> None:
309313
await self.get_embedding("astrbot")
310314

astrbot/core/provider/sources/gemini_embedding_source.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
7878
except APIError as e:
7979
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
8080

81+
async def detect_dim(self) -> int:
82+
"""探测模型原生向量维度(不传 output_dimensionality)"""
83+
try:
84+
result = await self.client.models.embed_content(
85+
model=self.model,
86+
contents="echo",
87+
)
88+
assert result.embeddings is not None
89+
assert result.embeddings[0].values is not None
90+
return len(result.embeddings[0].values)
91+
except APIError as e:
92+
raise Exception(f"Gemini Embedding 维度探测失败: {e.message}")
93+
8194
def get_dim(self) -> int:
8295
"""获取向量的维度"""
8396
return int(self.provider_config.get("embedding_dimensions", 768))

astrbot/core/provider/sources/openai_embedding_source.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,64 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
5252
)
5353
return [item.embedding for item in embeddings.data]
5454

55+
async def detect_dim(self) -> int:
56+
"""探测模型可用的最大向量维度"""
57+
58+
async def _request_dim(dimensions: int | None) -> int:
59+
kwargs = {
60+
"input": "echo",
61+
"model": self.model,
62+
}
63+
if dimensions is not None:
64+
kwargs["dimensions"] = dimensions
65+
embedding = await self.client.embeddings.create(**kwargs)
66+
return len(embedding.data[0].embedding)
67+
68+
# 1) 默认调用,获取当前默认维度
69+
base_dim = await _request_dim(None)
70+
71+
# 2) 先判断 dimensions 参数是否可调
72+
probe_dim = base_dim + 1
73+
try:
74+
probe_result = await _request_dim(probe_dim)
75+
if probe_result != probe_dim:
76+
return base_dim
77+
except Exception:
78+
return base_dim
79+
80+
# 3) 可调时探测上界:指数扩张 + 二分
81+
max_cap = 32768
82+
low = probe_dim
83+
high = max(base_dim * 2, probe_dim + 1)
84+
if high > max_cap:
85+
high = max_cap
86+
87+
while high < max_cap:
88+
try:
89+
result_dim = await _request_dim(high)
90+
if result_dim != high:
91+
break
92+
low = high
93+
high = min(high * 2, max_cap)
94+
except Exception:
95+
break
96+
97+
left = low + 1
98+
right = high - 1
99+
while left <= right:
100+
mid = (left + right) // 2
101+
try:
102+
result_dim = await _request_dim(mid)
103+
if result_dim == mid:
104+
low = mid
105+
left = mid + 1
106+
else:
107+
right = mid - 1
108+
except Exception:
109+
right = mid - 1
110+
111+
return low
112+
55113
def get_dim(self) -> int:
56114
"""获取向量的维度"""
57115
return int(self.provider_config.get("embedding_dimensions", 1024))

astrbot/dashboard/routes/config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,16 @@ async def get_embedding_dim(self):
754754
if not provider_type:
755755
return Response().error("provider_config 缺少 type 字段").__dict__
756756

757+
# 首次添加某类提供商时,provider_cls_map 可能尚未注册该适配器
758+
if provider_type not in provider_cls_map:
759+
try:
760+
self.core_lifecycle.provider_manager.dynamic_import_provider(
761+
provider_type,
762+
)
763+
except ImportError as e:
764+
logger.error(traceback.format_exc())
765+
return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__
766+
757767
# 获取对应的 provider 类
758768
if provider_type not in provider_cls_map:
759769
return (
@@ -779,9 +789,8 @@ async def get_embedding_dim(self):
779789
if inspect.iscoroutinefunction(init_fn):
780790
await init_fn()
781791

782-
# 获取嵌入向量维度
783-
vec = await inst.get_embedding("echo")
784-
dim = len(vec)
792+
# 探测嵌入向量维度(优先使用 provider 的原生探测逻辑)
793+
dim = await inst.detect_dim()
785794

786795
logger.info(
787796
f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}",

0 commit comments

Comments
 (0)