Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 115 additions & 44 deletions sdk/ai/azure-ai-projects/azure/ai/projects/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,86 @@
logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Shared helpers used by both the sync and async AIProjectClient.get_openai_client()
# implementations. Defined at module level so the async client can import and reuse
# them without duplicating the logic.
# ---------------------------------------------------------------------------


def _resolve_openai_base_url(config: Any, agent_name: Optional[str], kwargs: dict) -> str:
"""Resolve the base URL for the (Async)OpenAI client.

:param config: Generated client configuration carrying ``endpoint`` and ``allow_preview``.
:type config: Any
:param agent_name: Optional hosted-agent name.
:type agent_name: str or None
:param kwargs: Caller keyword arguments; ``base_url`` is popped when present.
:type kwargs: dict
:return: The base URL to use for the (Async)OpenAI client.
:rtype: str
:raises ValueError: If ``agent_name`` is provided but ``allow_preview=True`` was not set.
"""
if "base_url" in kwargs:
return kwargs.pop("base_url")
if agent_name is not None:
if config.allow_preview:
return config.endpoint.rstrip("/") + f"/agents/{agent_name}/endpoint/protocols/openai"
raise ValueError(
"Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
"\nwhen constructing the AIProjectClient. Note that preview features are under development and "
"\nsubject to change. They should not be used in production environments."
)
return config.endpoint.rstrip("/") + "/openai/v1"


def _resolve_openai_query_params(config: Any, agent_name: Optional[str], kwargs: dict) -> dict:
"""Build the ``default_query`` dict for the (Async)OpenAI client.

:param config: Generated client configuration carrying ``api_version``.
:type config: Any
:param agent_name: Optional hosted-agent name.
:type agent_name: str or None
:param kwargs: Caller keyword arguments; ``default_query`` is popped when present.
:type kwargs: dict
:return: Query parameters to forward to the (Async)OpenAI client.
:rtype: dict
"""
default_query = dict[str, str](kwargs.pop("default_query", None) or {})
if agent_name is not None and "api-version" not in default_query:
default_query["api-version"] = config.api_version
return default_query


def _resolve_openai_default_headers(agent_name: Optional[str], kwargs: dict) -> dict:
"""Build the ``default_headers`` dict for the (Async)OpenAI client.

:param agent_name: Optional hosted-agent name.
:type agent_name: str or None
:param kwargs: Caller keyword arguments; ``default_headers`` is popped when present.
:type kwargs: dict
:return: Headers to forward to the (Async)OpenAI client.
:rtype: dict
"""
default_headers = dict[str, str](kwargs.pop("default_headers", None) or {})
if agent_name is not None and not _has_header_case_insensitive(default_headers, _FOUNDRY_FEATURES_HEADER_NAME):
default_headers[_FOUNDRY_FEATURES_HEADER_NAME] = _BETA_OPERATION_FEATURE_HEADERS["agents"]
return default_headers


def _build_openai_user_agent(custom_user_agent: Optional[str], openai_default_user_agent: str) -> str:
"""Build the SDK-prefixed User-Agent string for the (Async)OpenAI client.

:param custom_user_agent: Caller-supplied user_agent kwarg captured at construction time.
:type custom_user_agent: str or None
:param openai_default_user_agent: The OpenAI client's own default user-agent.
:type openai_default_user_agent: str
:return: Combined User-Agent string.
:rtype: str
"""
return "-".join(ua for ua in [custom_user_agent, "AIProjectClient"] if ua) + " " + openai_default_user_agent


class AIProjectClient(AIProjectClientGenerated): # pylint: disable=too-many-instance-attributes
"""AIProjectClient.

Expand Down Expand Up @@ -101,6 +181,35 @@ def __init__(

self.telemetry = TelemetryOperations(self) # type: ignore

def _get_openai_api_key(self, kwargs: dict):
"""Resolve the API key for the OpenAI client.

