Skip to content

Commit f38464e

Browse files
authored
Fix pylint complaining about too many if-else branch in get_openai_cl… (#46536)
* Fix pylint complaining about too many if-else branch in get_openai_client implementation * Resolved comments * Remove user-agent construction tests for sync and async clients
1 parent a04a10b commit f38464e

8 files changed

Lines changed: 600 additions & 97 deletions

File tree

sdk/ai/azure-ai-projects/azure/ai/projects/_patch.py

Lines changed: 115 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,86 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27+
# ---------------------------------------------------------------------------
28+
# Shared helpers used by both the sync and async AIProjectClient.get_openai_client()
29+
# implementations. Defined at module level so the async client can import and reuse
30+
# them without duplicating the logic.
31+
# ---------------------------------------------------------------------------
32+
33+
34+
def _resolve_openai_base_url(config: Any, agent_name: Optional[str], kwargs: dict) -> str:
35+
"""Resolve the base URL for the (Async)OpenAI client.
36+
37+
:param config: Generated client configuration carrying ``endpoint`` and ``allow_preview``.
38+
:type config: Any
39+
:param agent_name: Optional hosted-agent name.
40+
:type agent_name: str or None
41+
:param kwargs: Caller keyword arguments; ``base_url`` is popped when present.
42+
:type kwargs: dict
43+
:return: The base URL to use for the (Async)OpenAI client.
44+
:rtype: str
45+
:raises ValueError: If ``agent_name`` is provided but ``allow_preview=True`` was not set.
46+
"""
47+
if "base_url" in kwargs:
48+
return kwargs.pop("base_url")
49+
if agent_name is not None:
50+
if config.allow_preview:
51+
return config.endpoint.rstrip("/") + f"/agents/{agent_name}/endpoint/protocols/openai"
52+
raise ValueError(
53+
"Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
54+
"\nwhen constructing the AIProjectClient. Note that preview features are under development and "
55+
"\nsubject to change. They should not be used in production environments."
56+
)
57+
return config.endpoint.rstrip("/") + "/openai/v1"
58+
59+
60+
def _resolve_openai_query_params(config: Any, agent_name: Optional[str], kwargs: dict) -> dict:
61+
"""Build the ``default_query`` dict for the (Async)OpenAI client.
62+
63+
:param config: Generated client configuration carrying ``api_version``.
64+
:type config: Any
65+
:param agent_name: Optional hosted-agent name.
66+
:type agent_name: str or None
67+
:param kwargs: Caller keyword arguments; ``default_query`` is popped when present.
68+
:type kwargs: dict
69+
:return: Query parameters to forward to the (Async)OpenAI client.
70+
:rtype: dict
71+
"""
72+
default_query = dict[str, str](kwargs.pop("default_query", None) or {})
73+
if agent_name is not None and "api-version" not in default_query:
74+
default_query["api-version"] = config.api_version
75+
return default_query
76+
77+
78+
def _resolve_openai_default_headers(agent_name: Optional[str], kwargs: dict) -> dict:
79+
"""Build the ``default_headers`` dict for the (Async)OpenAI client.
80+
81+
:param agent_name: Optional hosted-agent name.
82+
:type agent_name: str or None
83+
:param kwargs: Caller keyword arguments; ``default_headers`` is popped when present.
84+
:type kwargs: dict
85+
:return: Headers to forward to the (Async)OpenAI client.
86+
:rtype: dict
87+
"""
88+
default_headers = dict[str, str](kwargs.pop("default_headers", None) or {})
89+
if agent_name is not None and not _has_header_case_insensitive(default_headers, _FOUNDRY_FEATURES_HEADER_NAME):
90+
default_headers[_FOUNDRY_FEATURES_HEADER_NAME] = _BETA_OPERATION_FEATURE_HEADERS["agents"]
91+
return default_headers
92+
93+
94+
def _build_openai_user_agent(custom_user_agent: Optional[str], openai_default_user_agent: str) -> str:
95+
"""Build the SDK-prefixed User-Agent string for the (Async)OpenAI client.
96+
97+
:param custom_user_agent: Caller-supplied user_agent kwarg captured at construction time.
98+
:type custom_user_agent: str or None
99+
:param openai_default_user_agent: The OpenAI client's own default user-agent.
100+
:type openai_default_user_agent: str
101+
:return: Combined User-Agent string.
102+
:rtype: str
103+
"""
104+
return "-".join(ua for ua in [custom_user_agent, "AIProjectClient"] if ua) + " " + openai_default_user_agent
105+
106+
27107
class AIProjectClient(AIProjectClientGenerated): # pylint: disable=too-many-instance-attributes
28108
"""AIProjectClient.
29109
@@ -101,6 +181,35 @@ def __init__(
101181

102182
self.telemetry = TelemetryOperations(self) # type: ignore
103183

184+
def _get_openai_api_key(self, kwargs: dict):
185+
"""Resolve the API key for the OpenAI client.
186+
187+
:param kwargs: Caller keyword arguments; ``api_key`` is popped when present.
188+
:type kwargs: dict
189+
:return: The API key string or a bearer-token-provider callable.
190+
:rtype: str or Callable
191+
"""
192+
if "api_key" in kwargs:
193+
return kwargs.pop("api_key")
194+
return get_bearer_token_provider(
195+
self._config.credential, # pylint: disable=protected-access
196+
"https://ai.azure.com/.default",
197+
)
198+
199+
def _get_openai_http_client(self, kwargs: dict):
200+
"""Resolve the HTTP transport client for the OpenAI client.
201+
202+
:param kwargs: Caller keyword arguments; ``http_client`` is popped when present.
203+
:type kwargs: dict
204+
:return: An httpx.Client instance configured with logging transport, or ``None``.
205+
:rtype: httpx.Client or None
206+
"""
207+
if "http_client" in kwargs:
208+
return kwargs.pop("http_client")
209+
if self._console_logging_enabled:
210+
return httpx.Client(transport=_OpenAILoggingTransport())
211+
return None
212+
104213
@distributed_trace
105214
def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any) -> OpenAI:
106215
"""Get an authenticated OpenAI client from the `openai` package.
@@ -131,51 +240,17 @@ def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any)
131240

