Skip to content

Commit 3ccc4cb

Browse files
committed
fix: append /v1 for openai embedding api base
1 parent c6f4dd1 commit 3ccc4cb

2 files changed

Lines changed: 54 additions & 2 deletions

File tree

astrbot/core/provider/sources/openai_embedding_source.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
provider_type=ProviderType.EMBEDDING,
1515
)
1616
class OpenAIEmbeddingProvider(EmbeddingProvider):
17+
@staticmethod
18+
def _normalize_embedding_api_base(api_base: str) -> str:
19+
normalized_api_base = api_base.rstrip("/")
20+
if not normalized_api_base.endswith("/v1"):
21+
normalized_api_base = f"{normalized_api_base}/v1"
22+
return normalized_api_base
23+
1724
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
1825
super().__init__(provider_config, provider_settings)
1926
self.provider_config = provider_config
@@ -27,6 +34,8 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
2734
api_base = provider_config.get(
2835
"embedding_api_base", "https://api.openai.com/v1"
2936
).strip()
37+
if api_base:
38+
api_base = self._normalize_embedding_api_base(api_base)
3039
logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")
3140
self.client = AsyncOpenAI(
3241
api_key=provider_config.get("embedding_api_key"),

tests/test_openai_source.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from openai.types.chat.chat_completion import ChatCompletion
55

66
from astrbot.core.provider.sources.groq_source import ProviderGroq
7+
from astrbot.core.provider.sources.openai_embedding_source import (
8+
OpenAIEmbeddingProvider,
9+
)
710
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
811

912

@@ -49,6 +52,20 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq:
4952
)
5053

5154

55+
def _make_embedding_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider:
56+
provider_config = {
57+
"id": "test-openai-embedding",
58+
"type": "openai_embedding",
59+
"embedding_api_key": "test-key",
60+
}
61+
if overrides:
62+
provider_config.update(overrides)
63+
return OpenAIEmbeddingProvider(
64+
provider_config=provider_config,
65+
provider_settings={},
66+
)
67+
68+
5269
@pytest.mark.asyncio
5370
async def test_handle_api_error_content_moderated_removes_images():
5471
provider = _make_provider(
@@ -234,7 +251,9 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history():
234251
provider._finally_convert_payload(payloads)
235252

236253
assistant_message = payloads["messages"][0]
237-
assert assistant_message["content"] == [{"type": "text", "text": "final answer"}]
254+
assert assistant_message["content"] == [
255+
{"type": "text", "text": "final answer"}
256+
]
238257
assert assistant_message["reasoning_content"] == "step 1"
239258
finally:
240259
await provider.terminate()
@@ -259,7 +278,9 @@ async def test_groq_payload_drops_reasoning_content_from_assistant_history():
259278
provider._finally_convert_payload(payloads)
260279

261280
assistant_message = payloads["messages"][0]
262-
assert assistant_message["content"] == [{"type": "text", "text": "final answer"}]
281+
assert assistant_message["content"] == [
282+
{"type": "text", "text": "final answer"}
283+
]
263284
assert "reasoning_content" not in assistant_message
264285
assert "reasoning" not in assistant_message
265286
finally:
@@ -533,3 +554,25 @@ async def fake_create(**kwargs):
533554
assert extra_body["temperature"] == 0.1
534555
finally:
535556
await provider.terminate()
557+
558+
559+
@pytest.mark.asyncio
560+
async def test_openai_embedding_provider_appends_v1_to_base_url_when_missing():
561+
provider = _make_embedding_provider(
562+
{"embedding_api_base": "https://example.com/openai"}
563+
)
564+
try:
565+
assert str(provider.client.base_url) == "https://example.com/openai/v1/"
566+
finally:
567+
await provider.terminate()
568+
569+
570+
@pytest.mark.asyncio
571+
async def test_openai_embedding_provider_preserves_existing_v1_suffix():
572+
provider = _make_embedding_provider(
573+
{"embedding_api_base": "https://example.com/openai/v1/"}
574+
)
575+
try:
576+
assert str(provider.client.base_url) == "https://example.com/openai/v1/"
577+
finally:
578+
await provider.terminate()

0 commit comments

Comments
 (0)