11import logging
22import os
3+ from collections .abc import AsyncIterator , Iterator
34from typing import Any , Optional
45
56import httpx
89 CallbackManagerForLLMRun ,
910)
1011from langchain_core .messages import BaseMessage
11- from langchain_core .outputs import ChatResult
12+ from langchain_core .outputs import ChatGenerationChunk , ChatResult
1213from tenacity import AsyncRetrying , Retrying
1314from uipath ._utils import resource_override
1415from uipath ._utils ._ssl_context import get_httpx_client_kwargs
1516from uipath .utils import EndpointManager
1617
18+ from .header_capture import HeaderCapture
1719from .retryers .vertex import AsyncVertexRetryer , VertexRetryer
1820from .supported_models import GeminiModels
1921from .types import APIFlavor , LLMProvider
@@ -70,9 +72,15 @@ def _rewrite_vertex_url(original_url: str, gateway_url: str) -> httpx.URL | None
7072class _UrlRewriteTransport (httpx .HTTPTransport ):
7173 """Transport that rewrites URLs to redirect to UiPath gateway."""
7274
73- def __init__ (self , gateway_url : str , verify : bool = True ):
75+ def __init__ (
76+ self ,
77+ gateway_url : str ,
78+ verify : bool = True ,
79+ header_capture : HeaderCapture | None = None ,
80+ ):
7481 super ().__init__ (verify = verify )
7582 self .gateway_url = gateway_url
83+ self .header_capture = header_capture
7684
7785 def handle_request (self , request : httpx .Request ) -> httpx .Response :
7886 original_url = str (request .url )
@@ -86,15 +94,26 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
8694 # Update host header to match the new URL
8795 request .headers ["host" ] = new_url .host
8896 request .url = new_url
89- return super ().handle_request (request )
97+
98+ response = super ().handle_request (request )
99+ if self .header_capture :
100+ self .header_capture .set (dict (response .headers ))
101+
102+ return response
90103
91104
92105class _AsyncUrlRewriteTransport (httpx .AsyncHTTPTransport ):
93106 """Async transport that rewrites URLs to redirect to UiPath gateway."""
94107
95- def __init__ (self , gateway_url : str , verify : bool = True ):
108+ def __init__ (
109+ self ,
110+ gateway_url : str ,
111+ verify : bool = True ,
112+ header_capture : HeaderCapture | None = None ,
113+ ):
96114 super ().__init__ (verify = verify )
97115 self .gateway_url = gateway_url
116+ self .header_capture = header_capture
98117
99118 async def handle_async_request (self , request : httpx .Request ) -> httpx .Response :
100119 original_url = str (request .url )
@@ -108,7 +127,12 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
108127 # Update host header to match the new URL
109128 request .headers ["host" ] = new_url .host
110129 request .url = new_url
111- return await super ().handle_async_request (request )
130+
131+ response = await super ().handle_async_request (request )
132+ if self .header_capture :
133+ self .header_capture .set (dict (response .headers ))
134+
135+ return response
112136
113137
114138class UiPathChatVertex (ChatGoogleGenerativeAI ):
@@ -162,17 +186,22 @@ def __init__(
162186 uipath_url = self ._build_base_url (model_name )
163187 headers = self ._build_headers (token , agenthub_config , byo_connection_id )
164188
189+ header_capture = HeaderCapture (name = f"vertex_headers_{ id (self )} " )
165190 client_kwargs = get_httpx_client_kwargs ()
166191 verify = client_kwargs .get ("verify" , True )
167192
168193 http_options = genai_types .HttpOptions (
169194 httpx_client = httpx .Client (
170- transport = _UrlRewriteTransport (uipath_url , verify = verify ),
195+ transport = _UrlRewriteTransport (
196+ uipath_url , verify = verify , header_capture = header_capture
197+ ),
171198 headers = headers ,
172199 ** client_kwargs ,
173200 ),
174201 httpx_async_client = httpx .AsyncClient (
175- transport = _AsyncUrlRewriteTransport (uipath_url , verify = verify ),
202+ transport = _AsyncUrlRewriteTransport (
203+ uipath_url , verify = verify , header_capture = header_capture
204+ ),
176205 headers = headers ,
177206 ** client_kwargs ,
178207 ),
@@ -205,6 +234,7 @@ def __init__(
205234 self ._byo_connection_id = byo_connection_id
206235 self ._retryer = retryer
207236 self ._aretryer = aretryer
237+ self ._header_capture = header_capture
208238
209239 if self .temperature is not None and not 0 <= self .temperature <= 2.0 :
210240 raise ValueError ("temperature must be in the range [0.0, 2.0]" )
@@ -295,7 +325,10 @@ def _generate(
295325 result = super ()._generate (
296326 messages , stop = stop , run_manager = run_manager , ** kwargs
297327 )
298- return self ._merge_finish_reason_to_response_metadata (result )
328+ result = self ._merge_finish_reason_to_response_metadata (result )
329+ self ._header_capture .attach_to_chat_result (result )
330+ self ._header_capture .clear ()
331+ return result
299332
300333 async def _agenerate (
301334 self ,
@@ -308,7 +341,40 @@ async def _agenerate(
308341 result = await super ()._agenerate (
309342 messages , stop = stop , run_manager = run_manager , ** kwargs
310343 )
311- return self ._merge_finish_reason_to_response_metadata (result )
344+ result = self ._merge_finish_reason_to_response_metadata (result )
345+ self ._header_capture .attach_to_chat_result (result )
346+ self ._header_capture .clear ()
347+ return result
348+
349+ def _stream (
350+ self ,
351+ messages : list [BaseMessage ],
352+ stop : list [str ] | None = None ,
353+ run_manager : CallbackManagerForLLMRun | None = None ,
354+ ** kwargs : Any ,
355+ ) -> Iterator [ChatGenerationChunk ]:
356+ for chunk in super ()._stream (
357+ messages , stop = stop , run_manager = run_manager , ** kwargs
358+ ):
359+ self ._header_capture .attach_to_chat_generation (chunk )
360+ yield chunk
361+
362+ self ._header_capture .clear ()
363+
364+ async def _astream (
365+ self ,
366+ messages : list [BaseMessage ],
367+ stop : list [str ] | None = None ,
368+ run_manager : AsyncCallbackManagerForLLMRun | None = None ,
369+ ** kwargs : Any ,
370+ ) -> AsyncIterator [ChatGenerationChunk ]:
371+ async for chunk in super ()._astream (
372+ messages , stop = stop , run_manager = run_manager , ** kwargs
373+ ):
374+ self ._header_capture .attach_to_chat_generation (chunk )
375+ yield chunk
376+
377+ self ._header_capture .clear ()
312378
313379
314380def _get_default_retryer () -> VertexRetryer :
0 commit comments