|
12 | 12 |
|
13 | 13 | from typing import Any, Final |
14 | 14 |
|
15 | | -from langchain_core.callbacks import Callbacks |
| 15 | +from langchain_core.callbacks import BaseCallbackHandler, Callbacks |
16 | 16 | from langchain_core.language_models import BaseChatModel |
| 17 | +from uipath.llm_client.utils.headers import ( |
| 18 | + get_dynamic_request_headers, |
| 19 | + set_dynamic_request_headers, |
| 20 | +) |
| 21 | +from uipath.platform.chat.llm_trace_context import build_trace_context_headers |
17 | 22 | from uipath_langchain_client.base_client import UiPathBaseChatModel |
18 | 23 | from uipath_langchain_client.factory import get_chat_model as get_chat_model_factory |
19 | 24 | from uipath_langchain_client.settings import ( |
|
23 | 28 | VendorType, |
24 | 29 | ) |
25 | 30 |
|
| 31 | + |
| 32 | +class _TraceContextHeadersCallback(BaseCallbackHandler): |
| 33 | + """Inject W3C-style trace context headers into each LLM gateway request. |
| 34 | +
|
| 35 | + Merges into the existing dynamic-headers ContextVar so that headers |
| 36 | + set by earlier callbacks (e.g. ``LicenseRefIdHeadersCallback``) are |
| 37 | + preserved instead of overwritten. |
| 38 | + """ |
| 39 | + |
| 40 | + run_inline: bool = True |
| 41 | + |
| 42 | + def _merge_headers(self) -> None: |
| 43 | + existing = get_dynamic_request_headers() |
| 44 | + existing.update(build_trace_context_headers(extra_baggage=["source=agents"])) |
| 45 | + set_dynamic_request_headers(existing) |
| 46 | + |
| 47 | + def on_chat_model_start( |
| 48 | + self, serialized: dict[str, Any], messages: list[list[Any]], **kwargs: Any |
| 49 | + ) -> None: |
| 50 | + self._merge_headers() |
| 51 | + |
| 52 | + def on_llm_start( |
| 53 | + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any |
| 54 | + ) -> None: |
| 55 | + self._merge_headers() |
| 56 | + |
| 57 | + |
26 | 58 | _UNSET: Final[Any] = object() |
27 | 59 | DEFAULT_TIMEOUT_SECONDS: Final[float] = 895.0 |
28 | 60 | DEFAULT_MAX_TOKENS: Final[int] = 1000 |
@@ -84,6 +116,12 @@ def get_chat_model( |
84 | 116 | Returns: |
85 | 117 | A configured ``BaseChatModel`` instance. |
86 | 118 | """ |
| 119 | + # Always inject trace context headers per-request via a dynamic-headers |
| 120 | + # callback. For the new path the UiPathHttpxClient reads the ContextVar |
| 121 | + # set by the callback; for the legacy path the callback is a no-op but |
| 122 | + # keeps the wiring consistent. |
| 123 | + callbacks = _ensure_trace_context_callback(callbacks) |
| 124 | + |
87 | 125 | if not use_new_llm_clients: |
88 | 126 | return _legacy_chat_model( |
89 | 127 | model, |
@@ -120,6 +158,17 @@ def get_chat_model( |
120 | 158 | ) |
121 | 159 |
|
122 | 160 |
|
| 161 | +def _ensure_trace_context_callback(callbacks: Callbacks) -> list[BaseCallbackHandler]: |
| 162 | + """Append a ``_TraceContextHeadersCallback`` if one is not already present.""" |
| 163 | + if callbacks is _UNSET or callbacks is None: |
| 164 | + cb_list: list[BaseCallbackHandler] = [] |
| 165 | + else: |
| 166 | + cb_list = list(callbacks) # type: ignore[arg-type] |
| 167 | + if not any(isinstance(cb, _TraceContextHeadersCallback) for cb in cb_list): |
| 168 | + cb_list.append(_TraceContextHeadersCallback()) |
| 169 | + return cb_list |
| 170 | + |
| 171 | + |
123 | 172 | def _legacy_chat_model( |
124 | 173 | model: str, |
125 | 174 | *, |
|
0 commit comments