:param kwargs: Caller keyword arguments; ``api_key`` is popped when present.
:type kwargs: dict
:return: The API key string or a bearer-token-provider callable.
:rtype: str or Callable
"""
if "api_key" in kwargs:
return kwargs.pop("api_key")
return get_bearer_token_provider(
self._config.credential, # pylint: disable=protected-access
"https://ai.azure.com/.default",
)

def _get_openai_http_client(self, kwargs: dict):
"""Resolve the HTTP transport client for the OpenAI client.

:param kwargs: Caller keyword arguments; ``http_client`` is popped when present.
:type kwargs: dict
:return: An httpx.Client instance configured with logging transport, or ``None``.
:rtype: httpx.Client or None
"""
if "http_client" in kwargs:
return kwargs.pop("http_client")
if self._console_logging_enabled:
return httpx.Client(transport=_OpenAILoggingTransport())
return None

@distributed_trace
def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any) -> OpenAI:
"""Get an authenticated OpenAI client from the `openai` package.
Expand Down Expand Up @@ -131,51 +240,17 @@ def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any)

kwargs = kwargs.copy() if kwargs else {}

# Allow caller to override base_url
if "base_url" in kwargs:
base_url = kwargs.pop("base_url")
elif agent_name is not None:
if self._config.allow_preview:
base_url = (
self._config.endpoint.rstrip("/") + f"/agents/{agent_name}/endpoint/protocols/openai"
) # pylint: disable=protected-access
else:
raise ValueError(
"Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
"\nwhen constructing the AIProjectClient. Note that preview features are under development and "
"\nsubject to change. They should not be used in production environments."
)
else:
base_url = self._config.endpoint.rstrip("/") + "/openai/v1" # pylint: disable=protected-access

default_query = dict[str, str](kwargs.pop("default_query", None) or {})
if agent_name is not None and "api-version" not in default_query:
Comment thread
howieleung marked this conversation as resolved.
default_query["api-version"] = self._config.api_version # pylint: disable=protected-access
base_url = _resolve_openai_base_url(self._config, agent_name, kwargs)
default_query = _resolve_openai_query_params(self._config, agent_name, kwargs)

logger.debug( # pylint: disable=specify-parameter-names-in-call
"[get_openai_client] Creating OpenAI client using Entra ID authentication, base_url = `%s`", # pylint: disable=line-too-long
base_url,
)

# Allow caller to override api_key, otherwise use token provider
if "api_key" in kwargs:
api_key = kwargs.pop("api_key")
else:
api_key = get_bearer_token_provider(
self._config.credential, # pylint: disable=protected-access
"https://ai.azure.com/.default",
)

if "http_client" in kwargs:
http_client = kwargs.pop("http_client")
elif self._console_logging_enabled:
http_client = httpx.Client(transport=_OpenAILoggingTransport())
else:
http_client = None

default_headers = dict[str, str](kwargs.pop("default_headers", None) or {})
if agent_name is not None and not _has_header_case_insensitive(default_headers, _FOUNDRY_FEATURES_HEADER_NAME):
default_headers[_FOUNDRY_FEATURES_HEADER_NAME] = _BETA_OPERATION_FEATURE_HEADERS["agents"]
api_key = self._get_openai_api_key(kwargs)
http_client = self._get_openai_http_client(kwargs)
default_headers = _resolve_openai_default_headers(agent_name, kwargs)

openai_custom_user_agent = default_headers.get("User-Agent", None)

Expand All @@ -195,11 +270,7 @@ def _create_openai_client(**kwargs) -> OpenAI:
if openai_custom_user_agent:
final_user_agent = openai_custom_user_agent
else:
final_user_agent = (
"-".join(ua for ua in [self._custom_user_agent, "AIProjectClient"] if ua)
+ " "
+ openai_default_user_agent
)
final_user_agent = _build_openai_user_agent(self._custom_user_agent, openai_default_user_agent)

default_headers["User-Agent"] = final_user_agent

Expand Down
5 changes: 5 additions & 0 deletions sdk/ai/azure-ai-projects/azure/ai/projects/_patch.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ class AIProjectClient(AIProjectClientGenerated):
# To make mypy happy... otherwise imports of the below result in mypy "attr-defined" error
class _AuthSecretsFilter(logging.Filter): ...

