2424logger = logging .getLogger (__name__ )
2525
2626
27+ # ---------------------------------------------------------------------------
28+ # Shared helpers used by both the sync and async AIProjectClient.get_openai_client()
29+ # implementations. Defined at module level so the async client can import and reuse
30+ # them without duplicating the logic.
31+ # ---------------------------------------------------------------------------
32+
33+
34+ def _resolve_openai_base_url (config : Any , agent_name : Optional [str ], kwargs : dict ) -> str :
35+ """Resolve the base URL for the (Async)OpenAI client.
36+
37+ :param config: Generated client configuration carrying ``endpoint`` and ``allow_preview``.
38+ :type config: Any
39+ :param agent_name: Optional hosted-agent name.
40+ :type agent_name: str or None
41+ :param kwargs: Caller keyword arguments; ``base_url`` is popped when present.
42+ :type kwargs: dict
43+ :return: The base URL to use for the (Async)OpenAI client.
44+ :rtype: str
45+ :raises ValueError: If ``agent_name`` is provided but ``allow_preview=True`` was not set.
46+ """
47+ if "base_url" in kwargs :
48+ return kwargs .pop ("base_url" )
49+ if agent_name is not None :
50+ if config .allow_preview :
51+ return config .endpoint .rstrip ("/" ) + f"/agents/{ agent_name } /endpoint/protocols/openai"
52+ raise ValueError (
53+ "Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
54+ "\n when constructing the AIProjectClient. Note that preview features are under development and "
55+ "\n subject to change. They should not be used in production environments."
56+ )
57+ return config .endpoint .rstrip ("/" ) + "/openai/v1"
58+
59+
60+ def _resolve_openai_query_params (config : Any , agent_name : Optional [str ], kwargs : dict ) -> dict :
61+ """Build the ``default_query`` dict for the (Async)OpenAI client.
62+
63+ :param config: Generated client configuration carrying ``api_version``.
64+ :type config: Any
65+ :param agent_name: Optional hosted-agent name.
66+ :type agent_name: str or None
67+ :param kwargs: Caller keyword arguments; ``default_query`` is popped when present.
68+ :type kwargs: dict
69+ :return: Query parameters to forward to the (Async)OpenAI client.
70+ :rtype: dict
71+ """
72+ default_query = dict [str , str ](kwargs .pop ("default_query" , None ) or {})
73+ if agent_name is not None and "api-version" not in default_query :
74+ default_query ["api-version" ] = config .api_version
75+ return default_query
76+
77+
78+ def _resolve_openai_default_headers (agent_name : Optional [str ], kwargs : dict ) -> dict :
79+ """Build the ``default_headers`` dict for the (Async)OpenAI client.
80+
81+ :param agent_name: Optional hosted-agent name.
82+ :type agent_name: str or None
83+ :param kwargs: Caller keyword arguments; ``default_headers`` is popped when present.
84+ :type kwargs: dict
85+ :return: Headers to forward to the (Async)OpenAI client.
86+ :rtype: dict
87+ """
88+ default_headers = dict [str , str ](kwargs .pop ("default_headers" , None ) or {})
89+ if agent_name is not None and not _has_header_case_insensitive (default_headers , _FOUNDRY_FEATURES_HEADER_NAME ):
90+ default_headers [_FOUNDRY_FEATURES_HEADER_NAME ] = _BETA_OPERATION_FEATURE_HEADERS ["agents" ]
91+ return default_headers
92+
93+
94+ def _build_openai_user_agent (custom_user_agent : Optional [str ], openai_default_user_agent : str ) -> str :
95+ """Build the SDK-prefixed User-Agent string for the (Async)OpenAI client.
96+
97+ :param custom_user_agent: Caller-supplied user_agent kwarg captured at construction time.
98+ :type custom_user_agent: str or None
99+ :param openai_default_user_agent: The OpenAI client's own default user-agent.
100+ :type openai_default_user_agent: str
101+ :return: Combined User-Agent string.
102+ :rtype: str
103+ """
104+ return "-" .join (ua for ua in [custom_user_agent , "AIProjectClient" ] if ua ) + " " + openai_default_user_agent
105+
106+
27107class AIProjectClient (AIProjectClientGenerated ): # pylint: disable=too-many-instance-attributes
28108 """AIProjectClient.
29109
@@ -101,6 +181,35 @@ def __init__(
101181
102182 self .telemetry = TelemetryOperations (self ) # type: ignore
103183
184+ def _get_openai_api_key (self , kwargs : dict ):
185+ """Resolve the API key for the OpenAI client.
186+
187+ :param kwargs: Caller keyword arguments; ``api_key`` is popped when present.
188+ :type kwargs: dict
189+ :return: The API key string or a bearer-token-provider callable.
190+ :rtype: str or Callable
191+ """
192+ if "api_key" in kwargs :
193+ return kwargs .pop ("api_key" )
194+ return get_bearer_token_provider (
195+ self ._config .credential , # pylint: disable=protected-access
196+ "https://ai.azure.com/.default" ,
197+ )
198+
199+ def _get_openai_http_client (self , kwargs : dict ):
200+ """Resolve the HTTP transport client for the OpenAI client.
201+
202+ :param kwargs: Caller keyword arguments; ``http_client`` is popped when present.
203+ :type kwargs: dict
204+ :return: An httpx.Client instance configured with logging transport, or ``None``.
205+ :rtype: httpx.Client or None
206+ """
207+ if "http_client" in kwargs :
208+ return kwargs .pop ("http_client" )
209+ if self ._console_logging_enabled :
210+ return httpx .Client (transport = _OpenAILoggingTransport ())
211+ return None
212+
104213 @distributed_trace
105214 def get_openai_client (self , * , agent_name : Optional [str ] = None , ** kwargs : Any ) -> OpenAI :
106215 """Get an authenticated OpenAI client from the `openai` package.
@@ -131,51 +240,17 @@ def get_openai_client(self, *, agent_name: Optional[str] = None, **kwargs: Any)
131240
132241 kwargs = kwargs .copy () if kwargs else {}
133242
134- # Allow caller to override base_url
135- if "base_url" in kwargs :
136- base_url = kwargs .pop ("base_url" )
137- elif agent_name is not None :
138- if self ._config .allow_preview :
139- base_url = (
140- self ._config .endpoint .rstrip ("/" ) + f"/agents/{ agent_name } /endpoint/protocols/openai"
141- ) # pylint: disable=protected-access
142- else :
143- raise ValueError (
144- "Calling `get_openai_client` method with an `agent_name` requires you to set `allow_preview=True`"
145- "\n when constructing the AIProjectClient. Note that preview features are under development and "
146- "\n subject to change. They should not be used in production environments."
147- )
148- else :
149- base_url = self ._config .endpoint .rstrip ("/" ) + "/openai/v1" # pylint: disable=protected-access
150-
151- default_query = dict [str , str ](kwargs .pop ("default_query" , None ) or {})
152- if agent_name is not None and "api-version" not in default_query :
153- default_query ["api-version" ] = self ._config .api_version # pylint: disable=protected-access
243+ base_url = _resolve_openai_base_url (self ._config , agent_name , kwargs )
244+ default_query = _resolve_openai_query_params (self ._config , agent_name , kwargs )
154245
155246 logger .debug ( # pylint: disable=specify-parameter-names-in-call
156247 "[get_openai_client] Creating OpenAI client using Entra ID authentication, base_url = `%s`" , # pylint: disable=line-too-long
157248 base_url ,
158249 )
159250
160- # Allow caller to override api_key, otherwise use token provider
161- if "api_key" in kwargs :
162- api_key = kwargs .pop ("api_key" )
163- else :
164- api_key = get_bearer_token_provider (
165- self ._config .credential , # pylint: disable=protected-access
166- "https://ai.azure.com/.default" ,
167- )
168-
169- if "http_client" in kwargs :
170- http_client = kwargs .pop ("http_client" )
171- elif self ._console_logging_enabled :
172- http_client = httpx .Client (transport = _OpenAILoggingTransport ())
173- else :
174- http_client = None
175-
176- default_headers = dict [str , str ](kwargs .pop ("default_headers" , None ) or {})
177- if agent_name is not None and not _has_header_case_insensitive (default_headers , _FOUNDRY_FEATURES_HEADER_NAME ):
178- default_headers [_FOUNDRY_FEATURES_HEADER_NAME ] = _BETA_OPERATION_FEATURE_HEADERS ["agents" ]
251+ api_key = self ._get_openai_api_key (kwargs )
252+ http_client = self ._get_openai_http_client (kwargs )
253+ default_headers = _resolve_openai_default_headers (agent_name , kwargs )
179254
180255 openai_custom_user_agent = default_headers .get ("User-Agent" , None )
181256
@@ -195,11 +270,7 @@ def _create_openai_client(**kwargs) -> OpenAI:
195270 if openai_custom_user_agent :
196271 final_user_agent = openai_custom_user_agent
197272 else :
198- final_user_agent = (
199- "-" .join (ua for ua in [self ._custom_user_agent , "AIProjectClient" ] if ua )
200- + " "
201- + openai_default_user_agent
202- )
273+ final_user_agent = _build_openai_user_agent (self ._custom_user_agent , openai_default_user_agent )
203274
204275 default_headers ["User-Agent" ] = final_user_agent
205276
0 commit comments