132241
kwargs = kwargs.copy() if kwargs else {}
133242

134-
# Allow caller to override base_url
135-
if "base_url" in kwargs:
136-
base_url = kwargs.pop("base_url")
137-
elif agent_name is not None:
138-
if self._config.allow_preview:
139-
base_url = (
140-
self._config.endpoint.rstrip("/") + f"/agents/{agent_name}/endpoint/protocols/openai"
141-
) # pylint: disable=protected-access
142-
else:
143-
raise ValueError(
144-
"Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
145-
"\nwhen constructing the AIProjectClient. Note that preview features are under development and "
146-
"\nsubject to change. They should not be used in production environments."
147-
)
148-
else:
149-
base_url = self._config.endpoint.rstrip("/") + "/openai/v1" # pylint: disable=protected-access
150-
151-
default_query = dict[str, str](kwargs.pop("default_query", None) or {})
152-
if agent_name is not None and "api-version" not in default_query:
153-
default_query["api-version"] = self._config.api_version # pylint: disable=protected-access
243+
base_url = _resolve_openai_base_url(self._config, agent_name, kwargs)
244+
default_query = _resolve_openai_query_params(self._config, agent_name, kwargs)
154245

155246
logger.debug( # pylint: disable=specify-parameter-names-in-call
156247
"[get_openai_client] Creating OpenAI client using Entra ID authentication, base_url = `%s`", # pylint: disable=line-too-long
157248
base_url,
158249
)
159250

