diff --git a/packages/uipath_langchain_client/CHANGELOG.md b/packages/uipath_langchain_client/CHANGELOG.md index d95aee8..fd373e7 100644 --- a/packages/uipath_langchain_client/CHANGELOG.md +++ b/packages/uipath_langchain_client/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to `uipath_langchain_client` will be documented in this file. +## [1.1.6] - 2026-02-12 + +### Fixes +- Added proper type hints for factory method + ## [1.1.5] - 2026-02-12 ### Fixes diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py index b1a30a3..3dcd6fe 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py @@ -1,3 +1,3 @@ __title__ = "UiPath LangChain Client" __description__ = "A Python client for interacting with UiPath's LLM services via LangChain." -__version__ = "1.1.5" +__version__ = "1.1.6" diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py b/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py index f874e08..26de7dd 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/base_client.py @@ -24,11 +24,14 @@ """ import logging +from abc import ABC from collections.abc import AsyncIterator, Iterator, Mapping from functools import cached_property from typing import Any, Literal from httpx import URL, Response +from langchain_core.embeddings import Embeddings +from langchain_core.language_models.chat_models import BaseChatModel from pydantic import AliasChoices, BaseModel, ConfigDict, Field from uipath_langchain_client.settings import ( @@ -40,7 +43,7 @@ from uipath_llm_client.utils.retry import RetryConfig -class UiPathBaseLLMClient(BaseModel): +class UiPathBaseLLMClient(BaseModel, ABC): """Base HTTP client for interacting with UiPath's LLM services. Provides the underlying HTTP transport layer with support for: @@ -50,7 +53,6 @@ class UiPathBaseLLMClient(BaseModel): - Request/response logging This class is typically used as a mixin with framework-specific chat models - (e.g., LangChain, LlamaIndex) to provide UiPath connectivity. Attributes: model_name: Name of the LLM model to use (aliased as "model") @@ -278,3 +280,11 @@ async def uipath_astream( case "raw": async for chunk in response.aiter_raw(): yield chunk + + +class UiPathBaseChatModel(UiPathBaseLLMClient, BaseChatModel): + pass + + +class UiPathBaseEmbeddings(UiPathBaseLLMClient, Embeddings): + pass diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py index 7f2d84f..eff63f7 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py @@ -20,11 +20,12 @@ >>> embeddings = get_embedding_model(model_name="text-embedding-3-large", client_settings=settings) """ -from typing import Any, Literal - -from langchain_core.embeddings import Embeddings -from langchain_core.language_models.chat_models import BaseChatModel +from typing import Any, Literal, cast +from uipath_langchain_client.base_client import ( + UiPathBaseChatModel, + UiPathBaseEmbeddings, +) from uipath_langchain_client.settings import UiPathBaseSettings, get_default_client_settings @@ -70,7 +71,7 @@ def get_chat_model( client_settings: UiPathBaseSettings | None = None, client_type: Literal["passthrough", "normalized"] = "passthrough", **model_kwargs: Any, -) -> BaseChatModel: +) -> UiPathBaseChatModel: """Factory function to create the appropriate LangChain chat model for a given model name. Automatically detects the model vendor and returns the correct LangChain model class. @@ -99,11 +100,14 @@ def get_chat_model( UiPathChat, ) - return UiPathChat( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChat( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) vendor_type = model_info["vendor"].lower() @@ -114,21 +118,27 @@ def get_chat_model( UiPathAzureChatOpenAI, ) - return UiPathAzureChatOpenAI( - model=model_name, - settings=client_settings, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathAzureChatOpenAI( + model=model_name, + settings=client_settings, + **model_kwargs, + ), ) else: from uipath_langchain_client.clients.openai.chat_models import ( UiPathChatOpenAI, ) - return UiPathChatOpenAI( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatOpenAI( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) case "vertexai": if is_uipath_owned: @@ -137,20 +147,26 @@ def get_chat_model( UiPathChatAnthropicVertex, ) - return UiPathChatAnthropicVertex( - model=model_name, - settings=client_settings, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatAnthropicVertex( + model=model_name, + settings=client_settings, + **model_kwargs, + ), ) elif "gemini" in model_name: from uipath_langchain_client.clients.google.chat_models import ( UiPathChatGoogleGenerativeAI, ) - return UiPathChatGoogleGenerativeAI( - model=model_name, - settings=client_settings, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatGoogleGenerativeAI( + model=model_name, + settings=client_settings, + **model_kwargs, + ), ) else: raise ValueError( @@ -161,11 +177,14 @@ def get_chat_model( UiPathChatGoogleGenerativeAI, ) - return UiPathChatGoogleGenerativeAI( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatGoogleGenerativeAI( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) case "awsbedrock": if is_uipath_owned: @@ -174,31 +193,40 @@ def get_chat_model( UiPathChatAnthropic, ) - return UiPathChatAnthropic( - model=model_name, - settings=client_settings, - vendor_type=vendor_type, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatAnthropic( + model=model_name, + settings=client_settings, + vendor_type=vendor_type, + **model_kwargs, + ), ) else: from uipath_langchain_client.clients.bedrock.chat_models import ( UiPathChatBedrock, ) - return UiPathChatBedrock( - model=model_name, - settings=client_settings, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatBedrock( + model=model_name, + settings=client_settings, + **model_kwargs, + ), ) else: from uipath_langchain_client.clients.bedrock.chat_models import ( UiPathChatBedrockConverse, ) - return UiPathChatBedrockConverse( - model=model_name, - settings=client_settings, - **model_kwargs, + return cast( + UiPathBaseChatModel, + UiPathChatBedrockConverse( + model=model_name, + settings=client_settings, + **model_kwargs, + ), ) case _: raise ValueError( @@ -212,7 +240,7 @@ def get_embedding_model( client_settings: UiPathBaseSettings | None = None, client_type: Literal["passthrough", "normalized"] = "passthrough", **model_kwargs: Any, -) -> Embeddings: +) -> UiPathBaseEmbeddings: """Factory function to create the appropriate LangChain embeddings model. Automatically detects the model vendor and returns the correct LangChain embeddings class. @@ -243,11 +271,14 @@ def get_embedding_model( UiPathEmbeddings, ) - return UiPathEmbeddings( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseEmbeddings, + UiPathEmbeddings( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) vendor_type = model_info["vendor"].lower() @@ -259,41 +290,53 @@ def get_embedding_model( UiPathAzureOpenAIEmbeddings, ) - return UiPathAzureOpenAIEmbeddings( - model=model_name, settings=client_settings, **model_kwargs + return cast( + UiPathBaseEmbeddings, + UiPathAzureOpenAIEmbeddings( + model=model_name, settings=client_settings, **model_kwargs + ), ) else: from uipath_langchain_client.clients.openai.embeddings import ( UiPathOpenAIEmbeddings, ) - return UiPathOpenAIEmbeddings( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseEmbeddings, + UiPathOpenAIEmbeddings( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) case "vertexai": from uipath_langchain_client.clients.google.embeddings import ( UiPathGoogleGenerativeAIEmbeddings, ) - return UiPathGoogleGenerativeAIEmbeddings( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseEmbeddings, + UiPathGoogleGenerativeAIEmbeddings( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) case "awsbedrock": from uipath_langchain_client.clients.bedrock.embeddings import ( UiPathBedrockEmbeddings, ) - return UiPathBedrockEmbeddings( - model=model_name, - settings=client_settings, - byo_connection_id=byo_connection_id, - **model_kwargs, + return cast( + UiPathBaseEmbeddings, + UiPathBedrockEmbeddings( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ), ) case _: raise ValueError(