def _resolve_openai_base_url(config: Any, agent_name: Optional[str], kwargs: dict) -> str: ...
def _resolve_openai_query_params(config: Any, agent_name: Optional[str], kwargs: dict) -> dict: ...
def _resolve_openai_default_headers(agent_name: Optional[str], kwargs: dict) -> dict: ...
def _build_openai_user_agent(custom_user_agent: Optional[str], openai_default_user_agent: str) -> str: ...

__all__: List[str] = ["AIProjectClient"]

def patch_sdk() -> None: ...
88 changes: 42 additions & 46 deletions sdk/ai/azure-ai-projects/azure/ai/projects/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
from azure.core.tracing.decorator import distributed_trace
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import get_bearer_token_provider
from .._patch import _AuthSecretsFilter
from ..models._patch import _BETA_OPERATION_FEATURE_HEADERS, _FOUNDRY_FEATURES_HEADER_NAME, _has_header_case_insensitive
from .._patch import (
_AuthSecretsFilter,
_build_openai_user_agent,
_resolve_openai_base_url,
_resolve_openai_default_headers,
_resolve_openai_query_params,
)
from ._client import AIProjectClient as AIProjectClientGenerated
from .operations import TelemetryOperations

Expand Down Expand Up @@ -101,6 +106,35 @@ def __init__(

self.telemetry = TelemetryOperations(self) # type: ignore

def _get_openai_api_key(self, kwargs: dict):
"""Resolve the API key for the AsyncOpenAI client.

:param kwargs: Caller keyword arguments; ``api_key`` is popped when present.
:type kwargs: dict
:return: The API key string or a bearer-token-provider callable.
:rtype: str or Callable
"""
if "api_key" in kwargs:
return kwargs.pop("api_key")
return get_bearer_token_provider(
self._config.credential, # pylint: disable=protected-access
"https://ai.azure.com/.default",
)

def _get_openai_http_client(self, kwargs: dict):
"""Resolve the HTTP transport client for the AsyncOpenAI client.

:param kwargs: Caller keyword arguments; ``http_client`` is popped when present.
:type kwargs: dict
:return: An httpx.AsyncClient instance configured with logging transport, or ``None``.
:rtype: httpx.AsyncClient or None
"""
if "http_client" in kwargs:
return kwargs.pop("http_client")
if self._console_logging_enabled:
return httpx.AsyncClient(transport=_OpenAILoggingTransport())
return None

@distributed_trace
def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any) -> AsyncOpenAI:
"""Get an authenticated AsyncOpenAI client from the `openai` package.
Expand Down Expand Up @@ -131,51 +165,17 @@ def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any)

kwargs = kwargs.copy() if kwargs else {}

# Allow caller to override base_url
if "base_url" in kwargs:
base_url = kwargs.pop("base_url")
elif agent_name is not None:
if self._config.allow_preview:
base_url = (
self._config.endpoint.rstrip("/") + f"/agents/{agent_name}/endpoint/protocols/openai"
) # pylint: disable=protected-access
else:
raise ValueError(
"Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
"\nwhen constructing the AIProjectClient. Note that preview features are under development and "
"\nsubject to change. They should not be used in production environments."
)
else:
base_url = self._config.endpoint.rstrip("/") + "/openai/v1" # pylint: disable=protected-access

default_query = dict[str, str](kwargs.pop("default_query", None) or {})
if agent_name is not None and "api-version" not in default_query:
default_query["api-version"] = self._config.api_version # pylint: disable=protected-access
base_url = _resolve_openai_base_url(self._config, agent_name, kwargs)
default_query = _resolve_openai_query_params(self._config, agent_name, kwargs)

logger.debug( # pylint: disable=specify-parameter-names-in-call
"[get_openai_client] Creating OpenAI client using Entra ID authentication, base_url = `%s`", # pylint: disable=line-too-long
base_url,
)

# Allow caller to override api_key, otherwise use token provider
if "api_key" in kwargs:
api_key = kwargs.pop("api_key")
else:
api_key = get_bearer_token_provider(
self._config.credential, # pylint: disable=protected-access
"https://ai.azure.com/.default",
)

if "http_client" in kwargs:
http_client = kwargs.pop("http_client")
elif self._console_logging_enabled:
http_client = httpx.AsyncClient(transport=_OpenAILoggingTransport())
else:
http_client = None

default_headers = dict[str, str](kwargs.pop("default_headers", None) or {})
if agent_name is not None and not _has_header_case_insensitive(default_headers, _FOUNDRY_FEATURES_HEADER_NAME):
default_headers[_FOUNDRY_FEATURES_HEADER_NAME] = _BETA_OPERATION_FEATURE_HEADERS["agents"]
api_key = self._get_openai_api_key(kwargs)
http_client = self._get_openai_http_client(kwargs)
default_headers = _resolve_openai_default_headers(agent_name, kwargs)

openai_custom_user_agent = default_headers.get("User-Agent", None)

Expand All @@ -195,11 +195,7 @@ def _create_openai_client(**kwargs) -> AsyncOpenAI:
if openai_custom_user_agent:
final_user_agent = openai_custom_user_agent
else:
final_user_agent = (
"-".join(ua for ua in [self._custom_user_agent, "AIProjectClient"] if ua)
+ " "
+ openai_default_user_agent
)
final_user_agent = _build_openai_user_agent(self._custom_user_agent, openai_default_user_agent)

default_headers["User-Agent"] = final_user_agent

Expand Down
69 changes: 69 additions & 0 deletions sdk/ai/azure-ai-projects/tests/responses/openai_test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# pylint: disable=line-too-long,useless-suppression
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
"""Shared helpers for unit-testing AIProjectClient.get_openai_client (sync and async).