160-
# Allow caller to override api_key, otherwise use token provider
161-
if "api_key" in kwargs:
162-
api_key = kwargs.pop("api_key")
163-
else:
164-
api_key = get_bearer_token_provider(
165-
self._config.credential, # pylint: disable=protected-access
166-
"https://ai.azure.com/.default",
167-
)
168-
169-
if "http_client" in kwargs:
170-
http_client = kwargs.pop("http_client")
171-
elif self._console_logging_enabled:
172-
http_client = httpx.Client(transport=_OpenAILoggingTransport())
173-
else:
174-
http_client = None
175-
176-
default_headers = dict[str, str](kwargs.pop("default_headers", None) or {})
177-
if agent_name is not None and not _has_header_case_insensitive(default_headers, _FOUNDRY_FEATURES_HEADER_NAME):
178-
default_headers[_FOUNDRY_FEATURES_HEADER_NAME] = _BETA_OPERATION_FEATURE_HEADERS["agents"]
251+
api_key = self._get_openai_api_key(kwargs)
252+
http_client = self._get_openai_http_client(kwargs)
253+
default_headers = _resolve_openai_default_headers(agent_name, kwargs)
179254

180255
openai_custom_user_agent = default_headers.get("User-Agent", None)
181256

@@ -195,11 +270,7 @@ def _create_openai_client(**kwargs) -> OpenAI:
195270
if openai_custom_user_agent:
196271
final_user_agent = openai_custom_user_agent
197272
else:
198-
final_user_agent = (
199-
"-".join(ua for ua in [self._custom_user_agent, "AIProjectClient"] if ua)
200-
+ " "
201-
+ openai_default_user_agent
202-
)
273+
final_user_agent = _build_openai_user_agent(self._custom_user_agent, openai_default_user_agent)
203274

204275
default_headers["User-Agent"] = final_user_agent
205276

sdk/ai/azure-ai-projects/azure/ai/projects/_patch.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ class AIProjectClient(AIProjectClientGenerated):
108108
# To make mypy happy... otherwise imports of the below result in mypy "attr-defined" error
109109
class _AuthSecretsFilter(logging.Filter): ...
110110

111+
def _resolve_openai_base_url(config: Any, agent_name: Optional[str], kwargs: dict) -> str: ...
112+
def _resolve_openai_query_params(config: Any, agent_name: Optional[str], kwargs: dict) -> dict: ...
113+
def _resolve_openai_default_headers(agent_name: Optional[str], kwargs: dict) -> dict: ...
114+
def _build_openai_user_agent(custom_user_agent: Optional[str], openai_default_user_agent: str) -> str: ...
115+
111116
__all__: List[str] = ["AIProjectClient"]
112117

113118
def patch_sdk() -> None: ...

sdk/ai/azure-ai-projects/azure/ai/projects/aio/_patch.py

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
from azure.core.tracing.decorator import distributed_trace
1717
from azure.core.credentials_async import AsyncTokenCredential
1818
from azure.identity.aio import get_bearer_token_provider
19-
from .._patch import _AuthSecretsFilter
20-
from ..models._patch import _BETA_OPERATION_FEATURE_HEADERS, _FOUNDRY_FEATURES_HEADER_NAME, _has_header_case_insensitive
19+
from .._patch import (
20+
_AuthSecretsFilter,
21+
_build_openai_user_agent,
22+
_resolve_openai_base_url,
23+
_resolve_openai_default_headers,
24+
_resolve_openai_query_params,
25+
)
2126
from ._client import AIProjectClient as AIProjectClientGenerated
2227
from .operations import TelemetryOperations
2328

