Skip to content

Commit 690b184

Browse files
authored
fix: preserve embedding api version suffixes (#8736)
1 parent f19f623 commit 690b184

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

astrbot/core/provider/sources/openai_embedding_source.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
import httpx
24
from openai import AsyncOpenAI
35

@@ -8,6 +10,13 @@
810
from ..register import register_provider_adapter
911

1012

13+
def _normalize_api_base(api_base: str) -> str:
14+
api_base = api_base.strip().removesuffix("/").removesuffix("/embeddings")
15+
if api_base and not re.search(r"/v\d+$", api_base):
16+
api_base = api_base + "/v1"
17+
return api_base
18+
19+
1120
@register_provider_adapter(
1221
"openai_embedding",
1322
"OpenAI API Embedding 提供商适配器",
@@ -24,15 +33,9 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
2433
if proxy:
2534
logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}")
2635
http_client = httpx.AsyncClient(proxy=proxy)
27-
api_base = (
36+
api_base = _normalize_api_base(
2837
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
29-
.strip()
30-
.removesuffix("/")
31-
.removesuffix("/embeddings")
3238
)
33-
if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
34-
# /v4 see #5699
35-
api_base = api_base + "/v1"
3639
logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")
3740
self.client = AsyncOpenAI(
3841
api_key=provider_config.get("embedding_api_key"),
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from astrbot.core.provider.sources.openai_embedding_source import _normalize_api_base
2+
3+
4+
def test_openai_embedding_api_base_keeps_version_suffixes():
5+
assert (
6+
_normalize_api_base("https://ark.cn-beijing.volces.com/api/plan/v3")
7+
== "https://ark.cn-beijing.volces.com/api/plan/v3"
8+
)
9+
assert _normalize_api_base("https://example.test/v4") == "https://example.test/v4"
10+
11+
12+
def test_openai_embedding_api_base_adds_default_version():
13+
assert _normalize_api_base("https://example.test/openai") == (
14+
"https://example.test/openai/v1"
15+
)
16+
assert _normalize_api_base("https://example.test/v1/embeddings") == (
17+
"https://example.test/v1"
18+
)

0 commit comments

Comments
 (0)