diff --git a/.claude/CLAUDE.md b/.claude/CLAUDE.md index 2bf9e39..f6b90c9 100644 --- a/.claude/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -100,10 +100,10 @@ Follow the existing format — newest version first, grouped by date: Before every commit and before opening a PR, always run: ```bash -ruff check && ruff format --check && pytest tests +ruff check && ruff format . && pyright && pytest tests ``` -All three must pass. Fix any lint, format, or test failures before committing. This applies when working as an AI assistant too — run the checks, fix failures, then commit and push. +All four must pass. Fix any lint, format, type, or test failures before committing. This applies when working as an AI assistant too — run the checks, fix failures, then commit and push. --- diff --git a/CHANGELOG.md b/CHANGELOG.md index a041eec..34b8f44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ All notable changes to `uipath_llm_client` (core package) will be documented in this file. +## [1.6.0] - 2026-04-03 + +### Fixed +- Set `api_flavor` to `None` for ANTHROPIC and AZURE vendor types +- Add ANTHROPIC/AZURE cases to validator and remove unused `original_message` parameter +- Fix VertexAI `default_headers` consistency and demo import path +- Fix LLMGateway singleton cache key to include `base_url` + ## [1.5.10] - 2026-03-26 ### Changed @@ -9,17 +17,17 @@ All notable changes to `uipath_llm_client` (core package) will be documented in ## [1.5.9] - 2026-03-26 -### Fix +### Fixed - Use `availableOperationCodes` field (instead of `operationCodes`) when validating BYOM operation codes ## [1.5.8] - 2026-03-26 -### Fix +### Fixed - Pass `base_url` to `OpenAI` and `AsyncOpenAI` constructors in `UiPathOpenAI` and `UiPathAsyncOpenAI` to ensure the correct endpoint is forwarded to the underlying SDK clients ## [1.5.7] - 2026-03-23 -### Fix +### Fixed - Added mapping api_flavor to vendor_type ## [1.5.6] - 2026-03-21 @@ -30,22 +38,22 @@ All notable changes to `uipath_llm_client` (core package) will be documented in ## [1.5.5] - 2026-03-19 -### Fix +### Fixed - Fix headers for Platform Settings ## [1.5.3] - 2026-03-18 -### Fix +### Fixed - Factory function fix ## [1.5.2] - 2026-03-18 -### Fix +### Fixed - Factory function fix ## [1.5.1] - 2026-03-17 -### Fix +### Fixed - Added error message for normalized embeddings on UiPath Platform (AgentHub/Orchestrator) as there is no supported endpoint - Fix endpoints for platform to remove api version @@ -67,7 +75,7 @@ All notable changes to `uipath_llm_client` (core package) will be documented in - test updates - new cassettes -### Fixes +### Fixed - Added constants for VendorType and ApiFlavor ## [1.3.0] - 2026-03-10 @@ -89,41 +97,41 @@ All notable changes to `uipath_llm_client` (core package) will be documented in ## [1.2.2] - 2026-02-23 -### Fix +### Fixed - Fixes to discovery endpoint on LLMGW ## [1.2.1] - 2026-02-18 -### Fix -- TImeout fixes, change typing from int to float +### Fixed +- Timeout fixes, change typing from int to float - remove timeout=None from all clients -> caused overriding the default timeout set up on the UiPathHttpxClient ## [1.2.0] - 2026-02-18 ### Stable release -### Fix -- Fixed agenhub auth when token already exists +### Fixed +- Fixed agenthub auth when token already exists ## [1.1.1] - 2026-02-12 -### Fix +### Fixed - Small fixes on openai client ## [1.1.0] - 2026-02-11 ### Stable release -- Adeed BYOM validation for settings +- Added BYOM validation for settings - Stable release ## [1.0.13] - 2026-02-05 -### Fix +### Fixed - Fixed headers on llmgw settings ## [1.0.12] - 2026-02-05 -### Fix +### Fixed - Added 295 as default llmgateway timeout to avoid problems on the backend side ## [1.0.11] - 2026-02-05 @@ -133,25 +141,25 @@ All notable changes to `uipath_llm_client` (core package) will be documented in ## [1.0.10] - 2026-02-04 -### Type fix +### Fixed - Import TypedDict from typing_extension -### Typing Fix +### Fixed - import @override from typing_extension ## [1.0.8] - 2026-02-04 -### Bug Fix +### Fixed - Fixed a typing issue of Singleton ## [1.0.7] - 2026-02-04 -### Fix +### Fixed - Added py.typed to the package ## [1.0.6] -### Bug Fix +### Fixed - Fixed model discovery on AgentHub Settings. ## [1.0.5] - 2026-02-03 @@ -161,7 +169,7 @@ All notable changes to `uipath_llm_client` (core package) will be documented in ## [1.0.4] - 2026-02-03 -### Bug Fix +### Fixed - Adjusted retry logic, now 0 means no retries, 1 means one retry ## [1.0.3] - 2026-02-02 @@ -171,12 +179,12 @@ All notable changes to `uipath_llm_client` (core package) will be documented in ## [1.0.2] - 2026-02-02 -### Bug Fixes +### Fixed - Fixed endpoints on AgentHub Settings ## [1.0.1] - 2026-01-30 -### Bug Fixes +### Fixed - Map 400 Bad requests on S2S to 401 Unauthorized for better readability ## [1.0.0] - 2026-01-30 diff --git a/packages/uipath_langchain_client/CHANGELOG.md b/packages/uipath_langchain_client/CHANGELOG.md index 2b7bd3c..6ac4201 100644 --- a/packages/uipath_langchain_client/CHANGELOG.md +++ b/packages/uipath_langchain_client/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to `uipath_langchain_client` will be documented in this file. +## [1.6.0] - 2026-04-03 + +### Fixed +- Version bump to match core package changes + ## [1.5.10] - 2026-03-26 ### Changed @@ -9,17 +14,17 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.5.9] - 2026-03-26 -### Fix +### Fixed - Remove `fix_host_header` event hooks from `UiPathChatOpenAI`; host header management is handled by the underlying httpx client ## [1.5.8] - 2026-03-26 -### Fix +### Fixed - Pass `base_url` to `OpenAI` and `AsyncOpenAI` constructors in `UiPathChatOpenAI` to ensure the correct endpoint is used by the underlying SDK clients ## [1.5.7] - 2026-03-23 -### Fix +### Fixed - Fix factory for BYO to handle the case where vendor_type is None, but api_flavor is discovered ## [1.5.6] - 2026-03-21 @@ -31,27 +36,27 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.5.5] - 2026-03-19 -### Fix headers +### Fixed - Fix headers ## [1.5.4] - 2026-03-19 -### Fix +### Fixed - Fix bedrock clients with file attachments ## [1.5.3] - 2026-03-18 -### Fix +### Fixed - Factory function fix ## [1.5.2] - 2026-03-18 -### Fix +### Fixed - Factory function fix ## [1.5.1] - 2026-03-17 -### Fixes +### Fixed - Fixes to core package, version bump ## [1.5.0] - 2026-03-16 @@ -63,13 +68,13 @@ All notable changes to `uipath_langchain_client` will be documented in this file ### New client - Added UiPathChatAnthropicBedrock -- refactored factory function to se the new client +- refactored factory function to use the new client - brought the enums from the base client ## [1.3.1] - 2026-03-12 -### Fix +### Fixed - Fix normalized client raise error ## [1.3.0] - 2026-03-10 @@ -80,7 +85,7 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.2.7] - 2026-02-26 -### Fix +### Fixed - Fix Bedrock clients model_id ## [1.2.6] - 2026-02-26 @@ -91,12 +96,12 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.2.5] - 2026-02-26 -### Fix +### Fixed - Parameters on factory fix ## [1.2.4] - 2026-02-26 -### Fix +### Fixed - Fix typing on factory method ## [1.2.3] - 2026-02-25 @@ -106,13 +111,13 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.2.2] - 2026-02-23 -### Fix +### Fixed - Fixes to discovery endpoint on LLMGW ## [1.2.1] - 2026-02-18 -### Fix -- TImeout fixes, change typing from int to float +### Fixed +- Timeout fixes, change typing from int to float - remove timeout=None from all clients -> caused overriding the default timeout set up on the UiPathHttpxClient ## [1.2.0] - 2026-02-18 @@ -121,7 +126,7 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.1.9] - 2026-02-13 -### Docs +### Changed - Updated documentation ## [1.1.8] - 2026-02-13 @@ -136,34 +141,34 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.1.6] - 2026-02-12 -### Fixes +### Fixed - Added proper type hints for factory method ## [1.1.5] - 2026-02-12 -### Fixes +### Fixed - Fixed bedrock converse api ## [1.1.4] - 2026-02-12 -### Fixes +### Fixed - Fixed anthropic default vendor ## [1.1.3] - 2026-02-12 -### Fixes -- Fixes on openai langchain client on resposes_api -- Allow the flavor to be set up at requst time, not just when instantiating the llm +### Fixed +- Fixes on openai langchain client on responses_api +- Allow the flavor to be set up at request time, not just when instantiating the llm - Some fixes for the anthropic client ## [1.1.2] - 2026-02-12 ### Refactor -- Rename normalized client for better comaptibility with other packages +- Rename normalized client for better compatibility with other packages ## [1.1.1] - 2026-02-11 -### Fixes +### Fixed - Fix langchain fireworks client for async usage ## [1.1.0] - 2026-02-11 @@ -177,17 +182,17 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.0.13] - 2026-02-05 -### Fix +### Fixed - Bump version ## [1.0.12] - 2026-02-05 -### Fix +### Fixed - Added 295 as default llmgateway timeout to avoid problems on the backend side ## [1.0.11] - 2026-02-04 -### Type fix +### Fixed - Import TypedDict from typing_extension ## [1.0.10] - 2026-02-04 @@ -197,12 +202,12 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.0.9] - 2026-02-04 -### Fix +### Fixed - Fixed typing in core package, updated dependency ## [1.0.8] - 2026-02-04 -### Fix +### Fixed - Added py.typed to the package ## [1.0.7] - 2026-02-04 @@ -218,27 +223,27 @@ All notable changes to `uipath_langchain_client` will be documented in this file ## [1.0.5] - 2026-02-03 -### Bug Fix +### Fixed - Fixed retry logic on all clients ## [1.0.4] - 2026-02-03 -### Bug Fix +### Fixed - Fix some timout issues on langchain_openai from llmgw. ## [1.0.3] - 2026-02-02 -### Bug Fix +### Fixed - Added better dependencies for langchain-anthropic to include boto and vertex ## [1.0.2] - 2026-02-02 -### Bug Fix +### Fixed - Removed old fix on Gemini streaming and updated with a new cleaner one ## [1.0.1] - 2026-02-02 -### Bug Fix +### Fixed - Fixed Api Version on OpenAI Embeddings ## [1.0.0] - 2026-01-30 diff --git a/packages/uipath_langchain_client/demo.py b/packages/uipath_langchain_client/demo.py index be792d9..e3c887e 100644 --- a/packages/uipath_langchain_client/demo.py +++ b/packages/uipath_langchain_client/demo.py @@ -17,9 +17,7 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import tool from uipath_langchain_client import get_chat_model, get_embedding_model -from uipath_langchain_client.settings import get_default_client_settings - -from uipath.llm_client.settings.constants import RoutingMode +from uipath_langchain_client.settings import RoutingMode, get_default_client_settings def demo_basic_chat(): @@ -136,8 +134,31 @@ def calculate(expression: str) -> str: Args: expression: A mathematical expression to evaluate (e.g., "2 + 2"). """ + import ast + try: - result = eval(expression) + # Restrict to a safe subset: only literals and basic arithmetic operators. + # This prevents arbitrary code execution via eval(). + tree = ast.parse(expression, mode="eval") + allowed_node_types = ( + ast.Expression, + ast.BinOp, + ast.UnaryOp, + ast.Constant, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, + ast.USub, + ast.UAdd, + ) + for node in ast.walk(tree): + if not isinstance(node, allowed_node_types): + return "Error: unsupported operation in expression" + result = eval(compile(tree, "", "eval"), {"__builtins__": {}}) return str(result) except Exception as e: return f"Error: {e}" diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py index 292aefb..5149f6e 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py @@ -1,3 +1,3 @@ __title__ = "UiPath LangChain Client" __description__ = "A Python client for interacting with UiPath's LLM services via LangChain." -__version__ = "1.5.10" +__version__ = "1.6.0" diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py b/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py index 2999441..a79025e 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py @@ -25,7 +25,7 @@ import logging from abc import ABC -from collections.abc import AsyncIterator, Iterator, Mapping, Sequence +from collections.abc import AsyncGenerator, Generator, Mapping, Sequence from functools import cached_property from typing import Any, Literal @@ -189,7 +189,7 @@ def uipath_async_client(self) -> UiPathHttpxAsyncClient: def uipath_request( self, - method: str = "POST", + method: Literal["POST", "GET"] = "POST", url: URL | str = "/", *, request_body: dict[str, Any] | None = None, @@ -199,16 +199,18 @@ def uipath_request( """Make a synchronous HTTP request to the UiPath API. Args: - method: HTTP method (GET, POST, etc.). Defaults to "POST". + method: HTTP method (POST or GET). Defaults to "POST". url: Request URL path. Defaults to "/". request_body: JSON request body to send. + raise_status_error: If True, raises UiPathAPIError on non-2xx responses. **kwargs: Additional arguments passed to httpx.Client.request(). Returns: httpx.Response: The HTTP response from the API. Raises: - UiPathAPIError: On HTTP 4xx/5xx responses (raised by transport layer). + UiPathAPIError: On HTTP 4xx/5xx responses when raise_status_error is True, + or raised by the transport layer. """ response = self.uipath_sync_client.request(method, url, json=request_body, **kwargs) if raise_status_error: @@ -224,7 +226,22 @@ async def uipath_arequest( raise_status_error: bool = False, **kwargs: Any, ) -> Response: - """Make an asynchronous HTTP request to the UiPath API.""" + """Make an asynchronous HTTP request to the UiPath API. + + Args: + method: HTTP method (POST or GET). Defaults to "POST". + url: Request URL path. Defaults to "/". + request_body: JSON request body to send. + raise_status_error: If True, raises UiPathAPIError on non-2xx responses. + **kwargs: Additional arguments passed to httpx.AsyncClient.request(). + + Returns: + httpx.Response: The HTTP response from the API. + + Raises: + UiPathAPIError: On HTTP 4xx/5xx responses when raise_status_error is True, + or raised by the transport layer. + """ response = await self.uipath_async_client.request(method, url, json=request_body, **kwargs) if raise_status_error: response.raise_for_status() @@ -239,7 +256,7 @@ def uipath_stream( stream_type: Literal["text", "bytes", "lines", "raw"] = "lines", raise_status_error: bool = False, **kwargs: Any, - ) -> Iterator[str | bytes]: + ) -> Generator[str | bytes, None, None]: """Make a synchronous streaming HTTP request to the UiPath API. Args: @@ -251,6 +268,7 @@ def uipath_stream( - "bytes": Yield raw byte chunks - "lines": Yield complete lines (default, best for SSE) - "raw": Yield raw response data + raise_status_error: If True, raises UiPathAPIError on non-2xx responses. **kwargs: Additional arguments passed to httpx.Client.stream(). Yields: @@ -282,7 +300,7 @@ async def uipath_astream( stream_type: Literal["text", "bytes", "lines", "raw"] = "lines", raise_status_error: bool = False, **kwargs: Any, - ) -> AsyncIterator[str | bytes]: + ) -> AsyncGenerator[str | bytes, None]: """Make an asynchronous streaming HTTP request to the UiPath API. Args: @@ -294,6 +312,7 @@ async def uipath_astream( - "bytes": Yield raw byte chunks - "lines": Yield complete lines (default, best for SSE) - "raw": Yield raw response data + raise_status_error: If True, raises UiPathAPIError on non-2xx responses. **kwargs: Additional arguments passed to httpx.AsyncClient.stream(). Yields: @@ -393,7 +412,7 @@ def _stream( stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: + ) -> Generator[ChatGenerationChunk, None, None]: set_captured_response_headers({}) try: first = True @@ -413,7 +432,7 @@ def _uipath_stream( stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: + ) -> Generator[ChatGenerationChunk, None, None]: """Override in subclasses to provide the core (non-wrapped) stream logic.""" yield from super()._stream(messages, stop=stop, run_manager=run_manager, **kwargs) @@ -423,7 +442,7 @@ async def _astream( stop: list[str] | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: + ) -> AsyncGenerator[ChatGenerationChunk, None]: set_captured_response_headers({}) try: first = True @@ -443,7 +462,7 @@ async def _uipath_astream( stop: list[str] | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: + ) -> AsyncGenerator[ChatGenerationChunk, None]: """Override in subclasses to provide the core (non-wrapped) async stream logic.""" async for chunk in super()._astream(messages, stop=stop, run_manager=run_manager, **kwargs): yield chunk diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/callbacks.py b/packages/uipath_langchain_client/src/uipath_langchain_client/callbacks.py index af99ac1..431b49c 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/callbacks.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/callbacks.py @@ -1,6 +1,6 @@ """LangChain callbacks for dynamic per-request header injection.""" -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any from langchain_core.callbacks import BaseCallbackHandler @@ -8,7 +8,7 @@ from uipath.llm_client.utils.headers import set_dynamic_request_headers -class UiPathDynamicHeadersCallback(BaseCallbackHandler): +class UiPathDynamicHeadersCallback(BaseCallbackHandler, ABC): """Base callback for injecting dynamic headers into each LLM gateway request. Extend this class and implement ``get_headers()`` to return the headers to diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/anthropic/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/anthropic/chat_models.py index 52633df..26b87a2 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/anthropic/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/anthropic/chat_models.py @@ -45,15 +45,17 @@ class UiPathChatAnthropic(UiPathBaseChatModel, ChatAnthropic): def setup_api_flavor_and_version(self) -> Self: self.api_config.vendor_type = self.vendor_type match self.vendor_type: + case VendorType.ANTHROPIC: + self.api_config.api_flavor = None + case VendorType.AZURE: + self.api_config.api_flavor = None case VendorType.VERTEXAI: self.api_config.api_flavor = ApiFlavor.ANTHROPIC_CLAUDE self.api_config.api_version = "v1beta1" case VendorType.AWSBEDROCK: self.api_config.api_flavor = ApiFlavor.INVOKE case _: - raise ValueError( - "anthropic and azure vendors are currently not supported by UiPath" - ) + raise ValueError(f"Unsupported vendor_type: {self.vendor_type}") return self # Override fields to avoid typing issues and fix stuff @@ -150,13 +152,13 @@ def _async_anthropic_client( raise ValueError("Anthropic models are currently not hosted on any other provider") @override - def _create(self, payload: dict) -> Any: + def _create(self, payload: dict[str, Any]) -> Any: if "betas" in payload: return self._anthropic_client.beta.messages.create(**payload) return self._anthropic_client.messages.create(**payload) @override - async def _acreate(self, payload: dict) -> Any: + async def _acreate(self, payload: dict[str, Any]) -> Any: if "betas" in payload: return await self._async_anthropic_client.beta.messages.create(**payload) return await self._async_anthropic_client.messages.create(**payload) diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/azure/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/azure/chat_models.py index 3612a09..7af831f 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/azure/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/azure/chat_models.py @@ -1,11 +1,11 @@ from typing import Self -from httpx import URL, Request +from httpx import Request from pydantic import Field, model_validator from uipath_langchain_client.base_client import UiPathBaseChatModel +from uipath_langchain_client.clients.openai.utils import fix_url_and_api_flavor_header from uipath_langchain_client.settings import ( - ApiFlavor, ApiType, RoutingMode, UiPathAPIConfig, @@ -42,24 +42,14 @@ class UiPathAzureAIChatCompletionsModel(UiPathBaseChatModel, AzureAIOpenAIApiCha def setup_uipath_client(self) -> Self: base_url = str(self.uipath_sync_client.base_url).rstrip("/") - def fix_url_and_api_flavor_header(request: Request): - url_suffix = str(request.url).split(base_url)[-1] - if "responses" in url_suffix: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value - else: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value - request.url = URL(base_url) + def on_request(request: Request) -> None: + fix_url_and_api_flavor_header(base_url, request) - async def fix_url_and_api_flavor_header_async(request: Request): - url_suffix = str(request.url).split(base_url)[-1] - if "responses" in url_suffix: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value - else: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value - request.url = URL(base_url) + async def on_request_async(request: Request) -> None: + fix_url_and_api_flavor_header(base_url, request) - self.uipath_sync_client.event_hooks["request"].append(fix_url_and_api_flavor_header) - self.uipath_async_client.event_hooks["request"].append(fix_url_and_api_flavor_header_async) + self.uipath_sync_client.event_hooks["request"].append(on_request) + self.uipath_async_client.event_hooks["request"].append(on_request_async) self.root_client = OpenAI( api_key="PLACEHOLDER", diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/chat_models.py index f2349b0..8993522 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/chat_models.py @@ -48,7 +48,7 @@ def _patched_format_data_content_block(block: dict) -> dict: ) from e -class UiPathChatBedrockConverse(UiPathBaseChatModel, ChatBedrockConverse): +class UiPathChatBedrockConverse(UiPathBaseChatModel, ChatBedrockConverse): # type: ignore[override] api_config: UiPathAPIConfig = UiPathAPIConfig( api_type=ApiType.COMPLETIONS, routing_mode=RoutingMode.PASSTHROUGH, @@ -77,7 +77,7 @@ def setup_uipath_client(self) -> Self: return self -class UiPathChatBedrock(UiPathBaseChatModel, ChatBedrock): +class UiPathChatBedrock(UiPathBaseChatModel, ChatBedrock): # type: ignore[override] api_config: UiPathAPIConfig = UiPathAPIConfig( api_type=ApiType.COMPLETIONS, routing_mode=RoutingMode.PASSTHROUGH, diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/utils.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/utils.py index a56f371..232c085 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/utils.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/bedrock/utils.py @@ -1,6 +1,7 @@ import base64 import json -from typing import Any, Iterator +from collections.abc import Generator +from typing import Any from httpx import Client @@ -50,7 +51,9 @@ def __init__(self, httpx_client: Client | None = None, region_name: str = "PLACE self.httpx_client = httpx_client self.meta = _MockClientMeta(region_name=region_name) - def _stream_generator(self, request_body: dict[str, Any]) -> Iterator[dict[str, Any]]: + def _stream_generator( + self, request_body: dict[str, Any] + ) -> Generator[dict[str, Any], None, None]: if self.httpx_client is None: raise ValueError("httpx_client is not set") with self.httpx_client.stream("POST", "/", json=_serialize_bytes(request_body)) as response: @@ -71,15 +74,19 @@ def invoke_model(self, **kwargs: Any) -> Any: return { "body": self.httpx_client.post( "/", - json=json.loads(kwargs.get("body", {})), + json=json.loads(kwargs.get("body", "{}")), ) } def invoke_model_with_response_stream(self, **kwargs: Any) -> Any: - return {"body": self._stream_generator(json.loads(kwargs.get("body", {})))} + return {"body": self._stream_generator(json.loads(kwargs.get("body", "{}")))} def converse( - self, *, messages: list[dict[str, Any]], system: str | None = None, **params: Any + self, + *, + messages: list[dict[str, Any]], + system: list[dict[str, Any]] | None = None, + **params: Any, ) -> Any: if self.httpx_client is None: raise ValueError("httpx_client is not set") @@ -95,7 +102,11 @@ def converse( ).json() def converse_stream( - self, *, messages: list[dict[str, Any]], system: str | None = None, **params: Any + self, + *, + messages: list[dict[str, Any]], + system: list[dict[str, Any]] | None = None, + **params: Any, ) -> Any: return { "stream": self._stream_generator( diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/fireworks/embeddings.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/fireworks/embeddings.py index b4af2bc..161b71d 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/fireworks/embeddings.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/fireworks/embeddings.py @@ -1,10 +1,9 @@ from typing import Self -from pydantic import model_validator +from pydantic import Field, model_validator from uipath_langchain_client.base_client import UiPathBaseEmbeddings from uipath_langchain_client.settings import ( - ApiFlavor, ApiType, RoutingMode, UiPathAPIConfig, @@ -26,11 +25,12 @@ class UiPathFireworksEmbeddings(UiPathBaseEmbeddings, FireworksEmbeddings): api_type=ApiType.EMBEDDINGS, routing_mode=RoutingMode.PASSTHROUGH, vendor_type=VendorType.OPENAI, - api_flavor=ApiFlavor.CHAT_COMPLETIONS, api_version="2025-03-01-preview", freeze_base_url=True, ) + model: str = Field(default="", alias="model_name") + @model_validator(mode="after") def setup_uipath_client(self) -> Self: self.client = OpenAI( @@ -48,7 +48,8 @@ def setup_uipath_client(self) -> Self: def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" return [ - i.embedding for i in self.client.embeddings.create(input=texts, model=self.model).data + i.embedding + for i in self.client.embeddings.create(input=texts, model=self.model_name).data ] def embed_query(self, text: str) -> list[float]: @@ -59,7 +60,9 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs asynchronously.""" return [ i.embedding - for i in (await self.async_client.embeddings.create(input=texts, model=self.model)).data + for i in ( + await self.async_client.embeddings.create(input=texts, model=self.model_name) + ).data ] async def aembed_query(self, text: str) -> list[float]: diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py index 582181c..47e325b 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/chat_models.py @@ -24,7 +24,7 @@ """ import json -from collections.abc import AsyncIterator, Callable, Iterator, Sequence +from collections.abc import AsyncGenerator, Callable, Generator, Sequence from typing import Any from langchain_core.callbacks import ( @@ -155,7 +155,6 @@ def _default_params(self) -> dict[str, Any]: } return { - "model": self.model_name, **{k: v for k, v in exclude_if_none.items() if v is not None}, **self.model_kwargs, } @@ -329,9 +328,7 @@ async def _uipath_agenerate( response = await self.uipath_arequest(request_body=request_body, raise_status_error=True) return self._postprocess_response(response.json()) - def _generate_chunk( - self, original_message: str, json_data: dict[str, Any] - ) -> ChatGenerationChunk: + def _generate_chunk(self, json_data: dict[str, Any]) -> ChatGenerationChunk: generation_info = { "id": json_data.get("id"), "created": json_data.get("created", ""), @@ -377,10 +374,10 @@ def _generate_chunk( ) return ChatGenerationChunk( - text=original_message, + text=content or "", generation_info=generation_info, message=AIMessageChunk( - content=content, + content=content or "", usage_metadata=usage_metadata, tool_call_chunks=tool_call_chunks, ), @@ -392,21 +389,22 @@ def _uipath_stream( stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: + ) -> Generator[ChatGenerationChunk, None, None]: request_body = self._preprocess_request(messages, stop=stop, **kwargs) + request_body["stream"] = True for chunk in self.uipath_stream( request_body=request_body, stream_type="lines", raise_status_error=True ): chunk = str(chunk) if chunk.startswith("data:"): - chunk = chunk.split("data:")[1].strip() + chunk = chunk[len("data:") :].strip() try: json_data = json.loads(chunk) except json.JSONDecodeError: continue if "id" in json_data and not json_data["id"]: continue - yield self._generate_chunk(chunk, json_data) + yield self._generate_chunk(json_data) async def _uipath_astream( self, @@ -414,18 +412,19 @@ async def _uipath_astream( stop: list[str] | None = None, run_manager: AsyncCallbackManagerForLLMRun | None = None, **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: + ) -> AsyncGenerator[ChatGenerationChunk, None]: request_body = self._preprocess_request(messages, stop=stop, **kwargs) + request_body["stream"] = True async for chunk in self.uipath_astream( request_body=request_body, stream_type="lines", raise_status_error=True ): chunk = str(chunk) if chunk.startswith("data:"): - chunk = chunk.split("data:")[1].strip() + chunk = chunk[len("data:") :].strip() try: json_data = json.loads(chunk) except json.JSONDecodeError: continue if "id" in json_data and not json_data["id"]: continue - yield self._generate_chunk(chunk, json_data) + yield self._generate_chunk(json_data) diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/embeddings.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/embeddings.py index 86498a4..e77f1d6 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/embeddings.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/normalized/embeddings.py @@ -1,10 +1,8 @@ -from langchain_core.embeddings import Embeddings - from uipath_langchain_client.base_client import UiPathBaseEmbeddings from uipath_langchain_client.settings import ApiType, RoutingMode, UiPathAPIConfig -class UiPathEmbeddings(UiPathBaseEmbeddings, Embeddings): +class UiPathEmbeddings(UiPathBaseEmbeddings): """LangChain embeddings using the UiPath's normalized embeddings API. Provides a consistent interface for generating text embeddings across all @@ -18,14 +16,20 @@ class UiPathEmbeddings(UiPathBaseEmbeddings, Embeddings): ) def embed_documents(self, texts: list[str]) -> list[list[float]]: - response = self.uipath_request(request_body={"input": texts}) + response = self.uipath_request( + request_body={"input": texts}, + raise_status_error=True, + ) return [r["embedding"] for r in response.json()["data"]] def embed_query(self, text: str) -> list[float]: return self.embed_documents([text])[0] async def aembed_documents(self, texts: list[str]) -> list[list[float]]: - response = await self.uipath_arequest(request_body={"input": texts}) + response = await self.uipath_arequest( + request_body={"input": texts}, + raise_status_error=True, + ) return [r["embedding"] for r in response.json()["data"]] async def aembed_query(self, text: str) -> list[float]: diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/chat_models.py index b448867..4278c1f 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/chat_models.py @@ -1,12 +1,12 @@ from collections.abc import Awaitable, Callable from typing import Self -from httpx import URL, Request +from httpx import Request from pydantic import Field, SecretStr, model_validator from uipath_langchain_client.base_client import UiPathBaseChatModel +from uipath_langchain_client.clients.openai.utils import fix_url_and_api_flavor_header from uipath_langchain_client.settings import ( - ApiFlavor, ApiType, RoutingMode, UiPathAPIConfig, @@ -41,24 +41,14 @@ class UiPathChatOpenAI(UiPathBaseChatModel, ChatOpenAI): # type: ignore[overrid def setup_uipath_client(self) -> Self: base_url = str(self.uipath_sync_client.base_url).rstrip("/") - def fix_url_and_api_flavor_header(request: Request): - url_suffix = str(request.url).split(base_url)[-1] - if "responses" in url_suffix: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value - else: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value - request.url = URL(base_url) - - async def fix_url_and_api_flavor_header_async(request: Request): - url_suffix = str(request.url).split(base_url)[-1] - if "responses" in url_suffix: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value - else: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value - request.url = URL(base_url) - - self.uipath_sync_client.event_hooks["request"].append(fix_url_and_api_flavor_header) - self.uipath_async_client.event_hooks["request"].append(fix_url_and_api_flavor_header_async) + def on_request(request: Request) -> None: + fix_url_and_api_flavor_header(base_url, request) + + async def on_request_async(request: Request) -> None: + fix_url_and_api_flavor_header(base_url, request) + + self.uipath_sync_client.event_hooks["request"].append(on_request) + self.uipath_async_client.event_hooks["request"].append(on_request_async) self.root_client = OpenAI( api_key="PLACEHOLDER", @@ -95,24 +85,14 @@ class UiPathAzureChatOpenAI(UiPathBaseChatModel, AzureChatOpenAI): # type: igno def setup_uipath_client(self) -> Self: base_url = str(self.uipath_sync_client.base_url).rstrip("/") - def fix_url_and_api_flavor_header(request: Request): - url_suffix = str(request.url).split(base_url)[-1] - if "responses" in url_suffix: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value - else: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value - request.url = URL(base_url) - - async def fix_url_and_api_flavor_header_async(request: Request): - url_suffix = str(request.url).split(base_url)[-1] - if "responses" in url_suffix: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value - else: - request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value - request.url = URL(base_url) - - self.uipath_sync_client.event_hooks["request"].append(fix_url_and_api_flavor_header) - self.uipath_async_client.event_hooks["request"].append(fix_url_and_api_flavor_header_async) + def on_request(request: Request) -> None: + fix_url_and_api_flavor_header(base_url, request) + + async def on_request_async(request: Request) -> None: + fix_url_and_api_flavor_header(base_url, request) + + self.uipath_sync_client.event_hooks["request"].append(on_request) + self.uipath_async_client.event_hooks["request"].append(on_request_async) self.root_client = AzureOpenAI( azure_endpoint="PLACEHOLDER", diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/utils.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/utils.py new file mode 100644 index 0000000..03ecbbb --- /dev/null +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/openai/utils.py @@ -0,0 +1,26 @@ +"""Shared utilities for UiPath LangChain provider clients.""" + +from httpx import URL, Request + +from uipath_langchain_client.settings import ApiFlavor + + +def fix_url_and_api_flavor_header(base_url: str, request: Request) -> None: + """Detect API flavor from URL suffix and rewrite the URL to the base gateway URL. + + Inspects the outgoing request URL to determine whether it targets the + OpenAI *responses* or *chat completions* endpoint and sets the + ``X-UiPath-LlmGateway-ApiFlavor`` header accordingly. The request URL + is then collapsed back to *base_url* so that the gateway receives a + clean path. + + Args: + base_url: The UiPath gateway base URL to rewrite the request to. + request: The outgoing httpx request (mutated in place). + """ + url_suffix = str(request.url).split(base_url)[-1] + if "responses" in url_suffix: + request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.RESPONSES.value + else: + request.headers["X-UiPath-LlmGateway-ApiFlavor"] = ApiFlavor.CHAT_COMPLETIONS.value + request.url = URL(base_url) diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/vertexai/chat_models.py b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/vertexai/chat_models.py index 1d3d242..d6f1db6 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/clients/vertexai/chat_models.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/clients/vertexai/chat_models.py @@ -41,7 +41,7 @@ def setup_uipath_client(self) -> Self: project_id="PLACEHOLDER", access_token="PLACEHOLDER", base_url=str(self.uipath_sync_client.base_url), - default_headers=self.uipath_sync_client.headers, + default_headers=dict(self.uipath_sync_client.headers), max_retries=0, # handled by the UiPath client http_client=self.uipath_sync_client, ) @@ -50,7 +50,7 @@ def setup_uipath_client(self) -> Self: project_id="PLACEHOLDER", access_token="PLACEHOLDER", base_url=str(self.uipath_async_client.base_url), - default_headers=self.uipath_async_client.headers, + default_headers=dict(self.uipath_async_client.headers), max_retries=0, # handled by the UiPath client http_client=self.uipath_async_client, ) diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py index 9a5804b..4e98d63 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py @@ -27,7 +27,7 @@ UiPathBaseEmbeddings, ) from uipath_langchain_client.settings import ( - _API_FLAVOR_TO_VENDOR_TYPE, + API_FLAVOR_TO_VENDOR_TYPE, ApiFlavor, RoutingMode, UiPathBaseSettings, @@ -73,7 +73,7 @@ def _get_model_info( if not matching_models: raise ValueError( - f"Model {model_name} not found in available models the available models are: {[m['modelName'] for m in available_models]}" + f"Model {model_name} not found. Available models are: {[m['modelName'] for m in available_models]}" ) return matching_models[0] @@ -143,7 +143,7 @@ def get_chat_model( discovered_vendor_type = model_info.get("vendor", None) discovered_api_flavor = model_info.get("apiFlavor", None) if discovered_vendor_type is None and discovered_api_flavor is not None: - discovered_vendor_type = _API_FLAVOR_TO_VENDOR_TYPE.get(discovered_api_flavor, None) + discovered_vendor_type = API_FLAVOR_TO_VENDOR_TYPE.get(discovered_api_flavor, None) if discovered_vendor_type is None: raise ValueError("No vendor type or api flavor found in model info") discovered_vendor_type = discovered_vendor_type.lower() @@ -298,7 +298,17 @@ def get_embedding_model( **model_kwargs, ) - discovered_vendor_type = model_info["vendor"].lower() + discovered_vendor_type = model_info.get("vendor") + if discovered_vendor_type is None: + discovered_api_flavor = model_info.get("apiFlavor") + if discovered_api_flavor is not None: + discovered_vendor_type = API_FLAVOR_TO_VENDOR_TYPE.get(discovered_api_flavor) + if discovered_vendor_type is None: + raise ValueError( + f"No vendor type found in model info for embedding model '{model_name}'. " + f"Model info returned: {model_info}" + ) + discovered_vendor_type = discovered_vendor_type.lower() match discovered_vendor_type: case VendorType.OPENAI: if is_uipath_owned: diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/settings.py b/packages/uipath_langchain_client/src/uipath_langchain_client/settings.py index e8fc87a..32c9a4b 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/settings.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/settings.py @@ -23,7 +23,7 @@ get_default_client_settings, ) from uipath.llm_client.settings.constants import ( - _API_FLAVOR_TO_VENDOR_TYPE, + API_FLAVOR_TO_VENDOR_TYPE, ApiFlavor, ApiType, RoutingMode, @@ -40,5 +40,5 @@ "RoutingMode", "ApiFlavor", "VendorType", - "_API_FLAVOR_TO_VENDOR_TYPE", + "API_FLAVOR_TO_VENDOR_TYPE", ] diff --git a/src/uipath/llm_client/__init__.py b/src/uipath/llm_client/__init__.py index 342ef49..b437cef 100644 --- a/src/uipath/llm_client/__init__.py +++ b/src/uipath/llm_client/__init__.py @@ -9,20 +9,22 @@ - uipath_llamaindex_client: LlamaIndex-compatible models Quick Start: - >>> from uipath.llm_client import UiPathBaseLLMClient, UiPathAPIConfig - >>> from uipath.llm_client.settings import get_default_client_settings + >>> from uipath.llm_client import UiPathHttpxClient + >>> from uipath.llm_client.settings import get_default_client_settings, UiPathAPIConfig + >>> from uipath.llm_client.settings.constants import ApiType, RoutingMode >>> >>> settings = get_default_client_settings() - >>> client = UiPathBaseLLMClient( - ... model="gpt-4o-2024-11-20", - ... api_config=UiPathAPIConfig( - ... api_type=ApiType.COMPLETIONS, - ... routing_mode=RoutingMode.PASSTHROUGH, - ... vendor_type="openai", - ... ), - ... settings=settings, + >>> api_config = UiPathAPIConfig( + ... api_type=ApiType.COMPLETIONS, + ... routing_mode=RoutingMode.PASSTHROUGH, + ... vendor_type="openai", + ... ) + >>> client = UiPathHttpxClient( + ... model_name="gpt-4o-2024-11-20", + ... api_config=api_config, + ... base_url=settings.build_base_url(model_name="gpt-4o-2024-11-20", api_config=api_config), + ... auth=settings.build_auth_pipeline(), ... ) - >>> response = client.uipath_request(request_body={...}) """ from uipath.llm_client.__version__ import __version__ diff --git a/src/uipath/llm_client/__version__.py b/src/uipath/llm_client/__version__.py index 7da9182..c5bf7a8 100644 --- a/src/uipath/llm_client/__version__.py +++ b/src/uipath/llm_client/__version__.py @@ -1,3 +1,3 @@ __title__ = "UiPath LLM Client" __description__ = "A Python client for interacting with UiPath's LLM services." -__version__ = "1.5.10" +__version__ = "1.6.0" diff --git a/src/uipath/llm_client/clients/anthropic/client.py b/src/uipath/llm_client/clients/anthropic/client.py index 1c5cdfc..44bb318 100644 --- a/src/uipath/llm_client/clients/anthropic/client.py +++ b/src/uipath/llm_client/clients/anthropic/client.py @@ -15,9 +15,9 @@ """ import logging -from typing import Any +from collections.abc import Mapping, Sequence -from uipath.llm_client.httpx_client import UiPathHttpxAsyncClient, UiPathHttpxClient +from uipath.llm_client.clients.utils import build_httpx_async_client, build_httpx_client from uipath.llm_client.settings import ( UiPathAPIConfig, UiPathBaseSettings, @@ -57,16 +57,16 @@ def _build_api_config(vendor_type: str | VendorType = VendorType.ANTHROPIC) -> U class UiPathAnthropic(Anthropic): """Anthropic client routed through UiPath LLM Gateway. - Wraps the standard Anthropic client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name (e.g., "claude-3-5-sonnet-20241022"). byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to Anthropic client. """ def __init__( @@ -75,48 +75,45 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config() - httpx_client = UiPathHttpxClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( api_key="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAsyncAnthropic(AsyncAnthropic): """Async Anthropic client routed through UiPath LLM Gateway. - Wraps the standard AsyncAnthropic client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name (e.g., "claude-3-5-sonnet-20241022"). byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AsyncAnthropic client. """ def __init__( @@ -125,48 +122,45 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config() - httpx_client = UiPathHttpxAsyncClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( api_key="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_async_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAnthropicBedrock(AnthropicBedrock): """Anthropic Bedrock client routed through UiPath LLM Gateway. - Wraps the AnthropicBedrock client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name. byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AnthropicBedrock client. """ def __init__( @@ -175,50 +169,47 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config(vendor_type=VendorType.AWSBEDROCK) - httpx_client = UiPathHttpxClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( aws_access_key="PLACEHOLDER", aws_secret_key="PLACEHOLDER", aws_region="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(vendor_type=VendorType.AWSBEDROCK), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAsyncAnthropicBedrock(AsyncAnthropicBedrock): """Async Anthropic Bedrock client routed through UiPath LLM Gateway. - Wraps the AsyncAnthropicBedrock client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name. byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AsyncAnthropicBedrock client. """ def __init__( @@ -227,50 +218,47 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config(vendor_type=VendorType.AWSBEDROCK) - httpx_client = UiPathHttpxAsyncClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( aws_access_key="PLACEHOLDER", aws_secret_key="PLACEHOLDER", aws_region="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_async_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(vendor_type=VendorType.AWSBEDROCK), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAnthropicVertex(AnthropicVertex): """Anthropic Vertex client routed through UiPath LLM Gateway. - Wraps the AnthropicVertex client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name. byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AnthropicVertex client. """ def __init__( @@ -279,50 +267,47 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config(vendor_type=VendorType.VERTEXAI) - httpx_client = UiPathHttpxClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( region="PLACEHOLDER", project_id="PLACEHOLDER", access_token="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(vendor_type=VendorType.VERTEXAI), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAsyncAnthropicVertex(AsyncAnthropicVertex): """Async Anthropic Vertex client routed through UiPath LLM Gateway. - Wraps the AsyncAnthropicVertex client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name. byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AsyncAnthropicVertex client. """ def __init__( @@ -331,50 +316,47 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config(vendor_type=VendorType.VERTEXAI) - httpx_client = UiPathHttpxAsyncClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( region="PLACEHOLDER", project_id="PLACEHOLDER", access_token="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_async_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(vendor_type=VendorType.VERTEXAI), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAnthropicFoundry(AnthropicFoundry): """Anthropic Foundry (Azure) client routed through UiPath LLM Gateway. - Wraps the AnthropicFoundry client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name. byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AnthropicFoundry client. """ def __init__( @@ -383,48 +365,45 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config(vendor_type=VendorType.AZURE) - httpx_client = UiPathHttpxClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( api_key="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(vendor_type=VendorType.AZURE), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) class UiPathAsyncAnthropicFoundry(AsyncAnthropicFoundry): """Async Anthropic Foundry (Azure) client routed through UiPath LLM Gateway. - Wraps the AsyncAnthropicFoundry client to route requests through UiPath's - LLM Gateway while preserving the full Anthropic SDK interface. - Args: model_name: The Anthropic model name. byo_connection_id: Bring Your Own connection ID for custom deployments. client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). retry_config: Custom retry configuration. logger: Logger instance for request/response logging. - **kwargs: Additional arguments passed to AsyncAnthropicFoundry client. """ def __init__( @@ -433,30 +412,27 @@ def __init__( model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - api_config = _build_api_config(vendor_type=VendorType.AZURE) - httpx_client = UiPathHttpxAsyncClient( - model_name=model_name, - byo_connection_id=byo_connection_id, - api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), - retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, - logger=logger, - auth=client_settings.build_auth_pipeline(), - ) super().__init__( api_key="PLACEHOLDER", max_retries=0, - http_client=httpx_client, - **kwargs, + http_client=build_httpx_async_client( + model_name=model_name, + byo_connection_id=byo_connection_id, + client_settings=client_settings, + api_config=_build_api_config(vendor_type=VendorType.AZURE), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, + retry_config=retry_config, + logger=logger, + ), ) diff --git a/src/uipath/llm_client/clients/google/client.py b/src/uipath/llm_client/clients/google/client.py index 6e1ae88..7f3406c 100644 --- a/src/uipath/llm_client/clients/google/client.py +++ b/src/uipath/llm_client/clients/google/client.py @@ -1,7 +1,7 @@ import logging -from typing import Any +from collections.abc import Mapping, Sequence -from uipath.llm_client.httpx_client import UiPathHttpxAsyncClient, UiPathHttpxClient +from uipath.llm_client.clients.utils import build_httpx_async_client, build_httpx_client from uipath.llm_client.settings import ( UiPathAPIConfig, UiPathBaseSettings, @@ -21,15 +21,32 @@ class UiPathGoogle(Client): + """Google GenAI client routed through UiPath LLM Gateway. + + Args: + model_name: The Google model name (e.g., "gemini-2.5-flash"). + byo_connection_id: Bring Your Own connection ID for custom deployments. + client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). + retry_config: Custom retry configuration. + logger: Logger instance for request/response logging. + """ + def __init__( self, *, model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() api_config = UiPathAPIConfig( @@ -40,35 +57,29 @@ def __init__( api_version="v1beta1", freeze_base_url=True, ) - httpx_client = UiPathHttpxClient( + httpx_client = build_httpx_client( model_name=model_name, byo_connection_id=byo_connection_id, + client_settings=client_settings, api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, logger=logger, - auth=client_settings.build_auth_pipeline(), ) - httpx_async_client = UiPathHttpxAsyncClient( + httpx_async_client = build_httpx_async_client( model_name=model_name, byo_connection_id=byo_connection_id, + client_settings=client_settings, api_config=api_config, - timeout=kwargs.pop("timeout", None), - max_retries=kwargs.pop("max_retries", None), + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + captured_headers=captured_headers, retry_config=retry_config, - base_url=client_settings.build_base_url(model_name=model_name, api_config=api_config), - headers={ - **kwargs.pop("default_headers", {}), - **client_settings.build_auth_headers(model_name=model_name, api_config=api_config), - }, logger=logger, - auth=client_settings.build_auth_pipeline(), ) super().__init__( api_key="PLACEHOLDER", diff --git a/src/uipath/llm_client/clients/openai/client.py b/src/uipath/llm_client/clients/openai/client.py index cb9b82d..b1c893b 100644 --- a/src/uipath/llm_client/clients/openai/client.py +++ b/src/uipath/llm_client/clients/openai/client.py @@ -1,10 +1,9 @@ import logging -from typing import Any +from collections.abc import Mapping, Sequence from uipath.llm_client.clients.openai.utils import OpenAIRequestHandler -from uipath.llm_client.httpx_client import UiPathHttpxAsyncClient, UiPathHttpxClient -from uipath.llm_client.settings import get_default_client_settings -from uipath.llm_client.settings.base import UiPathBaseSettings +from uipath.llm_client.clients.utils import build_httpx_async_client, build_httpx_client +from uipath.llm_client.settings import UiPathBaseSettings, get_default_client_settings from uipath.llm_client.utils.retry import RetryConfig try: @@ -17,29 +16,47 @@ class UiPathOpenAI(OpenAI): + """OpenAI client routed through UiPath LLM Gateway. + + Wraps the standard OpenAI client to route requests through UiPath's + LLM Gateway while preserving the full OpenAI SDK interface. + + Args: + model_name: The OpenAI model name (e.g., "gpt-4o-2024-11-20"). + byo_connection_id: Bring Your Own connection ID for custom deployments. + client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). + retry_config: Custom retry configuration. + logger: Logger instance for request/response logging. + """ + def __init__( self, *, model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - timeout = kwargs.pop("timeout", None) - max_retries = kwargs.pop("max_retries", None) - default_headers = kwargs.pop("default_headers", None) - httpx_client = UiPathHttpxClient( + httpx_client = build_httpx_client( model_name=model_name, byo_connection_id=byo_connection_id, + client_settings=client_settings, timeout=timeout, max_retries=max_retries, - headers=default_headers, + default_headers=default_headers, + captured_headers=captured_headers, retry_config=retry_config, logger=logger, - auth=client_settings.build_auth_pipeline(), event_hooks={ "request": [ OpenAIRequestHandler( @@ -57,29 +74,47 @@ def __init__( class UiPathAsyncOpenAI(AsyncOpenAI): + """Async OpenAI client routed through UiPath LLM Gateway. + + Wraps the standard AsyncOpenAI client to route requests through UiPath's + LLM Gateway while preserving the full OpenAI SDK interface. + + Args: + model_name: The OpenAI model name (e.g., "gpt-4o-2024-11-20"). + byo_connection_id: Bring Your Own connection ID for custom deployments. + client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). + retry_config: Custom retry configuration. + logger: Logger instance for request/response logging. + """ + def __init__( self, *, model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - timeout = kwargs.pop("timeout", None) - max_retries = kwargs.pop("max_retries", None) - default_headers = kwargs.pop("default_headers", None) - httpx_client = UiPathHttpxAsyncClient( + httpx_client = build_httpx_async_client( model_name=model_name, byo_connection_id=byo_connection_id, + client_settings=client_settings, timeout=timeout, max_retries=max_retries, - headers=default_headers, + default_headers=default_headers, + captured_headers=captured_headers, retry_config=retry_config, logger=logger, - auth=client_settings.build_auth_pipeline(), event_hooks={ "request": [ OpenAIRequestHandler( @@ -97,29 +132,47 @@ def __init__( class UiPathAzureOpenAI(AzureOpenAI): + """Azure OpenAI client routed through UiPath LLM Gateway. + + Wraps the AzureOpenAI client to route requests through UiPath's + LLM Gateway while preserving the full Azure OpenAI SDK interface. + + Args: + model_name: The model name (e.g., "gpt-4o-2024-11-20"). + byo_connection_id: Bring Your Own connection ID for custom deployments. + client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). + retry_config: Custom retry configuration. + logger: Logger instance for request/response logging. + """ + def __init__( self, *, model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - timeout = kwargs.pop("timeout", None) - max_retries = kwargs.pop("max_retries", None) - default_headers = kwargs.pop("default_headers", None) - httpx_client = UiPathHttpxClient( + httpx_client = build_httpx_client( model_name=model_name, byo_connection_id=byo_connection_id, + client_settings=client_settings, timeout=timeout, max_retries=max_retries, - headers=default_headers, + default_headers=default_headers, + captured_headers=captured_headers, retry_config=retry_config, logger=logger, - auth=client_settings.build_auth_pipeline(), event_hooks={ "request": [ OpenAIRequestHandler( @@ -138,29 +191,47 @@ def __init__( class UiPathAsyncAzureOpenAI(AsyncAzureOpenAI): + """Async Azure OpenAI client routed through UiPath LLM Gateway. + + Wraps the AsyncAzureOpenAI client to route requests through UiPath's + LLM Gateway while preserving the full Azure OpenAI SDK interface. + + Args: + model_name: The model name (e.g., "gpt-4o-2024-11-20"). + byo_connection_id: Bring Your Own connection ID for custom deployments. + client_settings: UiPath client settings. Defaults to environment-based settings. + timeout: Client-side request timeout in seconds. + max_retries: Maximum retry attempts for failed requests. + default_headers: Additional headers to include in requests. + captured_headers: Response header prefixes to capture (case-insensitive). + retry_config: Custom retry configuration. + logger: Logger instance for request/response logging. + """ + def __init__( self, *, model_name: str, byo_connection_id: str | None = None, client_settings: UiPathBaseSettings | None = None, + timeout: float | None = None, + max_retries: int | None = None, + default_headers: Mapping[str, str] | None = None, + captured_headers: Sequence[str] = ("x-uipath-",), retry_config: RetryConfig | None = None, logger: logging.Logger | None = None, - **kwargs: Any, ): client_settings = client_settings or get_default_client_settings() - timeout = kwargs.pop("timeout", None) - max_retries = kwargs.pop("max_retries", None) - default_headers = kwargs.pop("default_headers", None) - httpx_client = UiPathHttpxAsyncClient( + httpx_client = build_httpx_async_client( model_name=model_name, byo_connection_id=byo_connection_id, + client_settings=client_settings, timeout=timeout, max_retries=max_retries, - headers=default_headers, + default_headers=default_headers, + captured_headers=captured_headers, retry_config=retry_config, logger=logger, - auth=client_settings.build_auth_pipeline(), event_hooks={ "request": [ OpenAIRequestHandler( diff --git a/src/uipath/llm_client/clients/openai/utils.py b/src/uipath/llm_client/clients/openai/utils.py index 40c0473..5674b6f 100644 --- a/src/uipath/llm_client/clients/openai/utils.py +++ b/src/uipath/llm_client/clients/openai/utils.py @@ -1,8 +1,8 @@ from httpx import URL, Request -from uipath.llm_client.httpx_client import build_routing_headers from uipath.llm_client.settings.base import UiPathAPIConfig, UiPathBaseSettings from uipath.llm_client.settings.constants import ApiFlavor, ApiType, RoutingMode, VendorType +from uipath.llm_client.utils.headers import build_routing_headers class OpenAIRequestHandler: @@ -22,55 +22,35 @@ def __init__( freeze_base_url=False, ) + def _apply_routing(self, request: Request, api_config: UiPathAPIConfig) -> None: + """Apply UiPath routing headers and URL rewriting to the request.""" + request.headers.update( + build_routing_headers( + model_name=self.model_name, + byo_connection_id=self.byo_connection_id, + api_config=api_config, + ) + ) + request.url = URL( + self.client_settings.build_base_url(model_name=self.model_name, api_config=api_config) + ) + def fix_url_and_headers(self, request: Request): if request.url.path.endswith("/completions"): api_config = self.base_api_config.model_copy( update={"api_flavor": ApiFlavor.CHAT_COMPLETIONS, "api_type": ApiType.COMPLETIONS} ) - request.headers.update( - build_routing_headers( - model_name=self.model_name, - byo_connection_id=self.byo_connection_id, - api_config=api_config, - ) - ) - request.url = URL( - self.client_settings.build_base_url( - model_name=self.model_name, api_config=api_config - ) - ) + self._apply_routing(request, api_config) elif request.url.path.endswith("/responses"): api_config = self.base_api_config.model_copy( update={"api_flavor": ApiFlavor.RESPONSES, "api_type": ApiType.COMPLETIONS} ) - request.headers.update( - build_routing_headers( - model_name=self.model_name, - byo_connection_id=self.byo_connection_id, - api_config=api_config, - ) - ) - request.url = URL( - self.client_settings.build_base_url( - model_name=self.model_name, api_config=api_config - ) - ) + self._apply_routing(request, api_config) elif request.url.path.endswith("/embeddings"): api_config = self.base_api_config.model_copy(update={"api_type": ApiType.EMBEDDINGS}) - request.headers.update( - build_routing_headers( - model_name=self.model_name, - byo_connection_id=self.byo_connection_id, - api_config=api_config, - ) - ) - request.url = URL( - self.client_settings.build_base_url( - model_name=self.model_name, api_config=api_config - ) - ) + self._apply_routing(request, api_config) else: - raise ValueError(f"Unsupported API endpoint: {request.url.path}") + raise ValueError(f"Unrecognized API endpoint '{request.url.path}'") async def fix_url_and_headers_async(self, request: Request): self.fix_url_and_headers(request) diff --git a/src/uipath/llm_client/clients/utils.py b/src/uipath/llm_client/clients/utils.py new file mode 100644 index 0000000..639668d --- /dev/null +++ b/src/uipath/llm_client/clients/utils.py @@ -0,0 +1,107 @@ +"""Shared utilities for building UiPath-configured httpx clients.""" + +import logging +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +from uipath.llm_client.httpx_client import UiPathHttpxAsyncClient, UiPathHttpxClient +from uipath.llm_client.settings.base import UiPathAPIConfig, UiPathBaseSettings +from uipath.llm_client.utils.retry import RetryConfig + + +def build_httpx_client( + *, + model_name: str, + byo_connection_id: str | None, + client_settings: UiPathBaseSettings, + timeout: float | None, + max_retries: int | None, + default_headers: Mapping[str, str] | None, + captured_headers: Sequence[str], + retry_config: RetryConfig | None, + logger: logging.Logger | None, + api_config: UiPathAPIConfig | None = None, + event_hooks: dict[str, list[Callable[..., Any]]] | None = None, +) -> UiPathHttpxClient: + """Build a sync UiPath httpx client with auth, routing headers, and retry. + + When *api_config* is provided the base URL and routing/auth headers are + derived from *client_settings*. When it is ``None`` (e.g. for OpenAI + clients that resolve routing per-request via event hooks) those are + omitted and only the auth pipeline, default headers and retry transport + are configured. + """ + headers: dict[str, str] = {**(default_headers or {})} + kwargs: dict[str, Any] = {} + + if api_config is not None: + headers.update( + client_settings.build_auth_headers(model_name=model_name, api_config=api_config) + ) + kwargs["base_url"] = client_settings.build_base_url( + model_name=model_name, api_config=api_config + ) + + if event_hooks is not None: + kwargs["event_hooks"] = event_hooks + + return UiPathHttpxClient( + model_name=model_name, + byo_connection_id=byo_connection_id, + api_config=api_config, + timeout=timeout, + max_retries=max_retries, + retry_config=retry_config, + headers=headers, + captured_headers=captured_headers, + logger=logger, + auth=client_settings.build_auth_pipeline(), + **kwargs, + ) + + +def build_httpx_async_client( + *, + model_name: str, + byo_connection_id: str | None, + client_settings: UiPathBaseSettings, + timeout: float | None, + max_retries: int | None, + default_headers: Mapping[str, str] | None, + captured_headers: Sequence[str], + retry_config: RetryConfig | None, + logger: logging.Logger | None, + api_config: UiPathAPIConfig | None = None, + event_hooks: dict[str, list[Callable[..., Any]]] | None = None, +) -> UiPathHttpxAsyncClient: + """Build an async UiPath httpx client with auth, routing headers, and retry. + + See :func:`build_httpx_client` for parameter details. + """ + headers: dict[str, str] = {**(default_headers or {})} + kwargs: dict[str, Any] = {} + + if api_config is not None: + headers.update( + client_settings.build_auth_headers(model_name=model_name, api_config=api_config) + ) + kwargs["base_url"] = client_settings.build_base_url( + model_name=model_name, api_config=api_config + ) + + if event_hooks is not None: + kwargs["event_hooks"] = event_hooks + + return UiPathHttpxAsyncClient( + model_name=model_name, + byo_connection_id=byo_connection_id, + api_config=api_config, + timeout=timeout, + max_retries=max_retries, + retry_config=retry_config, + headers=headers, + captured_headers=captured_headers, + logger=logger, + auth=client_settings.build_auth_pipeline(), + **kwargs, + ) diff --git a/src/uipath/llm_client/httpx_client.py b/src/uipath/llm_client/httpx_client.py index 1250465..83d2360 100644 --- a/src/uipath/llm_client/httpx_client.py +++ b/src/uipath/llm_client/httpx_client.py @@ -21,7 +21,7 @@ """ import logging -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from typing import Any from httpx import ( @@ -72,7 +72,7 @@ class UiPathHttpxClient(Client): """ _streaming_header: str = "X-UiPath-Streaming-Enabled" - _default_headers: Mapping[str, str] = { + _default_headers: dict[str, str] = { "X-UiPath-LLMGateway-TimeoutSeconds": "295", # server side timeout, default is 10, maximum is 300 # "X-UiPath-LLMGateway-AllowFull4xxResponse": "true", # allow full 4xx responses (default is false) — removed from default to avoid PII leakage in logs } @@ -100,7 +100,8 @@ def __init__( captured_headers: Case-insensitive header name prefixes to capture from responses. Captured headers are stored in a ContextVar and can be retrieved with get_captured_response_headers(). Defaults to ("x-uipath-",). - max_retries: Maximum retry attempts for failed requests. Defaults to 1. + max_retries: Maximum retry attempts for failed requests. Defaults to 0 + (retries disabled). Set to a positive integer to enable retries. retry_config: Custom retry configuration (backoff, retryable status codes). logger: Logger instance for request/response logging. **kwargs: Additional arguments passed to httpx.Client (e.g., base_url, @@ -114,7 +115,7 @@ def __init__( # Extract httpx.Client params that we need to modify headers: HeaderTypes | None = kwargs.pop("headers", None) transport: BaseTransport | None = kwargs.pop("transport", None) - event_hooks: Mapping[str, list[Callable[..., Any]]] | None = kwargs.pop("event_hooks", None) + event_hooks: dict[str, list[Callable[..., Any]]] | None = kwargs.pop("event_hooks", None) # Merge headers: default -> api_config -> user provided merged_headers = Headers(self._default_headers) @@ -131,7 +132,7 @@ def __init__( # Setup retry transport if not provided if transport is None: transport = RetryableHTTPTransport( - max_retries=max_retries or 0, + max_retries=max_retries if max_retries is not None else 0, retry_config=retry_config, logger=logger, ) @@ -147,8 +148,8 @@ def __init__( "request": [], "response": [], } - event_hooks["request"].append(logging_config.log_request_duration) - event_hooks["response"].append(logging_config.log_response_duration) + event_hooks.setdefault("request", []).append(logging_config.log_request_duration) + event_hooks.setdefault("response", []).append(logging_config.log_response_duration) event_hooks["response"].append(logging_config.log_error) # setup ssl context @@ -205,7 +206,7 @@ class UiPathHttpxAsyncClient(AsyncClient): """ _streaming_header: str = "X-UiPath-Streaming-Enabled" - _default_headers: Mapping[str, str] = { + _default_headers: dict[str, str] = { "X-UiPath-LLMGateway-TimeoutSeconds": "295", # server side timeout, default is 10, maximum is 300 # "X-UiPath-LLMGateway-AllowFull4xxResponse": "true", # allow full 4xx responses (default is false) — removed from default to avoid PII leakage in logs } @@ -233,7 +234,8 @@ def __init__( captured_headers: Case-insensitive header name prefixes to capture from responses. Captured headers are stored in a ContextVar and can be retrieved with get_captured_response_headers(). Defaults to ("x-uipath-",). - max_retries: Maximum retry attempts for failed requests. Defaults to 1. + max_retries: Maximum retry attempts for failed requests. Defaults to 0 + (retries disabled). Set to a positive integer to enable retries. retry_config: Custom retry configuration (backoff, retryable status codes). logger: Logger instance for request/response logging. **kwargs: Additional arguments passed to httpx.AsyncClient (e.g., base_url, @@ -247,7 +249,7 @@ def __init__( # Extract httpx.AsyncClient params that we need to modify headers: HeaderTypes | None = kwargs.pop("headers", None) transport: AsyncBaseTransport | None = kwargs.pop("transport", None) - event_hooks: Mapping[str, list[Callable[..., Any]]] | None = kwargs.pop("event_hooks", None) + event_hooks: dict[str, list[Callable[..., Any]]] | None = kwargs.pop("event_hooks", None) # Merge headers: default -> api_config -> user provided merged_headers = Headers(self._default_headers) @@ -264,7 +266,7 @@ def __init__( # Setup retry transport if not provided if transport is None: transport = RetryableAsyncHTTPTransport( - max_retries=max_retries or 0, + max_retries=max_retries if max_retries is not None else 0, retry_config=retry_config, logger=logger, ) @@ -280,8 +282,8 @@ def __init__( "request": [], "response": [], } - event_hooks["request"].append(logging_config.alog_request_duration) - event_hooks["response"].append(logging_config.alog_response_duration) + event_hooks.setdefault("request", []).append(logging_config.alog_request_duration) + event_hooks.setdefault("response", []).append(logging_config.alog_response_duration) event_hooks["response"].append(logging_config.alog_error) # setup ssl context diff --git a/src/uipath/llm_client/settings/base.py b/src/uipath/llm_client/settings/base.py index f6c2326..c2c974c 100644 --- a/src/uipath/llm_client/settings/base.py +++ b/src/uipath/llm_client/settings/base.py @@ -2,7 +2,7 @@ Base Settings Module for UiPath LLM Client This module defines the abstract base classes and data models for UiPath API settings. -Concrete implementations are provided in the `agenthub` and `llmgateway` submodules. +Concrete implementations are provided in the `platform` and `llmgateway` submodules. """ from abc import ABC, abstractmethod diff --git a/src/uipath/llm_client/settings/constants.py b/src/uipath/llm_client/settings/constants.py index 27232ca..f8acf52 100644 --- a/src/uipath/llm_client/settings/constants.py +++ b/src/uipath/llm_client/settings/constants.py @@ -28,7 +28,7 @@ class ApiFlavor(StrEnum): ANTHROPIC_CLAUDE = "anthropic-claude" -_API_FLAVOR_TO_VENDOR_TYPE: dict[ApiFlavor, VendorType] = { +API_FLAVOR_TO_VENDOR_TYPE: dict[ApiFlavor, VendorType] = { ApiFlavor.CHAT_COMPLETIONS: VendorType.OPENAI, ApiFlavor.RESPONSES: VendorType.OPENAI, ApiFlavor.GENERATE_CONTENT: VendorType.VERTEXAI, diff --git a/src/uipath/llm_client/settings/llmgateway/auth.py b/src/uipath/llm_client/settings/llmgateway/auth.py index aa8dc37..75728e4 100644 --- a/src/uipath/llm_client/settings/llmgateway/auth.py +++ b/src/uipath/llm_client/settings/llmgateway/auth.py @@ -5,18 +5,29 @@ from uipath.llm_client.settings.llmgateway.settings import LLMGatewayBaseSettings from uipath.llm_client.settings.llmgateway.utils import LLMGatewayEndpoints from uipath.llm_client.settings.utils import SingletonMeta +from uipath.llm_client.utils.ssl_config import get_httpx_ssl_client_kwargs class LLMGatewayS2SAuth(Auth, metaclass=SingletonMeta): """Bearer authentication handler with automatic token refresh. - Singleton class that reuses the same token across all requests to minimize - token generation overhead. Automatically refreshes the token on 401 responses. + Singleton keyed by (client_id, client_secret) so that clients sharing the + same credentials reuse one token while different credentials get separate + instances. Automatically refreshes the token on 401 responses. Does not raise errors on token retrieval failures — the request is sent without a valid token and the downstream client handles the error response. """ + @classmethod + def _singleton_cache_key(cls, settings: LLMGatewayBaseSettings) -> tuple: + """Derive a cache key from the credentials so different settings get different instances.""" + return ( + settings.base_url, + settings.client_id.get_secret_value() if settings.client_id else None, + settings.client_secret.get_secret_value() if settings.client_secret else None, + ) + def __init__( self, settings: LLMGatewayBaseSettings, @@ -51,7 +62,7 @@ def get_llmgw_token( grant_type="client_credentials", ) try: - with Client() as http_client: + with Client(**get_httpx_ssl_client_kwargs()) as http_client: response = http_client.post(url_get_token, data=token_credentials) if response.is_error: return None diff --git a/src/uipath/llm_client/settings/llmgateway/settings.py b/src/uipath/llm_client/settings/llmgateway/settings.py index 1076668..8c0b491 100644 --- a/src/uipath/llm_client/settings/llmgateway/settings.py +++ b/src/uipath/llm_client/settings/llmgateway/settings.py @@ -7,8 +7,10 @@ from typing_extensions import override from uipath.llm_client.settings.base import UiPathAPIConfig, UiPathBaseSettings +from uipath.llm_client.settings.constants import ApiType, RoutingMode from uipath.llm_client.settings.llmgateway.utils import LLMGatewayEndpoints from uipath.llm_client.utils.exceptions import UiPathAPIError +from uipath.llm_client.utils.ssl_config import get_httpx_ssl_client_kwargs class LLMGatewayBaseSettings(UiPathBaseSettings): @@ -73,13 +75,19 @@ def build_base_url( model_name: str | None = None, api_config: UiPathAPIConfig | None = None, ) -> str: + if api_config is None: + raise ValueError("api_config is required for LLMGatewaySettings.build_base_url") + if api_config.routing_mode is None: + raise ValueError( + "api_config.routing_mode is required for LLMGatewaySettings.build_base_url" + ) base_url = f"{self.base_url}/{self.org_id}/{self.tenant_id}" - if api_config is not None and api_config.routing_mode == "normalized": - url = f"{base_url}/{LLMGatewayEndpoints.NORMALIZED_ENDPOINT.value.format(api_type='chat/completions' if api_config.api_type == 'completions' else 'embeddings')}" - else: - if api_config is None: - raise ValueError("api_config is required for passthrough routing_mode") + if api_config.routing_mode == RoutingMode.NORMALIZED: + url = f"{base_url}/{LLMGatewayEndpoints.NORMALIZED_ENDPOINT.value.format(api_type='chat/completions' if api_config.api_type == ApiType.COMPLETIONS else 'embeddings')}" + elif api_config.routing_mode == RoutingMode.PASSTHROUGH: url = f"{base_url}/{LLMGatewayEndpoints.PASSTHROUGH_ENDPOINT.value.format(vendor=api_config.vendor_type, model=model_name, api_type=api_config.api_type)}" + else: + raise ValueError(f"Unsupported routing_mode: {api_config.routing_mode}") return url @override @@ -108,7 +116,11 @@ def build_auth_headers( @override def get_available_models(self) -> list[dict[str, Any]]: discovery_url = f"{self.base_url}/{self.org_id}/{self.tenant_id}/{LLMGatewayEndpoints.DISCOVERY_ENDPOINT.value}" - with Client(auth=self.build_auth_pipeline(), headers=self.build_auth_headers()) as client: + with Client( + auth=self.build_auth_pipeline(), + headers=self.build_auth_headers(), + **get_httpx_ssl_client_kwargs(), + ) as client: response = client.get(discovery_url) if response.is_error: raise UiPathAPIError.from_response(response) @@ -116,17 +128,23 @@ def get_available_models(self) -> list[dict[str, Any]]: @override def validate_byo_model(self, model_info: dict[str, Any]) -> None: + """Validate that the model is a BYOM model. + + Note: This method may mutate ``self.operation_code`` as a side effect + when no operation code was explicitly configured but the model provides + available codes. + """ byom_details = model_info.get("byomDetails", {}) operation_codes = byom_details.get("availableOperationCodes", []) if self.operation_code and self.operation_code not in operation_codes: raise ValueError( - f"The operation code {self.operation_code} is not allowed for the model {model_info['modelName']}" + f"The operation code {self.operation_code} is not allowed for the model {model_info.get('modelName', 'unknown')}" ) if not self.operation_code and len(operation_codes) > 0: if len(operation_codes) > 1: logging.warning( "Multiple operation codes are allowed for the model %s, but no operation code was provided, picking the first one available: %s", - model_info["modelName"], + model_info.get("modelName", "unknown"), operation_codes[0], ) self.operation_code = operation_codes[0] diff --git a/src/uipath/llm_client/settings/platform/auth.py b/src/uipath/llm_client/settings/platform/auth.py index fe1482f..b6d35fb 100644 --- a/src/uipath/llm_client/settings/platform/auth.py +++ b/src/uipath/llm_client/settings/platform/auth.py @@ -12,10 +12,22 @@ class PlatformAuth(Auth, metaclass=SingletonMeta): """Bearer authentication handler with automatic token refresh. - Singleton class that stores access_token and refresh_token directly, - reusing them across all requests. Automatically refreshes on 401 responses. + Singleton keyed by (base_url, organization_id, tenant_id, access_token) + so that clients sharing the same credentials reuse one token while + different credentials get separate instances. Automatically refreshes + on 401 responses. """ + @classmethod + def _singleton_cache_key(cls, settings: PlatformBaseSettings) -> tuple: + """Derive a cache key from the credentials so different settings get different instances.""" + return ( + settings.base_url, + settings.organization_id, + settings.tenant_id, + settings.access_token.get_secret_value() if settings.access_token else None, + ) + def __init__( self, settings: PlatformBaseSettings, diff --git a/src/uipath/llm_client/settings/platform/settings.py b/src/uipath/llm_client/settings/platform/settings.py index bbcc4d8..a20abdc 100644 --- a/src/uipath/llm_client/settings/platform/settings.py +++ b/src/uipath/llm_client/settings/platform/settings.py @@ -5,9 +5,11 @@ from pydantic import Field, SecretStr, model_validator from typing_extensions import override +from uipath.platform import UiPath from uipath.platform.common import EndpointManager from uipath.llm_client.settings.base import UiPathAPIConfig, UiPathBaseSettings +from uipath.llm_client.settings.constants import ApiType, RoutingMode from uipath.llm_client.settings.platform.utils import is_token_expired, parse_access_token @@ -72,7 +74,7 @@ def validate_environment(self) -> Self: ) parsed_token_data = parse_access_token(access_token) - self.client_id = parsed_token_data.get("client_id", None) + self.client_id = parsed_token_data.get("client_id") return self @staticmethod @@ -97,19 +99,33 @@ def build_base_url( api_config: UiPathAPIConfig | None = None, ) -> str: """Build the base URL for API requests.""" - assert model_name is not None - assert api_config is not None - if api_config.routing_mode == "normalized" and api_config.api_type == "completions": + if model_name is None: + raise ValueError("model_name is required for PlatformBaseSettings.build_base_url") + if api_config is None: + raise ValueError("api_config is required for PlatformBaseSettings.build_base_url") + if ( + api_config.routing_mode == RoutingMode.NORMALIZED + and api_config.api_type == ApiType.COMPLETIONS + ): url = f"{self.base_url}/{EndpointManager.get_normalized_endpoint()}" - elif api_config.routing_mode == "normalized" and api_config.api_type == "embeddings": + elif ( + api_config.routing_mode == RoutingMode.NORMALIZED + and api_config.api_type == ApiType.EMBEDDINGS + ): raise ValueError( "Normalized embeddings are not supported on UiPath Platform (AgentHub/Orchestrator). " "Use passthrough routing mode for embeddings instead." ) - elif api_config.routing_mode == "passthrough" and api_config.api_type == "completions": + elif ( + api_config.routing_mode == RoutingMode.PASSTHROUGH + and api_config.api_type == ApiType.COMPLETIONS + ): endpoint = EndpointManager.get_vendor_endpoint() url = f"{self.base_url}/{self._format_endpoint(endpoint, model=model_name, vendor=api_config.vendor_type, api_version=api_config.api_version)}" - elif api_config.routing_mode == "passthrough" and api_config.api_type == "embeddings": + elif ( + api_config.routing_mode == RoutingMode.PASSTHROUGH + and api_config.api_type == ApiType.EMBEDDINGS + ): if api_config.vendor_type is not None and api_config.vendor_type != "openai": raise ValueError( f"Platform embeddings endpoint only supports OpenAI-compatible models, " @@ -148,7 +164,6 @@ def build_auth_headers( @override def get_available_models(self) -> list[dict[str, Any]]: - from uipath.platform import UiPath models = UiPath().agenthub.get_available_llm_models( headers=dict(self.build_auth_headers()), diff --git a/src/uipath/llm_client/settings/platform/utils.py b/src/uipath/llm_client/settings/platform/utils.py index 559250b..1daaa0e 100644 --- a/src/uipath/llm_client/settings/platform/utils.py +++ b/src/uipath/llm_client/settings/platform/utils.py @@ -1,16 +1,42 @@ import base64 import json import time +from typing import Any -def parse_access_token(access_token: str): +def parse_access_token(access_token: str) -> dict[str, Any]: + """Parse a JWT access token and return the payload as a dict. + + Args: + access_token: A JWT token string (header.payload.signature). + + Returns: + The decoded payload as a dictionary. + + Raises: + ValueError: If the token is malformed or cannot be decoded. + """ token_parts = access_token.split(".") if len(token_parts) < 2: - raise Exception("Invalid access token") - payload = base64.urlsafe_b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)) - return json.loads(payload) + raise ValueError("Invalid access token: expected JWT with at least 2 dot-separated parts") + try: + payload = base64.urlsafe_b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)) + return json.loads(payload) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Invalid access token: failed to decode payload: {e}") from e def is_token_expired(token: str) -> bool: + """Check whether a JWT access token has expired. + + Args: + token: A JWT token string. + + Returns: + True if the token is expired, False if it is still valid or has no ``exp`` claim. + """ token_data = parse_access_token(token) - return token_data["exp"] < time.time() + exp = token_data.get("exp") + if exp is None: + return False + return exp < time.time() diff --git a/src/uipath/llm_client/settings/utils.py b/src/uipath/llm_client/settings/utils.py index 3892d4c..4596e5d 100644 --- a/src/uipath/llm_client/settings/utils.py +++ b/src/uipath/llm_client/settings/utils.py @@ -3,15 +3,33 @@ class SingletonMeta(type): - """Metaclass for creating singleton classes. Used to keep global configs shared between instances.""" + """Metaclass for creating singleton classes keyed by (class, cache_key). - _instances: dict[type, Any] = {} + Classes using this metaclass can define a ``_singleton_cache_key`` classmethod + that derives a hashable key from the constructor arguments. When the same + key is seen again the cached instance is returned and ``__init__`` is + **not** re-invoked. + + If the class does not define ``_singleton_cache_key``, the class itself is + used as the sole key (original singleton-per-class behaviour). + + Used to share access-tokens / auth state between multiple HTTP clients that + are configured with the same credentials. + """ + + _instances: dict[tuple[type, Any], Any] = {} _lock: threading.Lock = threading.Lock() def __call__(cls, *args: Any, **kwargs: Any) -> Any: - if cls not in SingletonMeta._instances: + key_fn = getattr(cls, "_singleton_cache_key", None) + if key_fn is not None: + cache_key = (cls, key_fn(*args, **kwargs)) + else: + cache_key = (cls, None) + + if cache_key not in SingletonMeta._instances: with SingletonMeta._lock: - if cls not in SingletonMeta._instances: + if cache_key not in SingletonMeta._instances: instance = super().__call__(*args, **kwargs) - SingletonMeta._instances[cls] = instance - return SingletonMeta._instances[cls] + SingletonMeta._instances[cache_key] = instance + return SingletonMeta._instances[cache_key] diff --git a/src/uipath/llm_client/utils/exceptions.py b/src/uipath/llm_client/utils/exceptions.py index 4c7d311..3e13057 100644 --- a/src/uipath/llm_client/utils/exceptions.py +++ b/src/uipath/llm_client/utils/exceptions.py @@ -26,6 +26,7 @@ """ from json import JSONDecodeError +from typing import Literal from httpx import HTTPStatusError, Request, Response @@ -96,43 +97,43 @@ def from_response(cls, response: Response, request: Request | None = None) -> "U class UiPathBadRequestError(UiPathAPIError): """HTTP 400 Bad Request error.""" - status_code: int = 400 + status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathAuthenticationError(UiPathAPIError): """HTTP 401 Unauthorized error.""" - status_code: int = 401 + status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathPermissionDeniedError(UiPathAPIError): """HTTP 403 Forbidden error.""" - status_code: int = 403 + status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathNotFoundError(UiPathAPIError): """HTTP 404 Not Found error.""" - status_code: int = 404 + status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathConflictError(UiPathAPIError): """HTTP 409 Conflict error.""" - status_code: int = 409 + status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathRequestTooLargeError(UiPathAPIError): """HTTP 413 Payload Too Large error.""" - status_code: int = 413 + status_code: Literal[413] = 413 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathUnprocessableEntityError(UiPathAPIError): """HTTP 422 Unprocessable Entity error.""" - status_code: int = 422 + status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathRateLimitError(UiPathAPIError): @@ -142,7 +143,7 @@ class UiPathRateLimitError(UiPathAPIError): retry_after: Seconds to wait before retrying (from Retry-After header), or None. """ - status_code: int = 429 + status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride] def __init__( self, @@ -178,9 +179,9 @@ def _parse_retry_after(response: Response) -> float | None: from datetime import datetime, timezone # Check both header variants (case-insensitive in httpx) - retry_after_value = response.headers.get("retry-after") or response.headers.get( - "x-retry-after" - ) + retry_after_value = response.headers.get("retry-after") + if retry_after_value is None: + retry_after_value = response.headers.get("x-retry-after") if retry_after_value is None: return None @@ -207,25 +208,25 @@ def _parse_retry_after(response: Response) -> float | None: class UiPathInternalServerError(UiPathAPIError): """HTTP 500 Internal Server Error.""" - status_code: int = 500 + status_code: Literal[500] = 500 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathServiceUnavailableError(UiPathAPIError): """HTTP 503 Service Unavailable error.""" - status_code: int = 503 + status_code: Literal[503] = 503 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathGatewayTimeoutError(UiPathAPIError): """HTTP 504 Gateway Timeout error.""" - status_code: int = 504 + status_code: Literal[504] = 504 # pyright: ignore[reportIncompatibleVariableOverride] class UiPathTooManyRequestsError(UiPathAPIError): """HTTP 529 Too Many Requests (Anthropic overload) error.""" - status_code: int = 529 + status_code: Literal[529] = 529 # pyright: ignore[reportIncompatibleVariableOverride] _STATUS_CODE_TO_EXCEPTION: dict[int, type[UiPathAPIError]] = { diff --git a/src/uipath/llm_client/utils/headers.py b/src/uipath/llm_client/utils/headers.py index 6a5bea2..8bb7fdd 100644 --- a/src/uipath/llm_client/utils/headers.py +++ b/src/uipath/llm_client/utils/headers.py @@ -4,13 +4,14 @@ from httpx import Headers from uipath.llm_client.settings.base import UiPathAPIConfig +from uipath.llm_client.settings.constants import ApiType, RoutingMode -_CAPTURED_RESPONSE_HEADERS: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar( - "_captured_response_headers", default={} +_CAPTURED_RESPONSE_HEADERS: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( + "_captured_response_headers", default=None ) -_DYNAMIC_REQUEST_HEADERS: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar( - "_dynamic_request_headers", default={} +_DYNAMIC_REQUEST_HEADERS: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( + "_dynamic_request_headers", default=None ) @@ -20,10 +21,12 @@ def get_captured_response_headers() -> dict[str, str]: Returns an empty dict if no headers have been captured or if called outside a capture scope. """ - return dict(_CAPTURED_RESPONSE_HEADERS.get()) + return dict(_CAPTURED_RESPONSE_HEADERS.get() or {}) -def set_captured_response_headers(headers: dict[str, str]) -> contextvars.Token[dict[str, str]]: +def set_captured_response_headers( + headers: dict[str, str], +) -> contextvars.Token[dict[str, str] | None]: """Set captured response headers for the current context.""" return _CAPTURED_RESPONSE_HEADERS.set(headers) @@ -33,10 +36,12 @@ def get_dynamic_request_headers() -> dict[str, str]: Returns an empty dict if no dynamic headers have been set in this context. """ - return dict(_DYNAMIC_REQUEST_HEADERS.get()) + return dict(_DYNAMIC_REQUEST_HEADERS.get() or {}) -def set_dynamic_request_headers(headers: dict[str, str]) -> contextvars.Token[dict[str, str]]: +def set_dynamic_request_headers( + headers: dict[str, str], +) -> contextvars.Token[dict[str, str] | None]: """Set headers to be injected into the next outgoing request.""" return _DYNAMIC_REQUEST_HEADERS.set(headers) @@ -74,9 +79,12 @@ def build_routing_headers( """ headers: dict[str, str] = {} if api_config is not None: - if api_config.routing_mode == "normalized" and model_name is not None: + if api_config.routing_mode == RoutingMode.NORMALIZED and model_name is not None: headers["X-UiPath-LlmGateway-NormalizedApi-ModelName"] = model_name - elif api_config.routing_mode == "passthrough" and api_config.api_type == "completions": + elif ( + api_config.routing_mode == RoutingMode.PASSTHROUGH + and api_config.api_type == ApiType.COMPLETIONS + ): if api_config.api_flavor is not None: headers["X-UiPath-LlmGateway-ApiFlavor"] = api_config.api_flavor if api_config.api_version is not None: diff --git a/src/uipath/llm_client/utils/retry.py b/src/uipath/llm_client/utils/retry.py index 941dfbd..32c89a3 100644 --- a/src/uipath/llm_client/utils/retry.py +++ b/src/uipath/llm_client/utils/retry.py @@ -248,6 +248,9 @@ def handle_request(self, request: Request) -> Response: The httpx Response. Returns error responses after retries are exhausted instead of raising exceptions. """ + if self.retryer is None: + return super().handle_request(request) + parent_handle = super().handle_request def _send() -> Response: @@ -257,10 +260,7 @@ def _send() -> Response: return response try: - if self.retryer is not None: - return self.retryer(_send) - else: - return _send() + return self.retryer(_send) except UiPathAPIError as e: return e.response @@ -312,6 +312,9 @@ async def handle_async_request(self, request: Request) -> Response: The httpx Response. Returns error responses after retries are exhausted instead of raising exceptions. """ + if self.retryer is None: + return await super().handle_async_request(request) + parent_handle = super().handle_async_request async def _send() -> Response: @@ -321,10 +324,7 @@ async def _send() -> Response: return response try: - if self.retryer is not None: - return await self.retryer(_send) - else: - return await _send() + return await self.retryer(_send) except UiPathAPIError as e: return e.response diff --git a/src/uipath/llm_client/utils/ssl_config.py b/src/uipath/llm_client/utils/ssl_config.py index 31d3c61..4dd1c35 100644 --- a/src/uipath/llm_client/utils/ssl_config.py +++ b/src/uipath/llm_client/utils/ssl_config.py @@ -3,7 +3,7 @@ from typing import Any -def expand_path(path): +def expand_path(path: str | None) -> str | None: """Expand environment variables and user home directory in path.""" if not path: return path @@ -14,24 +14,40 @@ def expand_path(path): return path -def create_ssl_context(): +def create_ssl_context() -> ssl.SSLContext: + """Create an SSL context using system certificates. + + Tries ``truststore`` first for native system certificate support. + Falls back to ``certifi`` for bundled Mozilla CA certificates. + + Raises: + ImportError: If neither ``truststore`` nor ``certifi`` is installed. + """ # Try truststore first (system certificates) try: import truststore return truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) except ImportError: - # Fallback to manual certificate configuration + pass + + # Fallback to manual certificate configuration + try: import certifi + except ImportError: + raise ImportError( + "SSL certificate support requires either 'truststore' or 'certifi'. " + "Install one with: pip install truststore or pip install certifi" + ) - ssl_cert_file = expand_path(os.environ.get("SSL_CERT_FILE")) - requests_ca_bundle = expand_path(os.environ.get("REQUESTS_CA_BUNDLE")) - ssl_cert_dir = expand_path(os.environ.get("SSL_CERT_DIR")) + ssl_cert_file = expand_path(os.environ.get("SSL_CERT_FILE")) + requests_ca_bundle = expand_path(os.environ.get("REQUESTS_CA_BUNDLE")) + ssl_cert_dir = expand_path(os.environ.get("SSL_CERT_DIR")) - return ssl.create_default_context( - cafile=ssl_cert_file or requests_ca_bundle or certifi.where(), - capath=ssl_cert_dir, - ) + return ssl.create_default_context( + cafile=ssl_cert_file or requests_ca_bundle or certifi.where(), + capath=ssl_cert_dir, + ) def get_httpx_ssl_client_kwargs() -> dict[str, Any]: diff --git a/tests/conftest.py b/tests/conftest.py index b334edf..0058c79 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,6 +56,9 @@ def pytest_recording_configure(config, vcr): vcr.register_persister(SQLitePersister) +# Only "llmgw" is parameterized because Platform (agenthub) requires `uipath auth` +# credentials that are not available in CI. Platform-specific logic is tested +# via mocked settings in test_base_client.py. @pytest.fixture(scope="session", params=["llmgw"]) def client_settings(request: pytest.FixtureRequest) -> UiPathBaseSettings: match request.param: diff --git a/tests/core/test_anthropic_client.py b/tests/core/test_anthropic_client.py index e69de29..b39b7bc 100644 --- a/tests/core/test_anthropic_client.py +++ b/tests/core/test_anthropic_client.py @@ -0,0 +1,304 @@ +"""Tests for the Anthropic client module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from uipath.llm_client.clients.anthropic.client import _build_api_config +from uipath.llm_client.settings.constants import ApiType, RoutingMode, VendorType + +MODULE = "uipath.llm_client.clients.anthropic.client" + + +# ============================================================================ +# _build_api_config tests +# ============================================================================ + + +class TestBuildApiConfig: + def test_default_returns_anthropic_passthrough(self): + config = _build_api_config() + assert config.vendor_type == VendorType.ANTHROPIC + assert config.routing_mode == RoutingMode.PASSTHROUGH + assert config.api_type == ApiType.COMPLETIONS + assert config.freeze_base_url is True + + def test_awsbedrock_vendor(self): + config = _build_api_config(vendor_type=VendorType.AWSBEDROCK) + assert config.vendor_type == VendorType.AWSBEDROCK + + def test_vertexai_vendor(self): + config = _build_api_config(vendor_type=VendorType.VERTEXAI) + assert config.vendor_type == VendorType.VERTEXAI + + def test_azure_vendor(self): + config = _build_api_config(vendor_type=VendorType.AZURE) + assert config.vendor_type == VendorType.AZURE + + +# ============================================================================ +# Client initialization tests +# ============================================================================ + + +def _make_sync_httpx_mock(): + mock = MagicMock(spec=httpx.Client) + mock.timeout = httpx.Timeout(None) + mock.headers = httpx.Headers() + return mock + + +def _make_async_httpx_mock(): + mock = MagicMock(spec=httpx.AsyncClient) + mock.timeout = httpx.Timeout(None) + mock.headers = httpx.Headers() + return mock + + +@pytest.fixture +def mock_settings(): + return MagicMock() + + +@pytest.fixture +def _patch_client_deps(mock_settings): + with ( + patch( + f"{MODULE}.build_httpx_client", + side_effect=lambda **kw: _make_sync_httpx_mock(), + ) as sync_mock, + patch( + f"{MODULE}.build_httpx_async_client", + side_effect=lambda **kw: _make_async_httpx_mock(), + ) as async_mock, + patch(f"{MODULE}.get_default_client_settings", return_value=mock_settings), + ): + yield sync_mock, async_mock + + +@pytest.fixture +def _foundry_env(monkeypatch): + monkeypatch.setenv("ANTHROPIC_FOUNDRY_RESOURCE", "test-resource") + + +@pytest.mark.usefixtures("_patch_client_deps") +class TestUiPathAnthropic: + def test_passes_api_key_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropic + + sync_mock, _ = _patch_client_deps + client = UiPathAnthropic(model_name="claude-3-5-sonnet") + assert client.api_key == "PLACEHOLDER" + assert client.max_retries == 0 + sync_mock.assert_called_once() + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropic + + sync_mock, _ = _patch_client_deps + UiPathAnthropic(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = sync_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps") +class TestUiPathAsyncAnthropic: + def test_passes_api_key_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropic + + _, async_mock = _patch_client_deps + client = UiPathAsyncAnthropic(model_name="claude-3-5-sonnet") + assert client.api_key == "PLACEHOLDER" + assert client.max_retries == 0 + async_mock.assert_called_once() + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropic + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropic(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = async_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps") +class TestUiPathAnthropicBedrock: + def test_passes_aws_placeholders_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicBedrock + + sync_mock, _ = _patch_client_deps + client = UiPathAnthropicBedrock(model_name="claude-3-5-sonnet") + assert client.aws_access_key == "PLACEHOLDER" + assert client.aws_secret_key == "PLACEHOLDER" + assert client.aws_region == "PLACEHOLDER" + assert client.max_retries == 0 + sync_mock.assert_called_once() + + def test_uses_awsbedrock_vendor(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicBedrock + + sync_mock, _ = _patch_client_deps + UiPathAnthropicBedrock(model_name="claude-3-5-sonnet") + api_config = sync_mock.call_args.kwargs["api_config"] + assert api_config.vendor_type == VendorType.AWSBEDROCK + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicBedrock + + sync_mock, _ = _patch_client_deps + UiPathAnthropicBedrock(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = sync_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps") +class TestUiPathAsyncAnthropicBedrock: + def test_passes_aws_placeholders_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicBedrock + + _, async_mock = _patch_client_deps + client = UiPathAsyncAnthropicBedrock(model_name="claude-3-5-sonnet") + assert client.aws_access_key == "PLACEHOLDER" + assert client.aws_secret_key == "PLACEHOLDER" + assert client.aws_region == "PLACEHOLDER" + assert client.max_retries == 0 + async_mock.assert_called_once() + + def test_uses_awsbedrock_vendor(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicBedrock + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropicBedrock(model_name="claude-3-5-sonnet") + api_config = async_mock.call_args.kwargs["api_config"] + assert api_config.vendor_type == VendorType.AWSBEDROCK + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicBedrock + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropicBedrock(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = async_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps") +class TestUiPathAnthropicVertex: + def test_passes_vertex_placeholders_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicVertex + + sync_mock, _ = _patch_client_deps + client = UiPathAnthropicVertex(model_name="claude-3-5-sonnet") + assert client.region == "PLACEHOLDER" + assert client.project_id == "PLACEHOLDER" + assert client.max_retries == 0 + sync_mock.assert_called_once() + + def test_uses_vertexai_vendor(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicVertex + + sync_mock, _ = _patch_client_deps + UiPathAnthropicVertex(model_name="claude-3-5-sonnet") + api_config = sync_mock.call_args.kwargs["api_config"] + assert api_config.vendor_type == VendorType.VERTEXAI + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicVertex + + sync_mock, _ = _patch_client_deps + UiPathAnthropicVertex(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = sync_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps") +class TestUiPathAsyncAnthropicVertex: + def test_passes_vertex_placeholders_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicVertex + + _, async_mock = _patch_client_deps + client = UiPathAsyncAnthropicVertex(model_name="claude-3-5-sonnet") + assert client.region == "PLACEHOLDER" + assert client.project_id == "PLACEHOLDER" + assert client.max_retries == 0 + async_mock.assert_called_once() + + def test_uses_vertexai_vendor(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicVertex + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropicVertex(model_name="claude-3-5-sonnet") + api_config = async_mock.call_args.kwargs["api_config"] + assert api_config.vendor_type == VendorType.VERTEXAI + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicVertex + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropicVertex(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = async_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps", "_foundry_env") +class TestUiPathAnthropicFoundry: + def test_passes_api_key_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicFoundry + + sync_mock, _ = _patch_client_deps + client = UiPathAnthropicFoundry(model_name="claude-3-5-sonnet") + assert client.api_key == "PLACEHOLDER" + assert client.max_retries == 0 + sync_mock.assert_called_once() + + def test_uses_azure_vendor(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicFoundry + + sync_mock, _ = _patch_client_deps + UiPathAnthropicFoundry(model_name="claude-3-5-sonnet") + api_config = sync_mock.call_args.kwargs["api_config"] + assert api_config.vendor_type == VendorType.AZURE + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAnthropicFoundry + + sync_mock, _ = _patch_client_deps + UiPathAnthropicFoundry(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = sync_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" + + +@pytest.mark.usefixtures("_patch_client_deps", "_foundry_env") +class TestUiPathAsyncAnthropicFoundry: + def test_passes_api_key_and_max_retries(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicFoundry + + _, async_mock = _patch_client_deps + client = UiPathAsyncAnthropicFoundry(model_name="claude-3-5-sonnet") + assert client.api_key == "PLACEHOLDER" + assert client.max_retries == 0 + async_mock.assert_called_once() + + def test_uses_azure_vendor(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicFoundry + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropicFoundry(model_name="claude-3-5-sonnet") + api_config = async_mock.call_args.kwargs["api_config"] + assert api_config.vendor_type == VendorType.AZURE + + def test_passes_model_name_and_byo_connection_id(self, _patch_client_deps): + from uipath.llm_client.clients.anthropic.client import UiPathAsyncAnthropicFoundry + + _, async_mock = _patch_client_deps + UiPathAsyncAnthropicFoundry(model_name="claude-3-5-sonnet", byo_connection_id="conn-123") + kwargs = async_mock.call_args.kwargs + assert kwargs["model_name"] == "claude-3-5-sonnet" + assert kwargs["byo_connection_id"] == "conn-123" diff --git a/tests/core/test_base_client.py b/tests/core/test_base_client.py index d1c830f..76bf1d1 100644 --- a/tests/core/test_base_client.py +++ b/tests/core/test_base_client.py @@ -458,8 +458,8 @@ def test_auth_flow_refreshes_on_401(self, llmgw_s2s_env_vars): except StopIteration: pass - def test_auth_singleton_reuses_instance(self, llmgw_env_vars): - """Test that LLMGatewayS2SAuth is a singleton.""" + def test_auth_singleton_reuses_instance_for_same_settings(self, llmgw_env_vars): + """Test that LLMGatewayS2SAuth reuses the same instance for identical settings.""" from uipath.llm_client.settings.llmgateway.auth import LLMGatewayS2SAuth with patch.dict(os.environ, llmgw_env_vars, clear=True): @@ -468,6 +468,20 @@ def test_auth_singleton_reuses_instance(self, llmgw_env_vars): auth2 = LLMGatewayS2SAuth(settings=settings) assert auth1 is auth2 + def test_auth_creates_separate_instances_for_different_settings(self, llmgw_env_vars): + """Test that LLMGatewayS2SAuth creates separate instances for different credentials.""" + from uipath.llm_client.settings.llmgateway.auth import LLMGatewayS2SAuth + + env1 = {**llmgw_env_vars, "LLMGW_CLIENT_ID": "id-a", "LLMGW_CLIENT_SECRET": "secret-a"} + env2 = {**llmgw_env_vars, "LLMGW_CLIENT_ID": "id-b", "LLMGW_CLIENT_SECRET": "secret-b"} + with patch.dict(os.environ, env1, clear=True): + settings1 = LLMGatewaySettings() + with patch.dict(os.environ, env2, clear=True): + settings2 = LLMGatewaySettings() + auth1 = LLMGatewayS2SAuth(settings=settings1) + auth2 = LLMGatewayS2SAuth(settings=settings2) + assert auth1 is not auth2 + # ============================================================================ # Test PlatformSettings @@ -637,17 +651,17 @@ def test_build_base_url_normalized_embeddings_raises( def test_build_base_url_requires_model_name( self, platform_env_vars, mock_platform_auth, normalized_api_config ): - """Test build_base_url asserts model_name is not None.""" + """Test build_base_url raises ValueError when model_name is None.""" with patch.dict(os.environ, platform_env_vars, clear=True): settings = PlatformSettings() - with pytest.raises(AssertionError): + with pytest.raises(ValueError, match="model_name is required"): settings.build_base_url(model_name=None, api_config=normalized_api_config) def test_build_base_url_requires_api_config(self, platform_env_vars, mock_platform_auth): - """Test build_base_url asserts api_config is not None.""" + """Test build_base_url raises ValueError when api_config is None.""" with patch.dict(os.environ, platform_env_vars, clear=True): settings = PlatformSettings() - with pytest.raises(AssertionError): + with pytest.raises(ValueError, match="api_config is required"): settings.build_base_url(model_name="gpt-4o", api_config=None) def test_build_auth_headers_empty_when_no_optional(self, platform_env_vars, mock_platform_auth): @@ -715,12 +729,11 @@ class TestPlatformAuthRefresh: @pytest.fixture(autouse=True) def clear_auth_singleton(self): """Clear PlatformAuth singleton before each test.""" - from uipath.llm_client.settings.platform.auth import PlatformAuth from uipath.llm_client.settings.utils import SingletonMeta - SingletonMeta._instances.pop(PlatformAuth, None) + SingletonMeta._instances.clear() yield - SingletonMeta._instances.pop(PlatformAuth, None) + SingletonMeta._instances.clear() def test_auth_flow_adds_bearer_token(self, platform_env_vars, mock_platform_auth): """Test auth_flow adds Authorization header.""" @@ -765,8 +778,10 @@ def test_auth_flow_refreshes_on_401(self, platform_env_vars, mock_platform_auth) except StopIteration: pass - def test_auth_singleton_reuses_instance(self, platform_env_vars, mock_platform_auth): - """Test that PlatformAuth is a singleton.""" + def test_auth_singleton_reuses_instance_for_same_settings( + self, platform_env_vars, mock_platform_auth + ): + """Test that PlatformAuth reuses the same instance for identical settings.""" from uipath.llm_client.settings.platform.auth import PlatformAuth with patch.dict(os.environ, platform_env_vars, clear=True): @@ -775,6 +790,22 @@ def test_auth_singleton_reuses_instance(self, platform_env_vars, mock_platform_a auth2 = PlatformAuth(settings=settings) assert auth1 is auth2 + def test_auth_creates_separate_instances_for_different_settings( + self, platform_env_vars, mock_platform_auth + ): + """Test that PlatformAuth creates separate instances for different credentials.""" + from uipath.llm_client.settings.platform.auth import PlatformAuth + + env1 = {**platform_env_vars, "UIPATH_ACCESS_TOKEN": "token-x"} + env2 = {**platform_env_vars, "UIPATH_ACCESS_TOKEN": "token-y"} + with patch.dict(os.environ, env1, clear=True): + settings1 = PlatformSettings() + with patch.dict(os.environ, env2, clear=True): + settings2 = PlatformSettings() + auth1 = PlatformAuth(settings=settings1) + auth2 = PlatformAuth(settings=settings2) + assert auth1 is not auth2 + # ============================================================================ # Test Retry Logic @@ -1131,7 +1162,7 @@ class TestSingletonMeta: """Tests for SingletonMeta metaclass.""" def test_singleton_creates_single_instance(self): - """Test singleton creates only one instance.""" + """Test singleton creates only one instance when no cache key is defined.""" class TestSingleton(metaclass=SingletonMeta): def __init__(self, value: int): @@ -1157,6 +1188,41 @@ class SingletonB(metaclass=SingletonMeta): assert a is not b + def test_keyed_singleton_same_key_reuses_instance(self): + """Test that same cache key returns the same instance.""" + + class KeyedSingleton(metaclass=SingletonMeta): + def __init__(self, key: str, value: int): + self.key = key + self.value = value + + @classmethod + def _singleton_cache_key(cls, key: str, value: int) -> tuple: + return (key,) + + a = KeyedSingleton("k1", 10) + b = KeyedSingleton("k1", 20) + assert a is b + assert a.value == 10 # First value retained + + def test_keyed_singleton_different_key_creates_new_instance(self): + """Test that different cache keys create separate instances.""" + + class KeyedSingleton2(metaclass=SingletonMeta): + def __init__(self, key: str, value: int): + self.key = key + self.value = value + + @classmethod + def _singleton_cache_key(cls, key: str, value: int) -> tuple: + return (key,) + + a = KeyedSingleton2("k1", 10) + b = KeyedSingleton2("k2", 20) + assert a is not b + assert a.value == 10 + assert b.value == 20 + # ============================================================================ # Test Exception String Representations and Body Parsing @@ -1555,6 +1621,7 @@ def test_expand_path_tilde(self): from uipath.llm_client.utils.ssl_config import expand_path result = expand_path("~/test") + assert result is not None assert "~" not in result assert result.endswith("/test") @@ -2098,3 +2165,176 @@ def test_enum_is_str_subclass(self): assert isinstance(RoutingMode.PASSTHROUGH, str) assert isinstance(VendorType.OPENAI, str) assert isinstance(ApiFlavor.CHAT_COMPLETIONS, str) + + +# ============================================================================ +# Test RateLimitError retry-after Parsing +# ============================================================================ + + +class TestRateLimitRetryAfterParsing: + """Tests for UiPathRateLimitError._parse_retry_after.""" + + def _make_429_response(self, headers=None): + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 429 + mock_resp.reason_phrase = "Too Many Requests" + mock_resp.request = MagicMock(spec=Request) + mock_resp.headers = Headers(headers or {}) + mock_resp.json.return_value = {} + return mock_resp + + def test_parses_integer_seconds_from_retry_after(self): + """retry-after header with integer seconds is parsed correctly.""" + resp = self._make_429_response(headers={"retry-after": "120"}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result == 120.0 + + def test_parses_float_seconds_from_retry_after(self): + """retry-after header with float seconds is parsed correctly.""" + resp = self._make_429_response(headers={"retry-after": "2.5"}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result == 2.5 + + def test_parses_x_retry_after_as_fallback(self): + """x-retry-after header is used when retry-after is absent.""" + resp = self._make_429_response(headers={"x-retry-after": "30"}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result == 30.0 + + def test_parses_http_date_format(self): + """retry-after with HTTP-date format returns positive delay.""" + from datetime import datetime, timedelta, timezone + + future = datetime.now(timezone.utc) + timedelta(seconds=60) + date_str = future.strftime("%a, %d %b %Y %H:%M:%S GMT") + resp = self._make_429_response(headers={"retry-after": date_str}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result is not None + assert result > 0 + + def test_returns_none_when_no_header_present(self): + """Returns None when neither retry-after nor x-retry-after is set.""" + resp = self._make_429_response(headers={}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result is None + + def test_returns_none_for_unparseable_value(self): + """Returns None for values that are neither numbers nor valid dates.""" + resp = self._make_429_response(headers={"retry-after": "not-valid"}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result is None + + def test_retry_after_prefers_standard_header(self): + """retry-after takes precedence over x-retry-after.""" + resp = self._make_429_response(headers={"retry-after": "10", "x-retry-after": "99"}) + result = UiPathRateLimitError._parse_retry_after(resp) + assert result == 10.0 + + def test_retry_after_property_on_exception(self): + """retry_after property is set from the response header.""" + resp = self._make_429_response(headers={"retry-after": "42"}) + exc = UiPathRateLimitError( + "rate limited", + request=resp.request, + response=resp, + ) + assert exc.retry_after == 42.0 + + +# ============================================================================ +# Test patch_raise_for_status +# ============================================================================ + + +class TestPatchRaiseForStatus: + """Tests for patch_raise_for_status utility.""" + + def test_patched_response_raises_uipath_error_on_error_status(self): + """Patched response raises UiPathAPIError subclass on HTTP error.""" + from httpx import HTTPStatusError + + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 404 + mock_resp.reason_phrase = "Not Found" + mock_resp.json.return_value = {"error": "not found"} + mock_resp.request = MagicMock(spec=Request) + mock_resp.headers = {} + original = MagicMock( + side_effect=HTTPStatusError("err", request=mock_resp.request, response=mock_resp) + ) + mock_resp.raise_for_status = original + + patched = patch_raise_for_status(mock_resp) + with pytest.raises(UiPathAPIError) as exc_info: + patched.raise_for_status() + assert isinstance(exc_info.value, UiPathNotFoundError) + + def test_patched_response_returns_response_on_success(self): + """Patched response returns the response object on 2xx status.""" + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 200 + original = MagicMock(return_value=mock_resp) + mock_resp.raise_for_status = original + + patched = patch_raise_for_status(mock_resp) + result = patched.raise_for_status() + assert result is mock_resp + + def test_patched_replaces_original_method(self): + """The raise_for_status method is replaced, not wrapped additively.""" + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 200 + original = MagicMock(return_value=mock_resp) + mock_resp.raise_for_status = original + + patched = patch_raise_for_status(mock_resp) + assert patched.raise_for_status is not original + + +# ============================================================================ +# Test LLMGateway Singleton Cache Key +# ============================================================================ + + +class TestLLMGatewaySingletonCacheKey: + """Tests for LLMGatewayS2SAuth._singleton_cache_key including base_url.""" + + def test_different_base_urls_produce_different_cache_keys(self, llmgw_env_vars): + """Different base_urls with same credentials produce different cache keys.""" + from uipath.llm_client.settings.llmgateway.auth import LLMGatewayS2SAuth + + env1 = {**llmgw_env_vars, "LLMGW_URL": "https://alpha.uipath.com"} + env2 = {**llmgw_env_vars, "LLMGW_URL": "https://beta.uipath.com"} + + with patch.dict(os.environ, env1, clear=True): + settings1 = LLMGatewaySettings() + with patch.dict(os.environ, env2, clear=True): + settings2 = LLMGatewaySettings() + + key1 = LLMGatewayS2SAuth._singleton_cache_key(settings1) + key2 = LLMGatewayS2SAuth._singleton_cache_key(settings2) + assert key1 != key2 + + def test_same_base_url_and_credentials_produce_same_cache_key(self, llmgw_env_vars): + """Same base_url and credentials produce identical cache keys.""" + from uipath.llm_client.settings.llmgateway.auth import LLMGatewayS2SAuth + + with patch.dict(os.environ, llmgw_env_vars, clear=True): + settings1 = LLMGatewaySettings() + with patch.dict(os.environ, llmgw_env_vars, clear=True): + settings2 = LLMGatewaySettings() + + key1 = LLMGatewayS2SAuth._singleton_cache_key(settings1) + key2 = LLMGatewayS2SAuth._singleton_cache_key(settings2) + assert key1 == key2 + + def test_cache_key_includes_base_url(self, llmgw_env_vars): + """The cache key tuple contains the base_url as its first element.""" + from uipath.llm_client.settings.llmgateway.auth import LLMGatewayS2SAuth + + with patch.dict(os.environ, llmgw_env_vars, clear=True): + settings = LLMGatewaySettings() + + key = LLMGatewayS2SAuth._singleton_cache_key(settings) + assert key[0] == settings.base_url diff --git a/tests/core/test_google_client.py b/tests/core/test_google_client.py index e69de29..9f8a118 100644 --- a/tests/core/test_google_client.py +++ b/tests/core/test_google_client.py @@ -0,0 +1,169 @@ +"""Tests for UiPathGoogle client initialization and API config.""" + +from unittest.mock import MagicMock, PropertyMock, patch + +import httpx + +from uipath.llm_client.settings import UiPathAPIConfig +from uipath.llm_client.settings.constants import ApiFlavor, ApiType, RoutingMode, VendorType + +# ============================================================================ +# Test Google API Config +# ============================================================================ + + +class TestBuildApiConfig: + """Tests for the API config built inside UiPathGoogle.__init__.""" + + def test_default_api_config_fields(self): + """The api_config created in __init__ has the expected constant values.""" + api_config = UiPathAPIConfig( + api_type=ApiType.COMPLETIONS, + routing_mode=RoutingMode.PASSTHROUGH, + vendor_type=VendorType.VERTEXAI, + api_flavor=ApiFlavor.GENERATE_CONTENT, + api_version="v1beta1", + freeze_base_url=True, + ) + assert api_config.api_type == "completions" + assert api_config.routing_mode == "passthrough" + assert api_config.vendor_type == "vertexai" + assert api_config.api_flavor == "generate-content" + assert api_config.api_version == "v1beta1" + assert api_config.freeze_base_url is True + + +# ============================================================================ +# Test UiPathGoogle Initialization +# ============================================================================ + + +def _make_mock_sync_client(): + """Create a mock sync httpx client that passes pydantic validation.""" + client = MagicMock(spec=httpx.Client) + type(client).base_url = PropertyMock(return_value=httpx.URL("https://example.com/base")) + client.headers = httpx.Headers({"Authorization": "Bearer tok"}) + return client + + +def _make_mock_async_client(): + """Create a mock async httpx client that passes pydantic validation.""" + return MagicMock(spec=httpx.AsyncClient) + + +class TestUiPathGoogleInit: + """Tests for UiPathGoogle client construction with mocked dependencies.""" + + @patch("uipath.llm_client.clients.google.client.build_httpx_async_client") + @patch("uipath.llm_client.clients.google.client.build_httpx_client") + @patch("uipath.llm_client.clients.google.client.get_default_client_settings") + def test_placeholder_api_key( + self, + mock_get_settings, + mock_build_sync, + mock_build_async, + ): + """UiPathGoogle passes api_key='PLACEHOLDER' to the parent Client.""" + from google.genai.client import Client + + mock_settings = MagicMock() + mock_get_settings.return_value = mock_settings + mock_build_sync.return_value = _make_mock_sync_client() + mock_build_async.return_value = _make_mock_async_client() + + with patch.object(Client, "__init__", return_value=None) as mock_init: + from uipath.llm_client.clients.google.client import UiPathGoogle + + UiPathGoogle(model_name="gemini-2.5-flash") + + mock_init.assert_called_once() + call_kwargs = mock_init.call_args[1] + assert call_kwargs["api_key"] == "PLACEHOLDER" + + @patch("uipath.llm_client.clients.google.client.build_httpx_async_client") + @patch("uipath.llm_client.clients.google.client.build_httpx_client") + @patch("uipath.llm_client.clients.google.client.get_default_client_settings") + def test_httpx_clients_passed_via_http_options( + self, + mock_get_settings, + mock_build_sync, + mock_build_async, + ): + """UiPathGoogle passes both httpx clients into HttpOptions.""" + from google.genai.client import Client + + mock_settings = MagicMock() + mock_get_settings.return_value = mock_settings + + sync_client = _make_mock_sync_client() + async_client = _make_mock_async_client() + mock_build_sync.return_value = sync_client + mock_build_async.return_value = async_client + + with patch.object(Client, "__init__", return_value=None) as mock_init: + from uipath.llm_client.clients.google.client import UiPathGoogle + + UiPathGoogle(model_name="gemini-2.5-flash") + + call_kwargs = mock_init.call_args[1] + http_options = call_kwargs["http_options"] + assert http_options.httpx_client is sync_client + assert http_options.httpx_async_client is async_client + assert str(http_options.base_url) == "https://example.com/base" + assert http_options.retry_options is None + + @patch("uipath.llm_client.clients.google.client.build_httpx_async_client") + @patch("uipath.llm_client.clients.google.client.build_httpx_client") + @patch("uipath.llm_client.clients.google.client.get_default_client_settings") + def test_uses_provided_client_settings( + self, + mock_get_settings, + mock_build_sync, + mock_build_async, + ): + """When client_settings is provided, get_default_client_settings is not called.""" + from google.genai.client import Client + + custom_settings = MagicMock() + mock_build_sync.return_value = _make_mock_sync_client() + mock_build_async.return_value = _make_mock_async_client() + + with patch.object(Client, "__init__", return_value=None): + from uipath.llm_client.clients.google.client import UiPathGoogle + + UiPathGoogle(model_name="gemini-2.5-flash", client_settings=custom_settings) + + mock_get_settings.assert_not_called() + assert mock_build_sync.call_args[1]["client_settings"] is custom_settings + assert mock_build_async.call_args[1]["client_settings"] is custom_settings + + @patch("uipath.llm_client.clients.google.client.build_httpx_async_client") + @patch("uipath.llm_client.clients.google.client.build_httpx_client") + @patch("uipath.llm_client.clients.google.client.get_default_client_settings") + def test_api_config_forwarded_to_builders( + self, + mock_get_settings, + mock_build_sync, + mock_build_async, + ): + """The internally-built api_config is forwarded to both httpx client builders.""" + from google.genai.client import Client + + mock_get_settings.return_value = MagicMock() + mock_build_sync.return_value = _make_mock_sync_client() + mock_build_async.return_value = _make_mock_async_client() + + with patch.object(Client, "__init__", return_value=None): + from uipath.llm_client.clients.google.client import UiPathGoogle + + UiPathGoogle(model_name="gemini-2.5-flash") + + sync_config = mock_build_sync.call_args[1]["api_config"] + async_config = mock_build_async.call_args[1]["api_config"] + + for cfg in (sync_config, async_config): + assert cfg.api_type == "completions" + assert cfg.routing_mode == "passthrough" + assert cfg.vendor_type == "vertexai" + assert cfg.api_flavor == "generate-content" + assert cfg.freeze_base_url is True diff --git a/tests/core/test_openai_client.py b/tests/core/test_openai_client.py index e69de29..04d4e3c 100644 --- a/tests/core/test_openai_client.py +++ b/tests/core/test_openai_client.py @@ -0,0 +1,461 @@ +"""Tests for the OpenAI client module. + +This module tests: +1. OpenAIRequestHandler.fix_url_and_headers (sync and async) routing logic +2. Client initialization for UiPathOpenAI, UiPathAsyncOpenAI, UiPathAzureOpenAI, + UiPathAsyncAzureOpenAI +""" + +from unittest.mock import MagicMock, patch + +import pytest +from httpx import Request + +from uipath.llm_client.clients.openai.utils import OpenAIRequestHandler +from uipath.llm_client.settings.base import UiPathAPIConfig +from uipath.llm_client.settings.constants import ( + ApiFlavor, + ApiType, + RoutingMode, + VendorType, +) + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def mock_settings(): + settings = MagicMock() + settings.build_base_url.return_value = "https://gateway.uipath.com/llm/v1" + settings.build_auth_headers.return_value = {"Authorization": "Bearer test-token"} + settings.build_auth_pipeline.return_value = None + return settings + + +@pytest.fixture +def handler(mock_settings): + return OpenAIRequestHandler( + model_name="gpt-4o", + client_settings=mock_settings, + byo_connection_id=None, + ) + + +@pytest.fixture +def handler_with_byo(mock_settings): + return OpenAIRequestHandler( + model_name="gpt-4o", + client_settings=mock_settings, + byo_connection_id="my-connection-id", + ) + + +def _make_request(path: str) -> Request: + return Request("POST", f"https://api.openai.com{path}") + + +# ============================================================================ +# OpenAIRequestHandler.__init__ +# ============================================================================ + + +class TestOpenAIRequestHandlerInit: + def test_stores_model_name(self, handler): + assert handler.model_name == "gpt-4o" + + def test_stores_byo_connection_id(self, handler_with_byo): + assert handler_with_byo.byo_connection_id == "my-connection-id" + + def test_base_api_config_defaults(self, handler): + cfg = handler.base_api_config + assert cfg.routing_mode == RoutingMode.PASSTHROUGH + assert cfg.vendor_type == VendorType.OPENAI + assert cfg.api_version == "2025-03-01-preview" + assert cfg.freeze_base_url is False + assert cfg.api_type is None + assert cfg.api_flavor is None + + +# ============================================================================ +# fix_url_and_headers — completions endpoint +# ============================================================================ + + +class TestFixUrlCompletions: + def test_sets_chat_completions_flavor(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.api_flavor == ApiFlavor.CHAT_COMPLETIONS + assert api_config.api_type == ApiType.COMPLETIONS + + def test_preserves_base_config_fields(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.routing_mode == RoutingMode.PASSTHROUGH + assert api_config.vendor_type == VendorType.OPENAI + assert api_config.api_version == "2025-03-01-preview" + + def test_rewrites_url(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + assert str(request.url) == "https://gateway.uipath.com/llm/v1" + + def test_injects_routing_headers(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + assert "X-UiPath-LlmGateway-ApiFlavor" in request.headers + assert request.headers["X-UiPath-LlmGateway-ApiFlavor"] == ApiFlavor.CHAT_COMPLETIONS + + def test_injects_api_version_header(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + assert request.headers.get("X-UiPath-LlmGateway-ApiVersion") == "2025-03-01-preview" + + +# ============================================================================ +# fix_url_and_headers — responses endpoint +# ============================================================================ + + +class TestFixUrlResponses: + def test_sets_responses_flavor(self, handler): + request = _make_request("/v1/responses") + handler.fix_url_and_headers(request) + + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.api_flavor == ApiFlavor.RESPONSES + assert api_config.api_type == ApiType.COMPLETIONS + + def test_rewrites_url(self, handler): + request = _make_request("/v1/responses") + handler.fix_url_and_headers(request) + + assert str(request.url) == "https://gateway.uipath.com/llm/v1" + + def test_injects_responses_flavor_header(self, handler): + request = _make_request("/v1/responses") + handler.fix_url_and_headers(request) + + assert request.headers["X-UiPath-LlmGateway-ApiFlavor"] == ApiFlavor.RESPONSES + + +# ============================================================================ +# fix_url_and_headers — embeddings endpoint +# ============================================================================ + + +class TestFixUrlEmbeddings: + def test_sets_embeddings_api_type(self, handler): + request = _make_request("/v1/embeddings") + handler.fix_url_and_headers(request) + + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.api_type == ApiType.EMBEDDINGS + + def test_no_api_flavor_for_embeddings(self, handler): + request = _make_request("/v1/embeddings") + handler.fix_url_and_headers(request) + + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.api_flavor is None + + def test_rewrites_url(self, handler): + request = _make_request("/v1/embeddings") + handler.fix_url_and_headers(request) + + assert str(request.url) == "https://gateway.uipath.com/llm/v1" + + def test_no_flavor_header_for_embeddings(self, handler): + request = _make_request("/v1/embeddings") + handler.fix_url_and_headers(request) + + assert "X-UiPath-LlmGateway-ApiFlavor" not in request.headers + + +# ============================================================================ +# fix_url_and_headers — unrecognized endpoint +# ============================================================================ + + +class TestFixUrlUnrecognized: + def test_raises_value_error(self, handler): + request = _make_request("/v1/models") + with pytest.raises(ValueError, match="Unrecognized API endpoint"): + handler.fix_url_and_headers(request) + + def test_does_not_call_build_base_url(self, handler): + request = _make_request("/v1/models") + with pytest.raises(ValueError): + handler.fix_url_and_headers(request) + + handler.client_settings.build_base_url.assert_not_called() + + +# ============================================================================ +# fix_url_and_headers — BYO connection ID +# ============================================================================ + + +class TestFixUrlByoConnectionId: + def test_byo_header_injected_on_completions(self, handler_with_byo): + request = _make_request("/v1/chat/completions") + handler_with_byo.fix_url_and_headers(request) + + assert request.headers["X-UiPath-LlmGateway-ByoIsConnectionId"] == "my-connection-id" + + def test_byo_header_injected_on_embeddings(self, handler_with_byo): + request = _make_request("/v1/embeddings") + handler_with_byo.fix_url_and_headers(request) + + assert request.headers["X-UiPath-LlmGateway-ByoIsConnectionId"] == "my-connection-id" + + def test_no_byo_header_when_none(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + assert "X-UiPath-LlmGateway-ByoIsConnectionId" not in request.headers + + +# ============================================================================ +# fix_url_and_headers — does not mutate base_api_config +# ============================================================================ + + +class TestFixUrlDoesNotMutateBase: + def test_base_config_unchanged_after_completions(self, handler): + request = _make_request("/v1/chat/completions") + handler.fix_url_and_headers(request) + + assert handler.base_api_config.api_type is None + assert handler.base_api_config.api_flavor is None + + def test_base_config_unchanged_after_embeddings(self, handler): + request = _make_request("/v1/embeddings") + handler.fix_url_and_headers(request) + + assert handler.base_api_config.api_type is None + assert handler.base_api_config.api_flavor is None + + +# ============================================================================ +# fix_url_and_headers_async +# ============================================================================ + + +class TestFixUrlAndHeadersAsync: + @pytest.mark.asyncio + async def test_async_delegates_to_sync(self, handler): + request = _make_request("/v1/chat/completions") + await handler.fix_url_and_headers_async(request) + + assert str(request.url) == "https://gateway.uipath.com/llm/v1" + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.api_flavor == ApiFlavor.CHAT_COMPLETIONS + + @pytest.mark.asyncio + async def test_async_responses_endpoint(self, handler): + request = _make_request("/v1/responses") + await handler.fix_url_and_headers_async(request) + + call_args = handler.client_settings.build_base_url.call_args + api_config: UiPathAPIConfig = call_args.kwargs["api_config"] + assert api_config.api_flavor == ApiFlavor.RESPONSES + + +# ============================================================================ +# Client initialization +# ============================================================================ + +_CLIENT_MODULE = "uipath.llm_client.clients.openai.client" + + +def _mock_httpx_sync_client(): + import httpx + + client = MagicMock(spec=httpx.Client) + client.base_url = "https://gateway.uipath.com/llm/v1" + client.headers = httpx.Headers() + client._transport = MagicMock() + client._base_url = httpx.URL("https://gateway.uipath.com/llm/v1") + return client + + +def _mock_httpx_async_client(): + import httpx + + client = MagicMock(spec=httpx.AsyncClient) + client.base_url = "https://gateway.uipath.com/llm/v1" + client.headers = httpx.Headers() + client._transport = MagicMock() + client._base_url = httpx.URL("https://gateway.uipath.com/llm/v1") + return client + + +class TestUiPathOpenAIInit: + @patch(f"{_CLIENT_MODULE}.build_httpx_client", return_value=_mock_httpx_sync_client()) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_uses_default_settings_when_none(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathOpenAI + + UiPathOpenAI(model_name="gpt-4o") + + mock_get_settings.assert_called_once() + + @patch(f"{_CLIENT_MODULE}.build_httpx_client", return_value=_mock_httpx_sync_client()) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_passes_event_hooks_with_fix_url(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathOpenAI + + UiPathOpenAI(model_name="gpt-4o") + + call_kwargs = mock_build.call_args.kwargs + assert "event_hooks" in call_kwargs + hooks = call_kwargs["event_hooks"] + assert "request" in hooks + assert len(hooks["request"]) == 1 + # The hook should be a bound method of OpenAIRequestHandler + hook_fn = hooks["request"][0] + assert hook_fn.__func__.__name__ == "fix_url_and_headers" + + @patch(f"{_CLIENT_MODULE}.build_httpx_client", return_value=_mock_httpx_sync_client()) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_sets_max_retries_zero(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathOpenAI + + client = UiPathOpenAI(model_name="gpt-4o") + + assert client.max_retries == 0 + + +class TestUiPathAsyncOpenAIInit: + @patch( + f"{_CLIENT_MODULE}.build_httpx_async_client", + return_value=_mock_httpx_async_client(), + ) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_passes_async_event_hooks(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathAsyncOpenAI + + UiPathAsyncOpenAI(model_name="gpt-4o") + + call_kwargs = mock_build.call_args.kwargs + hooks = call_kwargs["event_hooks"] + hook_fn = hooks["request"][0] + assert hook_fn.__func__.__name__ == "fix_url_and_headers_async" + + @patch( + f"{_CLIENT_MODULE}.build_httpx_async_client", + return_value=_mock_httpx_async_client(), + ) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_sets_max_retries_zero(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathAsyncOpenAI + + client = UiPathAsyncOpenAI(model_name="gpt-4o") + + assert client.max_retries == 0 + + +class TestUiPathAzureOpenAIInit: + @patch(f"{_CLIENT_MODULE}.build_httpx_client", return_value=_mock_httpx_sync_client()) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_passes_placeholder_values(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathAzureOpenAI + + client = UiPathAzureOpenAI(model_name="gpt-4o") + + # AzureOpenAI stores these internally; we verify the client was created + # successfully with PLACEHOLDER values (no real Azure config needed) + assert client.max_retries == 0 + + @patch(f"{_CLIENT_MODULE}.build_httpx_client", return_value=_mock_httpx_sync_client()) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_uses_sync_event_hook(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathAzureOpenAI + + UiPathAzureOpenAI(model_name="gpt-4o") + + call_kwargs = mock_build.call_args.kwargs + hooks = call_kwargs["event_hooks"] + hook_fn = hooks["request"][0] + assert hook_fn.__func__.__name__ == "fix_url_and_headers" + + +class TestUiPathAsyncAzureOpenAIInit: + @patch( + f"{_CLIENT_MODULE}.build_httpx_async_client", + return_value=_mock_httpx_async_client(), + ) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_sets_max_retries_zero(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathAsyncAzureOpenAI + + client = UiPathAsyncAzureOpenAI(model_name="gpt-4o") + + assert client.max_retries == 0 + + @patch( + f"{_CLIENT_MODULE}.build_httpx_async_client", + return_value=_mock_httpx_async_client(), + ) + @patch(f"{_CLIENT_MODULE}.get_default_client_settings") + def test_uses_async_event_hook(self, mock_get_settings, mock_build): + mock_settings = MagicMock() + mock_settings.build_auth_pipeline.return_value = None + mock_get_settings.return_value = mock_settings + + from uipath.llm_client.clients.openai.client import UiPathAsyncAzureOpenAI + + UiPathAsyncAzureOpenAI(model_name="gpt-4o") + + call_kwargs = mock_build.call_args.kwargs + hooks = call_kwargs["event_hooks"] + hook_fn = hooks["request"][0] + assert hook_fn.__func__.__name__ == "fix_url_and_headers_async"