@@ -101,6 +106,35 @@ def __init__(
101106

102107
self.telemetry = TelemetryOperations(self) # type: ignore
103108

109+
def _get_openai_api_key(self, kwargs: dict):
110+
"""Resolve the API key for the AsyncOpenAI client.
111+
112+
:param kwargs: Caller keyword arguments; ``api_key`` is popped when present.
113+
:type kwargs: dict
114+
:return: The API key string or a bearer-token-provider callable.
115+
:rtype: str or Callable
116+
"""
117+
if "api_key" in kwargs:
118+
return kwargs.pop("api_key")
119+
return get_bearer_token_provider(
120+
self._config.credential, # pylint: disable=protected-access
121+
"https://ai.azure.com/.default",
122+
)
123+
124+
def _get_openai_http_client(self, kwargs: dict):
125+
"""Resolve the HTTP transport client for the AsyncOpenAI client.
126+
127+
:param kwargs: Caller keyword arguments; ``http_client`` is popped when present.
128+
:type kwargs: dict
129+
:return: An httpx.AsyncClient instance configured with logging transport, or ``None``.
130+
:rtype: httpx.AsyncClient or None
131+
"""
132+
if "http_client" in kwargs:
133+
return kwargs.pop("http_client")
134+
if self._console_logging_enabled:
135+
return httpx.AsyncClient(transport=_OpenAILoggingTransport())
136+
return None
137+
104138
@distributed_trace
105139
def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any) -> AsyncOpenAI:
106140
"""Get an authenticated AsyncOpenAI client from the `openai` package.
@@ -131,51 +165,17 @@ def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any)
131165

132166
kwargs = kwargs.copy() if kwargs else {}
133167

134-
# Allow caller to override base_url
135-
if "base_url" in kwargs:
136-
base_url = kwargs.pop("base_url")
137-
elif agent_name is not None:
138-
if self._config.allow_preview:
139-
base_url = (
140-
self._config.endpoint.rstrip("/") + f"/agents/{agent_name}/endpoint/protocols/openai"
141-
) # pylint: disable=protected-access
142-
else:
143-
raise ValueError(
144-
"Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
145-
"\nwhen constructing the AIProjectClient. Note that preview features are under development and "
146-
"\nsubject to change. They should not be used in production environments."
147-
)
148-
else:
149-
base_url = self._config.endpoint.rstrip("/") + "/openai/v1" # pylint: disable=protected-access
150-
151-
default_query = dict[str, str](kwargs.pop("default_query", None) or {})
152-
if agent_name is not None and "api-version" not in default_query:
153-
default_query["api-version"] = self._config.api_version # pylint: disable=protected-access
168+
base_url = _resolve_openai_base_url(self._config, agent_name, kwargs)
169+
default_query = _resolve_openai_query_params(self._config, agent_name, kwargs)
154170

155171
logger.debug( # pylint: disable=specify-parameter-names-in-call
156172
"[get_openai_client] Creating OpenAI client using Entra ID authentication, base_url = `%s`", # pylint: disable=line-too-long
157173
base_url,
158174
)
159175

160-
# Allow caller to override api_key, otherwise use token provider
161-
if "api_key" in kwargs:
162-
api_key = kwargs.pop("api_key")
163-
else:
164-
api_key = get_bearer_token_provider(
165-
self._config.credential, # pylint: disable=protected-access
166-
"https://ai.azure.com/.default",
167-
)
168-
169-
if "http_client" in kwargs:
170-
http_client = kwargs.pop("http_client")
171-
elif self._console_logging_enabled:
172-
http_client = httpx.AsyncClient(transport=_OpenAILoggingTransport())
173-
else:
174-
http_client = None
175-
176-
default_headers = dict[str, str](kwargs.pop("default_headers", None) or {})
177-
if agent_name is not None and not _has_header_case_insensitive(default_headers, _FOUNDRY_FEATURES_HEADER_NAME):
178-
default_headers[_FOUNDRY_FEATURES_HEADER_NAME] = _BETA_OPERATION_FEATURE_HEADERS["agents"]
176+
api_key = self._get_openai_api_key(kwargs)
177+
http_client = self._get_openai_http_client(kwargs)
178+
default_headers = _resolve_openai_default_headers(agent_name, kwargs)
179179

180180
openai_custom_user_agent = default_headers.get("User-Agent", None)
181181

@@ -195,11 +195,7 @@ def _create_openai_client(**kwargs) -> AsyncOpenAI:
195195
if openai_custom_user_agent:
196196
final_user_agent = openai_custom_user_agent
197197
else:
198-
final_user_agent = (
199-
"-".join(ua for ua in [self._custom_user_agent, "AIProjectClient"] if ua)
200-
+ " "
201-
+ openai_default_user_agent
202-
)
198+
final_user_agent = _build_openai_user_agent(self._custom_user_agent, openai_default_user_agent)
203199

204200
default_headers["User-Agent"] = final_user_agent
205201

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# pylint: disable=line-too-long,useless-suppression
2+
# ------------------------------------
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT License.
5+
# ------------------------------------
6+
"""Shared helpers for unit-testing AIProjectClient.get_openai_client (sync and async).
7+
8+
These helpers build lightweight client stubs that bypass the real ``__init__`` so unit
9+
tests can target individual branches of ``get_openai_client`` without making any
10+
network calls.
11+
"""
12+
13+
from typing import Optional
14+
from unittest.mock import MagicMock
15+
16+
from azure.ai.projects import AIProjectClient
17+
from azure.ai.projects.aio import AIProjectClient as AsyncAIProjectClient
18+
19+
ENDPOINT = "https://myaccount.services.ai.azure.com/api/projects/myproject"
20+
API_VERSION = "2025-01-01"
21+
22+
# Patch targets used by tests to swap in mocked OpenAI/AsyncOpenAI constructors
23+
# and bearer-token providers.
24+
SYNC_OPENAI_PATCH = "azure.ai.projects._patch.OpenAI"
25+
ASYNC_OPENAI_PATCH = "azure.ai.projects.aio._patch.AsyncOpenAI"
26+
SYNC_TOKEN_PROVIDER_PATCH = "azure.ai.projects._patch.get_bearer_token_provider"
27+
ASYNC_TOKEN_PROVIDER_PATCH = "azure.ai.projects.aio._patch.get_bearer_token_provider"
28+
29+
30+
def make_sync_client(
31+
allow_preview: bool = True,
32+
console_logging: bool = False,
33+
custom_user_agent: Optional[str] = None,
34+
) -> AIProjectClient:
35+
"""Return a minimal sync AIProjectClient stub suitable for unit-testing get_openai_client."""
36+
client = AIProjectClient.__new__(AIProjectClient)
37+
client._config = MagicMock()
38+
client._config.endpoint = ENDPOINT
39+
client._config.allow_preview = allow_preview
40+
client._config.api_version = API_VERSION
41+
client._config.credential = MagicMock()
42+
client._console_logging_enabled = console_logging
43+
client._custom_user_agent = custom_user_agent
44+
return client
45+
46+
47+
def make_async_client(
48+
allow_preview: bool = True,
49+
console_logging: bool = False,
50+
custom_user_agent: Optional[str] = None,
51+
) -> AsyncAIProjectClient:
52+
"""Return a minimal async AIProjectClient stub suitable for unit-testing get_openai_client."""
53+
client = AsyncAIProjectClient.__new__(AsyncAIProjectClient)
54+
client._config = MagicMock()
55+
client._config.endpoint = ENDPOINT
56+
client._config.allow_preview = allow_preview
57+
client._config.api_version = API_VERSION
58+
client._config.credential = MagicMock()
59+
client._console_logging_enabled = console_logging
60+
client._custom_user_agent = custom_user_agent
61+
return client
62+
63+
64+
def mock_openai(user_agent: str = "openai/1.0"):
65+
"""Return ``(mock_class, mock_instance)`` where ``mock_class`` acts as the OpenAI constructor."""
66+
instance = MagicMock()
67+
instance.user_agent = user_agent
68+
mock_cls = MagicMock(return_value=instance)
69+
return mock_cls, instance

0 commit comments

Comments
 (0)