Skip to content

Commit c81ef26

Browse files
authored
fix: pass embedding dimensions to provider apis (#5411)
1 parent a5ae27c commit c81ef26

2 files changed

Lines changed: 16 additions & 2 deletions

File tree

astrbot/core/provider/sources/gemini_embedding_source.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ async def get_embedding(self, text: str) -> list[float]:
4848
result = await self.client.models.embed_content(
4949
model=self.model,
5050
contents=text,
51+
config=types.EmbedContentConfig(
52+
output_dimensionality=self.get_dim(),
53+
),
5154
)
5255
assert result.embeddings is not None
5356
assert result.embeddings[0].values is not None
@@ -61,6 +64,9 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]:
6164
result = await self.client.models.embed_content(
6265
model=self.model,
6366
contents=cast(types.ContentListUnion, text),
67+
config=types.EmbedContentConfig(
68+
output_dimensionality=self.get_dim(),
69+
),
6470
)
6571
assert result.embeddings is not None
6672

astrbot/core/provider/sources/openai_embedding_source.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,20 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
3636

3737
async def get_embedding(self, text: str) -> list[float]:
3838
"""获取文本的嵌入"""
39-
embedding = await self.client.embeddings.create(input=text, model=self.model)
39+
embedding = await self.client.embeddings.create(
40+
input=text,
41+
model=self.model,
42+
dimensions=self.get_dim(),
43+
)
4044
return embedding.data[0].embedding
4145

4246
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
4347
"""批量获取文本的嵌入"""
44-
embeddings = await self.client.embeddings.create(input=text, model=self.model)
48+
embeddings = await self.client.embeddings.create(
49+
input=text,
50+
model=self.model,
51+
dimensions=self.get_dim(),
52+
)
4553
return [item.embedding for item in embeddings.data]
4654

4755
def get_dim(self) -> int:

0 commit comments

Comments
 (0)