Skip to content

Commit b2fdda7

Browse files
committed
fixes
1 parent 376e97d commit b2fdda7

4 files changed

Lines changed: 548 additions & 99 deletions

File tree

packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
from typing import Any, Literal
3131

3232
from httpx import URL, Response
33+
from langchain_core.callbacks import (
34+
AsyncCallbackManagerForLLMRun,
35+
CallbackManagerForLLMRun,
36+
)
3337
from langchain_core.embeddings import Embeddings
3438
from langchain_core.language_models.chat_models import BaseChatModel
3539
from langchain_core.messages import BaseMessage
@@ -322,72 +326,128 @@ class UiPathBaseChatModel(UiPathBaseLLMClient, BaseChatModel):
322326
from the ContextVar (populated by the httpx client's send()) and inject them into
323327
the AIMessage's response_metadata under the 'uipath_llmgateway_headers' key.
324328
329+
Dynamic request headers are injected via UiPathDynamicHeadersCallback: set
330+
``run_inline = True`` (already the default) so LangChain calls
331+
``on_chat_model_start`` in the same coroutine as ``_agenerate``, ensuring the
332+
ContextVar is visible when ``httpx.send()`` fires.
333+
325334
Passthrough clients that delegate to vendor SDKs should inherit from this class
326335
so that headers are captured transparently.
327336
"""
328337

329338
def _generate(
330339
self,
331340
messages: list[BaseMessage],
332-
*args: Any,
341+
stop: list[str] | None = None,
342+
run_manager: CallbackManagerForLLMRun | None = None,
333343
**kwargs: Any,
334344
) -> ChatResult:
335345
set_captured_response_headers({})
336346
try:
337-
result = super()._generate(messages, *args, **kwargs)
347+
result = self._uipath_generate(messages, stop=stop, run_manager=run_manager, **kwargs)
338348
self._inject_gateway_headers(result.generations)
339349
return result
340350
finally:
341351
set_captured_response_headers({})
342352

353+
def _uipath_generate(
354+
self,
355+
messages: list[BaseMessage],
356+
stop: list[str] | None = None,
357+
run_manager: CallbackManagerForLLMRun | None = None,
358+
**kwargs: Any,
359+
) -> ChatResult:
360+
"""Override in subclasses to provide the core (non-wrapped) generate logic."""
361+
return super()._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
362+
343363
async def _agenerate(
344364
self,
345365
messages: list[BaseMessage],
346-
*args: Any,
366+
stop: list[str] | None = None,
367+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
347368
**kwargs: Any,
348369
) -> ChatResult:
349370
set_captured_response_headers({})
350371
try:
351-
result = await super()._agenerate(messages, *args, **kwargs)
372+
result = await self._uipath_agenerate(
373+
messages, stop=stop, run_manager=run_manager, **kwargs
374+
)
352375
self._inject_gateway_headers(result.generations)
353376
return result
354377
finally:
355378
set_captured_response_headers({})
356379

380+
async def _uipath_agenerate(
381+
self,
382+
messages: list[BaseMessage],
383+
stop: list[str] | None = None,
384+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
385+
**kwargs: Any,
386+
) -> ChatResult:
387+
"""Override in subclasses to provide the core (non-wrapped) async generate logic."""
388+
return await super()._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)
389+
357390
def _stream(
358391
self,
359392
messages: list[BaseMessage],
360-
*args: Any,
393+
stop: list[str] | None = None,
394+
run_manager: CallbackManagerForLLMRun | None = None,
361395
**kwargs: Any,
362396
) -> Iterator[ChatGenerationChunk]:
363397
set_captured_response_headers({})
364398
try:
365399
first = True
366-
for chunk in super()._stream(messages, *args, **kwargs):
400+
for chunk in self._uipath_stream(
401+
messages, stop=stop, run_manager=run_manager, **kwargs
402+
):
367403
if first:
368404
self._inject_gateway_headers([chunk])
369405
first = False
370406
yield chunk
371407
finally:
372408
set_captured_response_headers({})
373409

410+
def _uipath_stream(
411+
self,
412+
messages: list[BaseMessage],
413+
stop: list[str] | None = None,
414+
run_manager: CallbackManagerForLLMRun | None = None,
415+
**kwargs: Any,
416+
) -> Iterator[ChatGenerationChunk]:
417+
"""Override in subclasses to provide the core (non-wrapped) stream logic."""
418+
yield from super()._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
419+
374420
async def _astream(
375421
self,
376422
messages: list[BaseMessage],
377-
*args: Any,
423+
stop: list[str] | None = None,
424+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
378425
**kwargs: Any,
379426
) -> AsyncIterator[ChatGenerationChunk]:
380427
set_captured_response_headers({})
381428
try:
382429
first = True
383-
async for chunk in super()._astream(messages, *args, **kwargs):
430+
async for chunk in self._uipath_astream(
431+
messages, stop=stop, run_manager=run_manager, **kwargs
432+
):
384433
if first:
385434
self._inject_gateway_headers([chunk])
386435
first = False
387436
yield chunk
388437
finally:
389438
set_captured_response_headers({})
390439

440+
async def _uipath_astream(
441+
self,
442+
messages: list[BaseMessage],
443+
stop: list[str] | None = None,
444+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
445+
**kwargs: Any,
446+
) -> AsyncIterator[ChatGenerationChunk]:
447+
"""Override in subclasses to provide the core (non-wrapped) async stream logic."""
448+
async for chunk in super()._astream(messages, stop=stop, run_manager=run_manager, **kwargs):
449+
yield chunk
450+
391451
def _inject_gateway_headers(self, generations: Sequence[ChatGeneration]) -> None:
392452
"""Inject captured gateway headers into each generation's response_metadata."""
393453
if not self.captured_headers:

packages/uipath_langchain_client/src/uipath_langchain_client/callbacks.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@
22

33
from abc import abstractmethod
44
from typing import Any
5-
from uuid import UUID
65

76
from langchain_core.callbacks import BaseCallbackHandler
8-
from langchain_core.messages import BaseMessage
97

10-
from uipath.llm_client.utils.headers import (
11-
set_dynamic_request_headers,
12-
)
8+
from uipath.llm_client.utils.headers import set_dynamic_request_headers
139

1410

1511
class UiPathDynamicHeadersCallback(BaseCallbackHandler):
1612
"""Base callback for injecting dynamic headers into each LLM gateway request.
1713
1814
Extend this class and implement ``get_headers()`` to return the headers to
19-
inject. The headers are stored in a ContextVar before each LLM call and read
20-
by the httpx client's ``send()`` method, so they flow transparently through
21-
the call stack regardless of which vendor SDK is in use.
15+
inject. ``run_inline = True`` ensures ``on_chat_model_start`` is called
16+
directly in the caller's coroutine (not via ``asyncio.gather``), so the
17+
ContextVar mutation is visible when ``httpx.send()`` fires.
2218
2319
Example (OTEL trace propagation)::
2420
@@ -34,6 +30,8 @@ def get_headers(self) -> dict[str, str]:
3430
response = chat.invoke("Hello!", config={"callbacks": [OtelHeadersCallback()]})
3531
"""
3632

33+
run_inline: bool = True # dispatch in the caller's coroutine, not via asyncio.gather
34+
3735
@abstractmethod
3836
def get_headers(self) -> dict[str, str]:
3937
"""Return headers to inject into the next LLM gateway request."""
@@ -42,9 +40,7 @@ def get_headers(self) -> dict[str, str]:
4240
def on_chat_model_start(
4341
self,
4442
serialized: dict[str, Any],
45-
messages: list[list[BaseMessage]],
46-
*,
47-
run_id: UUID,
43+
messages: list[list[Any]],
4844
**kwargs: Any,
4945
) -> None:
5046
set_dynamic_request_headers(self.get_headers())
@@ -53,14 +49,12 @@ def on_llm_start(
5349
self,
5450
serialized: dict[str, Any],
5551
prompts: list[str],
56-
*,
57-
run_id: UUID,
5852
**kwargs: Any,
5953
) -> None:
6054
set_dynamic_request_headers(self.get_headers())
6155

62-
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
56+
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
6357
set_dynamic_request_headers({})
6458

65-
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
59+
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
6660
set_dynamic_request_headers({})

packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py

Lines changed: 42 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,6 @@
5656
)
5757
from pydantic import Field
5858

