-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
fix: append /v1 for OpenAI embedding api base #6910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| from urllib.parse import urlsplit, urlunsplit | ||
|
|
||
| import httpx | ||
| from openai import AsyncOpenAI | ||
|
|
||
|
|
@@ -14,6 +16,29 @@ | |
| provider_type=ProviderType.EMBEDDING, | ||
| ) | ||
| class OpenAIEmbeddingProvider(EmbeddingProvider): | ||
| DEFAULT_EMBEDDING_API_BASE = "https://api.openai.com/v1" | ||
|
|
||
| @staticmethod | ||
| def _normalize_embedding_api_base(api_base: str) -> str: | ||
| """Normalize root-style embedding base URLs while avoiding path-specific ones. | ||
|
|
||
| Auto-append ``/v1`` only for host roots or single-segment paths such as | ||
| ``https://example.com`` or ``https://example.com/openai``. More specific | ||
| paths (for example ``/v1-beta`` or ``/v1/embeddings``) are preserved as-is. | ||
| """ | ||
| normalized_api_base = api_base.rstrip("/") | ||
| parsed = urlsplit(normalized_api_base) | ||
| path_segments = [segment for segment in parsed.path.split("/") if segment] | ||
| has_version_segment = any( | ||
| len(segment) > 1 and segment.startswith("v") and segment[1].isdigit() | ||
| for segment in path_segments | ||
| ) | ||
| if has_version_segment or len(path_segments) > 1: | ||
| return normalized_api_base | ||
|
|
||
| normalized_path = f"{parsed.path.rstrip('/')}/v1" if parsed.path else "/v1" | ||
| return urlunsplit(parsed._replace(path=normalized_path)) | ||
|
|
||
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: | ||
| super().__init__(provider_config, provider_settings) | ||
| self.provider_config = provider_config | ||
|
|
@@ -25,8 +50,12 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: | |
| logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") | ||
| http_client = httpx.AsyncClient(proxy=proxy) | ||
| api_base = provider_config.get( | ||
| "embedding_api_base", "https://api.openai.com/v1" | ||
| "embedding_api_base", self.DEFAULT_EMBEDDING_API_BASE | ||
| ).strip() | ||
| if api_base: | ||
| api_base = self._normalize_embedding_api_base(api_base) | ||
|
Comment on lines
+55
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While the current logic correctly normalizes a non-empty api_base = self._normalize_embedding_api_base(api_base or "https://api.openai.com/v1") |
||
| else: | ||
| api_base = self.DEFAULT_EMBEDDING_API_BASE | ||
| logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") | ||
| self.client = AsyncOpenAI( | ||
| api_key=provider_config.get("embedding_api_key"), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |
| from openai.types.chat.chat_completion import ChatCompletion | ||
|
|
||
| from astrbot.core.provider.sources.groq_source import ProviderGroq | ||
| from astrbot.core.provider.sources.openai_embedding_source import ( | ||
| OpenAIEmbeddingProvider, | ||
| ) | ||
| from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial | ||
|
|
||
|
|
||
|
|
@@ -49,6 +52,20 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq: | |
| ) | ||
|
|
||
|
|
||
| def _make_embedding_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider: | ||
| provider_config = { | ||
| "id": "test-openai-embedding", | ||
| "type": "openai_embedding", | ||
| "embedding_api_key": "test-key", | ||
| } | ||
| if overrides: | ||
| provider_config.update(overrides) | ||
| return OpenAIEmbeddingProvider( | ||
| provider_config=provider_config, | ||
| provider_settings={}, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_handle_api_error_content_moderated_removes_images(): | ||
| provider = _make_provider( | ||
|
Comment on lines
69
to
71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Cover edge case where Given the current normalization ( Suggested implementation: finally:
await provider.terminate()
def test_embedding_api_base_trailing_slash_normalized():
provider = _make_provider(
overrides={"embedding_api_base": "https://example.com/openai/"}
)
# The provider should normalize the embedding API base by removing any
# trailing slash and then appending `/v1`, resulting in a single slash.
# This asserts we do *not* end up with `https://example.com/openai//v1/`.
base_url = str(provider.client._client.base_url)
assert base_url == "https://example.com/openai/v1/"Depending on how
The key behavior to lock in is that |
||
|
|
@@ -234,7 +251,9 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): | |
| provider._finally_convert_payload(payloads) | ||
|
|
||
| assistant_message = payloads["messages"][0] | ||
| assert assistant_message["content"] == [{"type": "text", "text": "final answer"}] | ||
| assert assistant_message["content"] == [ | ||
| {"type": "text", "text": "final answer"} | ||
| ] | ||
| assert assistant_message["reasoning_content"] == "step 1" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
@@ -259,7 +278,9 @@ async def test_groq_payload_drops_reasoning_content_from_assistant_history(): | |
| provider._finally_convert_payload(payloads) | ||
|
|
||
| assistant_message = payloads["messages"][0] | ||
| assert assistant_message["content"] == [{"type": "text", "text": "final answer"}] | ||
| assert assistant_message["content"] == [ | ||
| {"type": "text", "text": "final answer"} | ||
| ] | ||
| assert "reasoning_content" not in assistant_message | ||
| assert "reasoning" not in assistant_message | ||
| finally: | ||
|
|
@@ -533,3 +554,60 @@ async def fake_create(**kwargs): | |
| assert extra_body["temperature"] == 0.1 | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_appends_v1_to_base_url_when_missing(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai"} | ||
| ) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://example.com/openai/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_preserves_existing_v1_suffix(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai/v1/"} | ||
| ) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://example.com/openai/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_normalizes_trailing_slash_without_double_slash(): | ||
| provider = _make_embedding_provider( | ||
| {"embedding_api_base": "https://example.com/openai/"} | ||
| ) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://example.com/openai/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_falls_back_to_default_base_for_blank_config(): | ||
| provider = _make_embedding_provider({"embedding_api_base": " "}) | ||
| try: | ||
| assert str(provider.client.base_url) == "https://api.openai.com/v1/" | ||
| finally: | ||
| await provider.terminate() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_openai_embedding_provider_preserves_versioned_or_specific_paths(): | ||
| base_urls = [ | ||
| "https://example.com/v1-beta", | ||
| "https://example.com/v1/embeddings", | ||
| ] | ||
|
|
||
| for base_url in base_urls: | ||
| provider = _make_embedding_provider({"embedding_api_base": base_url}) | ||
| try: | ||
| assert str(provider.client.base_url) == f"{base_url.rstrip('/')}/" | ||
| finally: | ||
| await provider.terminate() | ||
Uh oh!
There was an error while loading. Please reload this page.