Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions test/unit/embed/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pydantic import SecretStr

from unstructured_ingest.embed.openai import OpenAIEmbeddingConfig, OpenAIEmbeddingEncoder


Expand Down Expand Up @@ -26,3 +28,49 @@ def test_embed_documents_does_not_break_element_to_dict(mocker):
assert elements[0]["text"] == "This is sentence 1"
assert elements[1]["text"] == "This is sentence 2"
assert mock_client.embeddings.create.call_count == 2


def test_get_client_without_default_headers(mocker):
"""default_headers=None must not pass the kwarg to the OpenAI constructor."""
mock_openai = mocker.patch("unstructured_ingest.embed.openai.OpenAI")
mocker.patch("unstructured_ingest.embed.openai.DefaultHttpxClient")
config = OpenAIEmbeddingConfig(api_key="key")
config.get_client()
_, kwargs = mock_openai.call_args
assert "default_headers" not in kwargs


def test_get_client_with_default_headers_extracts_secrets(mocker):
"""default_headers values are unwrapped from SecretStr before reaching the SDK."""
mock_openai = mocker.patch("unstructured_ingest.embed.openai.OpenAI")
mocker.patch("unstructured_ingest.embed.openai.DefaultHttpxClient")
config = OpenAIEmbeddingConfig(
api_key="key",
default_headers={"X-Custom": SecretStr("token123"), "X-Other": SecretStr("abc")},
)
config.get_client()
_, kwargs = mock_openai.call_args
assert kwargs["default_headers"] == {"X-Custom": "token123", "X-Other": "abc"}


def test_get_async_client_without_default_headers(mocker):
"""default_headers=None must not pass the kwarg to AsyncOpenAI."""
mock_async_openai = mocker.patch("unstructured_ingest.embed.openai.AsyncOpenAI")
mocker.patch("unstructured_ingest.embed.openai.DefaultAsyncHttpxClient")
config = OpenAIEmbeddingConfig(api_key="key")
config.get_async_client()
_, kwargs = mock_async_openai.call_args
assert "default_headers" not in kwargs


def test_get_async_client_with_default_headers_extracts_secrets(mocker):
"""default_headers values are unwrapped from SecretStr in async client path."""
mock_async_openai = mocker.patch("unstructured_ingest.embed.openai.AsyncOpenAI")
mocker.patch("unstructured_ingest.embed.openai.DefaultAsyncHttpxClient")
config = OpenAIEmbeddingConfig(
api_key="key",
default_headers={"Authorization": SecretStr("Bearer tok")},
)
config.get_async_client()
_, kwargs = mock_async_openai.call_args
assert kwargs["default_headers"] == {"Authorization": "Bearer tok"}
30 changes: 24 additions & 6 deletions unstructured_ingest/embed/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class OpenAIEmbeddingConfig(EmbeddingConfig):
default="text-embedding-ada-002", alias="model_name", description="OpenAI model name"
)
base_url: Optional[str] = Field(default=None, description="optional override for the base url")
default_headers: Optional[dict[str, SecretStr]] = Field(
default=None,
description="extra HTTP headers attached to every request",
)

@requires_dependencies(["openai"], extras="openai")
def wrap_error(self, e: Exception) -> Exception:
Expand Down Expand Up @@ -90,18 +94,32 @@ def get_client(self) -> "OpenAI":
from openai import DefaultHttpxClient, OpenAI

client = DefaultHttpxClient(verify=ssl_context_with_optional_ca_override())
return OpenAI(
api_key=self.api_key.get_secret_value(), http_client=client, base_url=self.base_url
)
kwargs = {
"api_key": self.api_key.get_secret_value(),
"http_client": client,
"base_url": self.base_url,
}
if self.default_headers:
kwargs["default_headers"] = {
k: v.get_secret_value() for k, v in self.default_headers.items()
}
return OpenAI(**kwargs)

@requires_dependencies(["openai"], extras="openai")
def get_async_client(self) -> "AsyncOpenAI":
from openai import AsyncOpenAI, DefaultAsyncHttpxClient

client = DefaultAsyncHttpxClient(verify=ssl_context_with_optional_ca_override())
return AsyncOpenAI(
api_key=self.api_key.get_secret_value(), http_client=client, base_url=self.base_url
)
kwargs = {
"api_key": self.api_key.get_secret_value(),
"http_client": client,
"base_url": self.base_url,
}
if self.default_headers:
kwargs["default_headers"] = {
k: v.get_secret_value() for k, v in self.default_headers.items()
}
return AsyncOpenAI(**kwargs)


@dataclass
Expand Down
Loading