2525
2626import logging
2727from abc import ABC
28- from collections .abc import AsyncIterator , Iterator , Mapping
28+ from collections .abc import AsyncIterator , Iterator , Mapping , Sequence
2929from functools import cached_property
3030from typing import Any , Literal
3131
3232from httpx import URL , Response
3333from langchain_core .embeddings import Embeddings
3434from langchain_core .language_models .chat_models import BaseChatModel
35+ from langchain_core .messages import BaseMessage
36+ from langchain_core .outputs import ChatGeneration , ChatGenerationChunk , ChatResult
3537from pydantic import AliasChoices , BaseModel , ConfigDict , Field
3638
3739from uipath_langchain_client .settings import (
3840 UiPathAPIConfig ,
3941 UiPathBaseSettings ,
4042 get_default_client_settings ,
4143)
42- from uipath_llm_client .httpx_client import UiPathHttpxAsyncClient , UiPathHttpxClient
44+ from uipath_llm_client .httpx_client import (
45+ UiPathHttpxAsyncClient ,
46+ UiPathHttpxClient ,
47+ )
48+ from uipath_llm_client .utils .headers import (
49+ get_captured_response_headers ,
50+ set_captured_response_headers ,
51+ )
4352from uipath_llm_client .utils .retry import RetryConfig
4453
4554
@@ -99,6 +108,13 @@ class UiPathBaseLLMClient(BaseModel, ABC):
99108 },
100109 description = "Default request headers to include in requests" ,
101110 )
111+ captured_headers : tuple [str , ...] = Field (
112+ default = ("x-uipath-" ,),
113+ description = "Case-insensitive response header prefixes to capture from LLM Gateway responses. "
114+ "Captured headers appear in response_metadata under the 'uipath_llmgateway_headers' key. "
115+ "Set to an empty tuple to disable." ,
116+ )
117+
102118 request_timeout : float | None = Field (
103119 alias = "timeout" ,
104120 validation_alias = AliasChoices ("timeout" , "request_timeout" , "default_request_timeout" ),
@@ -113,6 +129,7 @@ class UiPathBaseLLMClient(BaseModel, ABC):
113129 default = None ,
114130 description = "Retry configuration for failed requests" ,
115131 )
132+
116133 logger : logging .Logger | None = Field (
117134 default = None ,
118135 description = "Logger for request/response logging" ,
@@ -135,6 +152,7 @@ def uipath_sync_client(self) -> UiPathHttpxClient:
135152 model_name = self .model_name , api_config = self .api_config
136153 ),
137154 },
155+ captured_headers = self .captured_headers ,
138156 timeout = self .request_timeout ,
139157 max_retries = self .max_retries ,
140158 retry_config = self .retry_config ,
@@ -158,6 +176,7 @@ def uipath_async_client(self) -> UiPathHttpxAsyncClient:
158176 model_name = self .model_name , api_config = self .api_config
159177 ),
160178 },
179+ captured_headers = self .captured_headers ,
161180 timeout = self .request_timeout ,
162181 max_retries = self .max_retries ,
163182 retry_config = self .retry_config ,
@@ -283,7 +302,87 @@ async def uipath_astream(
283302
284303
285304class UiPathBaseChatModel (UiPathBaseLLMClient , BaseChatModel ):
286- pass
305+ """Base chat model that captures LLM Gateway response headers into response_metadata.
306+
307+ Wraps _generate/_agenerate/_stream/_astream to automatically read captured headers
308+ from the ContextVar (populated by the httpx client's send()) and inject them into
309+ the AIMessage's response_metadata under the 'uipath_llmgateway_headers' key.
310+
311+ Passthrough clients that delegate to vendor SDKs should inherit from this class
312+ so that headers are captured transparently.
313+ """
314+
315+ def _generate (
316+ self ,
317+ messages : list [BaseMessage ],
318+ * args : Any ,
319+ ** kwargs : Any ,
320+ ) -> ChatResult :
321+ set_captured_response_headers ({})
322+ try :
323+ result = super ()._generate (messages , * args , ** kwargs )
324+ self ._inject_gateway_headers (result .generations )
325+ return result
326+ finally :
327+ set_captured_response_headers ({})
328+
329+ async def _agenerate (
330+ self ,
331+ messages : list [BaseMessage ],
332+ * args : Any ,
333+ ** kwargs : Any ,
334+ ) -> ChatResult :
335+ set_captured_response_headers ({})
336+ try :
337+ result = await super ()._agenerate (messages , * args , ** kwargs )
338+ self ._inject_gateway_headers (result .generations )
339+ return result
340+ finally :
341+ set_captured_response_headers ({})
342+
343+ def _stream (
344+ self ,
345+ messages : list [BaseMessage ],
346+ * args : Any ,
347+ ** kwargs : Any ,
348+ ) -> Iterator [ChatGenerationChunk ]:
349+ set_captured_response_headers ({})
350+ try :
351+ first = True
352+ for chunk in super ()._stream (messages , * args , ** kwargs ):
353+ if first :
354+ self ._inject_gateway_headers ([chunk ])
355+ first = False
356+ yield chunk
357+ finally :
358+ set_captured_response_headers ({})
359+
360+ async def _astream (
361+ self ,
362+ messages : list [BaseMessage ],
363+ * args : Any ,
364+ ** kwargs : Any ,
365+ ) -> AsyncIterator [ChatGenerationChunk ]:
366+ set_captured_response_headers ({})
367+ try :
368+ first = True
369+ async for chunk in super ()._astream (messages , * args , ** kwargs ):
370+ if first :
371+ self ._inject_gateway_headers ([chunk ])
372+ first = False
373+ yield chunk
374+ finally :
375+ set_captured_response_headers ({})
376+
377+ def _inject_gateway_headers (self , generations : Sequence [ChatGeneration ]) -> None :
378+ """Inject captured gateway headers into each generation's response_metadata."""
379+ if not self .captured_headers :
380+ return
381+ headers = get_captured_response_headers ()
382+ if not headers :
383+ return
384+ for generation in generations :
385+ generation .message .response_metadata ["uipath_llmgateway_headers" ] = headers
287386
288387
289388class UiPathBaseEmbeddings (UiPathBaseLLMClient , Embeddings ):
0 commit comments