|
30 | 30 | from typing import Any, Literal |
31 | 31 |
|
32 | 32 | from httpx import URL, Response |
| 33 | +from langchain_core.callbacks import ( |
| 34 | + AsyncCallbackManagerForLLMRun, |
| 35 | + CallbackManagerForLLMRun, |
| 36 | +) |
33 | 37 | from langchain_core.embeddings import Embeddings |
34 | 38 | from langchain_core.language_models.chat_models import BaseChatModel |
35 | 39 | from langchain_core.messages import BaseMessage |
@@ -322,72 +326,128 @@ class UiPathBaseChatModel(UiPathBaseLLMClient, BaseChatModel): |
322 | 326 | from the ContextVar (populated by the httpx client's send()) and inject them into |
323 | 327 | the AIMessage's response_metadata under the 'uipath_llmgateway_headers' key. |
324 | 328 |
|
| 329 | + Dynamic request headers are injected via UiPathDynamicHeadersCallback: set |
| 330 | + ``run_inline = True`` (already the default) so LangChain calls |
| 331 | + ``on_chat_model_start`` in the same coroutine as ``_agenerate``, ensuring the |
| 332 | + ContextVar is visible when ``httpx.send()`` fires. |
| 333 | +
|
325 | 334 | Passthrough clients that delegate to vendor SDKs should inherit from this class |
326 | 335 | so that headers are captured transparently. |
327 | 336 | """ |
328 | 337 |
|
329 | 338 | def _generate( |
330 | 339 | self, |
331 | 340 | messages: list[BaseMessage], |
332 | | - *args: Any, |
| 341 | + stop: list[str] | None = None, |
| 342 | + run_manager: CallbackManagerForLLMRun | None = None, |
333 | 343 | **kwargs: Any, |
334 | 344 | ) -> ChatResult: |
335 | 345 | set_captured_response_headers({}) |
336 | 346 | try: |
337 | | - result = super()._generate(messages, *args, **kwargs) |
| 347 | + result = self._uipath_generate(messages, stop=stop, run_manager=run_manager, **kwargs) |
338 | 348 | self._inject_gateway_headers(result.generations) |
339 | 349 | return result |
340 | 350 | finally: |
341 | 351 | set_captured_response_headers({}) |
342 | 352 |
|
| 353 | + def _uipath_generate( |
| 354 | + self, |
| 355 | + messages: list[BaseMessage], |
| 356 | + stop: list[str] | None = None, |
| 357 | + run_manager: CallbackManagerForLLMRun | None = None, |
| 358 | + **kwargs: Any, |
| 359 | + ) -> ChatResult: |
| 360 | + """Override in subclasses to provide the core (non-wrapped) generate logic.""" |
| 361 | + return super()._generate(messages, stop=stop, run_manager=run_manager, **kwargs) |
| 362 | + |
343 | 363 | async def _agenerate( |
344 | 364 | self, |
345 | 365 | messages: list[BaseMessage], |
346 | | - *args: Any, |
| 366 | + stop: list[str] | None = None, |
| 367 | + run_manager: AsyncCallbackManagerForLLMRun | None = None, |
347 | 368 | **kwargs: Any, |
348 | 369 | ) -> ChatResult: |
349 | 370 | set_captured_response_headers({}) |
350 | 371 | try: |
351 | | - result = await super()._agenerate(messages, *args, **kwargs) |
| 372 | + result = await self._uipath_agenerate( |
| 373 | + messages, stop=stop, run_manager=run_manager, **kwargs |
| 374 | + ) |
352 | 375 | self._inject_gateway_headers(result.generations) |
353 | 376 | return result |
354 | 377 | finally: |
355 | 378 | set_captured_response_headers({}) |
356 | 379 |
|
| 380 | + async def _uipath_agenerate( |
| 381 | + self, |
| 382 | + messages: list[BaseMessage], |
| 383 | + stop: list[str] | None = None, |
| 384 | + run_manager: AsyncCallbackManagerForLLMRun | None = None, |
| 385 | + **kwargs: Any, |
| 386 | + ) -> ChatResult: |
| 387 | + """Override in subclasses to provide the core (non-wrapped) async generate logic.""" |
| 388 | + return await super()._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs) |
| 389 | + |
357 | 390 | def _stream( |
358 | 391 | self, |
359 | 392 | messages: list[BaseMessage], |
360 | | - *args: Any, |
| 393 | + stop: list[str] | None = None, |
| 394 | + run_manager: CallbackManagerForLLMRun | None = None, |
361 | 395 | **kwargs: Any, |
362 | 396 | ) -> Iterator[ChatGenerationChunk]: |
363 | 397 | set_captured_response_headers({}) |
364 | 398 | try: |
365 | 399 | first = True |
366 | | - for chunk in super()._stream(messages, *args, **kwargs): |
| 400 | + for chunk in self._uipath_stream( |
| 401 | + messages, stop=stop, run_manager=run_manager, **kwargs |
| 402 | + ): |
367 | 403 | if first: |
368 | 404 | self._inject_gateway_headers([chunk]) |
369 | 405 | first = False |
370 | 406 | yield chunk |
371 | 407 | finally: |
372 | 408 | set_captured_response_headers({}) |
373 | 409 |
|
| 410 | + def _uipath_stream( |
| 411 | + self, |
| 412 | + messages: list[BaseMessage], |
| 413 | + stop: list[str] | None = None, |
| 414 | + run_manager: CallbackManagerForLLMRun | None = None, |
| 415 | + **kwargs: Any, |
| 416 | + ) -> Iterator[ChatGenerationChunk]: |
| 417 | + """Override in subclasses to provide the core (non-wrapped) stream logic.""" |
| 418 | + yield from super()._stream(messages, stop=stop, run_manager=run_manager, **kwargs) |
| 419 | + |
374 | 420 | async def _astream( |
375 | 421 | self, |
376 | 422 | messages: list[BaseMessage], |
377 | | - *args: Any, |
| 423 | + stop: list[str] | None = None, |
| 424 | + run_manager: AsyncCallbackManagerForLLMRun | None = None, |
378 | 425 | **kwargs: Any, |
379 | 426 | ) -> AsyncIterator[ChatGenerationChunk]: |
380 | 427 | set_captured_response_headers({}) |
381 | 428 | try: |
382 | 429 | first = True |
383 | | - async for chunk in super()._astream(messages, *args, **kwargs): |
| 430 | + async for chunk in self._uipath_astream( |
| 431 | + messages, stop=stop, run_manager=run_manager, **kwargs |
| 432 | + ): |
384 | 433 | if first: |
385 | 434 | self._inject_gateway_headers([chunk]) |
386 | 435 | first = False |
387 | 436 | yield chunk |
388 | 437 | finally: |
389 | 438 | set_captured_response_headers({}) |
390 | 439 |
|
| 440 | + async def _uipath_astream( |
| 441 | + self, |
| 442 | + messages: list[BaseMessage], |
| 443 | + stop: list[str] | None = None, |
| 444 | + run_manager: AsyncCallbackManagerForLLMRun | None = None, |
| 445 | + **kwargs: Any, |
| 446 | + ) -> AsyncIterator[ChatGenerationChunk]: |
| 447 | + """Override in subclasses to provide the core (non-wrapped) async stream logic.""" |
| 448 | + async for chunk in super()._astream(messages, stop=stop, run_manager=run_manager, **kwargs): |
| 449 | + yield chunk |
| 450 | + |
391 | 451 | def _inject_gateway_headers(self, generations: Sequence[ChatGeneration]) -> None: |
392 | 452 | """Inject captured gateway headers into each generation's response_metadata.""" |
393 | 453 | if not self.captured_headers: |
|
0 commit comments