Skip to content

Commit 9d8e083

Browse files
jnMetaCodeSoulter
authored andcommitted
fix: only pass dimensions when explicitly configured in embedding config (AstrBotDevs#6432)
* fix: only pass dimensions param when explicitly configured Models like bge-m3 don't support the dimensions parameter in the embedding API, causing HTTP 400 errors. Previously dimensions was always sent with a default value of 1024, even when the user never configured it. Now dimensions is only included in the request when embedding_dimensions is explicitly set in provider config. Closes AstrBotDevs#6421 Signed-off-by: JiangNan <1394485448@qq.com> * fix: handle invalid dimensions config and align get_dim return - Add try-except around int() conversion in _embedding_kwargs to gracefully handle invalid embedding_dimensions config values - Update get_dim() to return 0 when embedding_dimensions is not explicitly configured, so callers know dimensions weren't specified and can handle it accordingly - Both methods now share consistent logic for reading the config Signed-off-by: JiangNan <1394485448@qq.com> * fix: improve logging for invalid embedding_dimensions configuration --------- Signed-off-by: JiangNan <1394485448@qq.com> Co-authored-by: Soulter <905617992@qq.com>
1 parent 9ffb0bd commit 9d8e083

1 file changed

Lines changed: 24 additions & 3 deletions

File tree

astrbot/core/provider/sources/openai_embedding_source.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,46 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
4040

4141
async def get_embedding(self, text: str) -> list[float]:
4242
"""获取文本的嵌入"""
43+
kwargs = self._embedding_kwargs()
4344
embedding = await self.client.embeddings.create(
4445
input=text,
4546
model=self.model,
46-
dimensions=self.get_dim(),
47+
**kwargs,
4748
)
4849
return embedding.data[0].embedding
4950

5051
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
5152
"""批量获取文本的嵌入"""
53+
kwargs = self._embedding_kwargs()
5254
embeddings = await self.client.embeddings.create(
5355
input=text,
5456
model=self.model,
55-
dimensions=self.get_dim(),
57+
**kwargs,
5658
)
5759
return [item.embedding for item in embeddings.data]
5860

61+
def _embedding_kwargs(self) -> dict:
62+
"""构建嵌入请求的可选参数"""
63+
kwargs = {}
64+
if "embedding_dimensions" in self.provider_config:
65+
try:
66+
kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"])
67+
except (ValueError, TypeError):
68+
logger.warning(
69+
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
70+
)
71+
return kwargs
72+
5973
def get_dim(self) -> int:
6074
"""获取向量的维度"""
61-
return int(self.provider_config.get("embedding_dimensions", 1024))
75+
if "embedding_dimensions" in self.provider_config:
76+
try:
77+
return int(self.provider_config["embedding_dimensions"])
78+
except (ValueError, TypeError):
79+
logger.warning(
80+
f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored."
81+
)
82+
return 0
6283

6384
async def terminate(self):
6485
if self.client:

0 commit comments

Comments
 (0)