59-
from uipath.llm_client.utils.headers import (
60-
extract_matching_headers,
61-
set_captured_response_headers,
62-
)
6359
from uipath_langchain_client.base_client import UiPathBaseChatModel
6460
from uipath_langchain_client.settings import ApiType, RoutingMode, UiPathAPIConfig
6561

@@ -311,39 +307,27 @@ def _postprocess_response(self, response: dict[str, Any]) -> ChatResult:
311307
llm_output=llm_output,
312308
)
313309

314-
def _generate(
310+
def _uipath_generate(
315311
self,
316312
messages: list[BaseMessage],
317-
*args: Any,
313+
stop: list[str] | None = None,
318314
run_manager: CallbackManagerForLLMRun | None = None,
319315
**kwargs: Any,
320316
) -> ChatResult:
321-
request_body = self._preprocess_request(messages, **kwargs)
317+
request_body = self._preprocess_request(messages, stop=stop, **kwargs)
322318
response = self.uipath_request(request_body=request_body, raise_status_error=True)
323-
result = self._postprocess_response(response.json())
324-
if self.captured_headers:
325-
captured = extract_matching_headers(response.headers, self.captured_headers)
326-
if captured:
327-
for gen in result.generations:
328-
gen.message.response_metadata["uipath_llmgateway_headers"] = captured
329-
return result
330-
331-
async def _agenerate(
319+
return self._postprocess_response(response.json())
320+
321+
async def _uipath_agenerate(
332322
self,
333323
messages: list[BaseMessage],
334-
*args: Any,
324+
stop: list[str] | None = None,
335325
run_manager: AsyncCallbackManagerForLLMRun | None = None,
336326
**kwargs: Any,
337327
) -> ChatResult:
338-
request_body = self._preprocess_request(messages, **kwargs)
328+
request_body = self._preprocess_request(messages, stop=stop, **kwargs)
339329
response = await self.uipath_arequest(request_body=request_body, raise_status_error=True)
340-
result = self._postprocess_response(response.json())
341-
if self.captured_headers:
342-
captured = extract_matching_headers(response.headers, self.captured_headers)
343-
if captured:
344-
for gen in result.generations:
345-
gen.message.response_metadata["uipath_llmgateway_headers"] = captured
346-
return result
330+
return self._postprocess_response(response.json())
347331

348332
def _generate_chunk(
349333
self, original_message: str, json_data: dict[str, Any]
@@ -402,64 +386,46 @@ def _generate_chunk(
402386
),
403387
)
404388

405-
def _stream(
389+
def _uipath_stream(
406390
self,
407391
messages: list[BaseMessage],
408-
*args: Any,
392+
stop: list[str] | None = None,
409393
run_manager: CallbackManagerForLLMRun | None = None,
410394
**kwargs: Any,
411395
) -> Iterator[ChatGenerationChunk]:
412-
request_body = self._preprocess_request(messages, **kwargs)
413-
set_captured_response_headers({})
414-
try:
415-
first = True
416-
for chunk in self.uipath_stream(
417-
request_body=request_body, stream_type="lines", raise_status_error=True
418-
):
419-
chunk = str(chunk)
420-
if chunk.startswith("data:"):
421-
chunk = chunk.split("data:")[1].strip()
422-
try:
423-
json_data = json.loads(chunk)
424-
except json.JSONDecodeError:
425-
continue
426-
if "id" in json_data and not json_data["id"]:
427-
continue
428-
gen_chunk = self._generate_chunk(chunk, json_data)
429-
if first:
430-
self._inject_gateway_headers([gen_chunk])
431-
first = False
432-
yield gen_chunk
433-
finally:
434-
set_captured_response_headers({})
435-
436-
async def _astream(
396+
request_body = self._preprocess_request(messages, stop=stop, **kwargs)
397+
for chunk in self.uipath_stream(
398+
request_body=request_body, stream_type="lines", raise_status_error=True
399+
):
400+
chunk = str(chunk)
401+
if chunk.startswith("data:"):
402+
chunk = chunk.split("data:")[1].strip()
403+
try:
404+
json_data = json.loads(chunk)
405+
except json.JSONDecodeError:
406+
continue
407+
if "id" in json_data and not json_data["id"]:
408+
continue
409+
yield self._generate_chunk(chunk, json_data)
410+
411+
async def _uipath_astream(
437412
self,
438413
messages: list[BaseMessage],
439-
*args: Any,
414+
stop: list[str] | None = None,
440415
run_manager: AsyncCallbackManagerForLLMRun | None = None,
441416
**kwargs: Any,
442417
) -> AsyncIterator[ChatGenerationChunk]:
443-
request_body = self._preprocess_request(messages, **kwargs)
444-
set_captured_response_headers({})
445-
try:
446-
first = True
447-
async for chunk in self.uipath_astream(
448-
request_body=request_body, stream_type="lines", raise_status_error=True
449-
):
450-
chunk = str(chunk)
451-
if chunk.startswith("data:"):
452-
chunk = chunk.split("data:")[1].strip()
453-
try:
454-
json_data = json.loads(chunk)
455-
except json.JSONDecodeError:
456-
continue
457-
if "id" in json_data and not json_data["id"]:
458-
continue
459-
gen_chunk = self._generate_chunk(chunk, json_data)
460-
if first:
461-
self._inject_gateway_headers([gen_chunk])
462-
first = False
463-
yield gen_chunk
464-
finally:
465-
set_captured_response_headers({})
418+
request_body = self._preprocess_request(messages, stop=stop, **kwargs)
419+
async for chunk in self.uipath_astream(
420+
request_body=request_body, stream_type="lines", raise_status_error=True
421+
):
422+
chunk = str(chunk)
423+
if chunk.startswith("data:"):
424+
chunk = chunk.split("data:")[1].strip()
425+
try:
426+
json_data = json.loads(chunk)
427+
except json.JSONDecodeError:
428+
continue
429+
if "id" in json_data and not json_data["id"]:
430+
continue
431+
yield self._generate_chunk(chunk, json_data)

0 commit comments

Comments
 (0)