These helpers build lightweight client stubs that bypass the real ``__init__`` so unit
tests can target individual branches of ``get_openai_client`` without making any
network calls.
"""

from typing import Optional
from unittest.mock import MagicMock

from azure.ai.projects import AIProjectClient
from azure.ai.projects.aio import AIProjectClient as AsyncAIProjectClient

ENDPOINT = "https://myaccount.services.ai.azure.com/api/projects/myproject"
API_VERSION = "2025-01-01"

# Patch targets used by tests to swap in mocked OpenAI/AsyncOpenAI constructors
# and bearer-token providers.
SYNC_OPENAI_PATCH = "azure.ai.projects._patch.OpenAI"
ASYNC_OPENAI_PATCH = "azure.ai.projects.aio._patch.AsyncOpenAI"
SYNC_TOKEN_PROVIDER_PATCH = "azure.ai.projects._patch.get_bearer_token_provider"
ASYNC_TOKEN_PROVIDER_PATCH = "azure.ai.projects.aio._patch.get_bearer_token_provider"


def make_sync_client(
allow_preview: bool = True,
console_logging: bool = False,
custom_user_agent: Optional[str] = None,
) -> AIProjectClient:
"""Return a minimal sync AIProjectClient stub suitable for unit-testing get_openai_client."""
client = AIProjectClient.__new__(AIProjectClient)
client._config = MagicMock()
client._config.endpoint = ENDPOINT
client._config.allow_preview = allow_preview
client._config.api_version = API_VERSION
client._config.credential = MagicMock()
client._console_logging_enabled = console_logging
client._custom_user_agent = custom_user_agent
return client


def make_async_client(
allow_preview: bool = True,
console_logging: bool = False,
custom_user_agent: Optional[str] = None,
) -> AsyncAIProjectClient:
"""Return a minimal async AIProjectClient stub suitable for unit-testing get_openai_client."""
client = AsyncAIProjectClient.__new__(AsyncAIProjectClient)
client._config = MagicMock()
client._config.endpoint = ENDPOINT
client._config.allow_preview = allow_preview
client._config.api_version = API_VERSION
client._config.credential = MagicMock()
client._console_logging_enabled = console_logging
client._custom_user_agent = custom_user_agent
return client


def mock_openai(user_agent: str = "openai/1.0"):
"""Return ``(mock_class, mock_instance)`` where ``mock_class`` acts as the OpenAI constructor."""
instance = MagicMock()
instance.user_agent = user_agent
mock_cls = MagicMock(return_value=instance)
return mock_cls, instance
Loading
Loading