2020 >>> embeddings = get_embedding_model(model_name="text-embedding-3-large", client_settings=settings)
2121"""
2222
23- from typing import Any , Literal
24-
25- from langchain_core .embeddings import Embeddings
26- from langchain_core .language_models .chat_models import BaseChatModel
23+ from typing import Any , Literal , cast
2724
25+ from uipath_langchain_client .base_client import (
26+ UiPathBaseChatModel ,
27+ UiPathBaseEmbeddings ,
28+ )
2829from uipath_langchain_client .settings import UiPathBaseSettings , get_default_client_settings
2930
3031
@@ -70,7 +71,7 @@ def get_chat_model(
7071 client_settings : UiPathBaseSettings | None = None ,
7172 client_type : Literal ["passthrough" , "normalized" ] = "passthrough" ,
7273 ** model_kwargs : Any ,
73- ) -> BaseChatModel :
74+ ) -> UiPathBaseChatModel :
7475 """Factory function to create the appropriate LangChain chat model for a given model name.
7576
7677 Automatically detects the model vendor and returns the correct LangChain model class.
@@ -99,11 +100,14 @@ def get_chat_model(
99100 UiPathChat ,
100101 )
101102
102- return UiPathChat (
103- model = model_name ,
104- settings = client_settings ,
105- byo_connection_id = byo_connection_id ,
106- ** model_kwargs ,
103+ return cast (
104+ UiPathBaseChatModel ,
105+ UiPathChat (
106+ model = model_name ,
107+ settings = client_settings ,
108+ byo_connection_id = byo_connection_id ,
109+ ** model_kwargs ,
110+ ),
107111 )
108112
109113 vendor_type = model_info ["vendor" ].lower ()
@@ -114,21 +118,27 @@ def get_chat_model(
114118 UiPathAzureChatOpenAI ,
115119 )
116120
117- return UiPathAzureChatOpenAI (
118- model = model_name ,
119- settings = client_settings ,
120- ** model_kwargs ,
121+ return cast (
122+ UiPathBaseChatModel ,
123+ UiPathAzureChatOpenAI (
124+ model = model_name ,
125+ settings = client_settings ,
126+ ** model_kwargs ,
127+ ),
121128 )
122129 else :
123130 from uipath_langchain_client .clients .openai .chat_models import (
124131 UiPathChatOpenAI ,
125132 )
126133
127- return UiPathChatOpenAI (
128- model = model_name ,
129- settings = client_settings ,
130- byo_connection_id = byo_connection_id ,
131- ** model_kwargs ,
134+ return cast (
135+ UiPathBaseChatModel ,
136+ UiPathChatOpenAI (
137+ model = model_name ,
138+ settings = client_settings ,
139+ byo_connection_id = byo_connection_id ,
140+ ** model_kwargs ,
141+ ),
132142 )
133143 case "vertexai" :
134144 if is_uipath_owned :
@@ -137,20 +147,26 @@ def get_chat_model(
137147 UiPathChatAnthropicVertex ,
138148 )
139149
140- return UiPathChatAnthropicVertex (
141- model = model_name ,
142- settings = client_settings ,
143- ** model_kwargs ,
150+ return cast (
151+ UiPathBaseChatModel ,
152+ UiPathChatAnthropicVertex (
153+ model = model_name ,
154+ settings = client_settings ,
155+ ** model_kwargs ,
156+ ),
144157 )
145158 elif "gemini" in model_name :
146159 from uipath_langchain_client .clients .google .chat_models import (
147160 UiPathChatGoogleGenerativeAI ,
148161 )
149162
150- return UiPathChatGoogleGenerativeAI (
151- model = model_name ,
152- settings = client_settings ,
153- ** model_kwargs ,
163+ return cast (
164+ UiPathBaseChatModel ,
165+ UiPathChatGoogleGenerativeAI (
166+ model = model_name ,
167+ settings = client_settings ,
168+ ** model_kwargs ,
169+ ),
154170 )
155171 else :
156172 raise ValueError (
@@ -161,11 +177,14 @@ def get_chat_model(
161177 UiPathChatGoogleGenerativeAI ,
162178 )
163179
164- return UiPathChatGoogleGenerativeAI (
165- model = model_name ,
166- settings = client_settings ,
167- byo_connection_id = byo_connection_id ,
168- ** model_kwargs ,
180+ return cast (
181+ UiPathBaseChatModel ,
182+ UiPathChatGoogleGenerativeAI (
183+ model = model_name ,
184+ settings = client_settings ,
185+ byo_connection_id = byo_connection_id ,
186+ ** model_kwargs ,
187+ ),
169188 )
170189 case "awsbedrock" :
171190 if is_uipath_owned :
@@ -174,31 +193,40 @@ def get_chat_model(
174193 UiPathChatAnthropic ,
175194 )
176195
177- return UiPathChatAnthropic (
178- model = model_name ,
179- settings = client_settings ,
180- vendor_type = vendor_type ,
181- ** model_kwargs ,
196+ return cast (
197+ UiPathBaseChatModel ,
198+ UiPathChatAnthropic (
199+ model = model_name ,
200+ settings = client_settings ,
201+ vendor_type = vendor_type ,
202+ ** model_kwargs ,
203+ ),
182204 )
183205 else :
184206 from uipath_langchain_client .clients .bedrock .chat_models import (
185207 UiPathChatBedrock ,
186208 )
187209
188- return UiPathChatBedrock (
189- model = model_name ,
190- settings = client_settings ,
191- ** model_kwargs ,
210+ return cast (
211+ UiPathBaseChatModel ,
212+ UiPathChatBedrock (
213+ model = model_name ,
214+ settings = client_settings ,
215+ ** model_kwargs ,
216+ ),
192217 )
193218 else :
194219 from uipath_langchain_client .clients .bedrock .chat_models import (
195220 UiPathChatBedrockConverse ,
196221 )
197222
198- return UiPathChatBedrockConverse (
199- model = model_name ,
200- settings = client_settings ,
201- ** model_kwargs ,
223+ return cast (
224+ UiPathBaseChatModel ,
225+ UiPathChatBedrockConverse (
226+ model = model_name ,
227+ settings = client_settings ,
228+ ** model_kwargs ,
229+ ),
202230 )
203231 case _:
204232 raise ValueError (
@@ -212,7 +240,7 @@ def get_embedding_model(
212240 client_settings : UiPathBaseSettings | None = None ,
213241 client_type : Literal ["passthrough" , "normalized" ] = "passthrough" ,
214242 ** model_kwargs : Any ,
215- ) -> Embeddings :
243+ ) -> UiPathBaseEmbeddings :
216244 """Factory function to create the appropriate LangChain embeddings model.
217245
218246 Automatically detects the model vendor and returns the correct LangChain embeddings class.
@@ -243,11 +271,14 @@ def get_embedding_model(
243271 UiPathEmbeddings ,
244272 )
245273
246- return UiPathEmbeddings (
247- model = model_name ,
248- settings = client_settings ,
249- byo_connection_id = byo_connection_id ,
250- ** model_kwargs ,
274+ return cast (
275+ UiPathBaseEmbeddings ,
276+ UiPathEmbeddings (
277+ model = model_name ,
278+ settings = client_settings ,
279+ byo_connection_id = byo_connection_id ,
280+ ** model_kwargs ,
281+ ),
251282 )
252283
253284 vendor_type = model_info ["vendor" ].lower ()
@@ -259,41 +290,53 @@ def get_embedding_model(
259290 UiPathAzureOpenAIEmbeddings ,
260291 )
261292
262- return UiPathAzureOpenAIEmbeddings (
263- model = model_name , settings = client_settings , ** model_kwargs
293+ return cast (
294+ UiPathBaseEmbeddings ,
295+ UiPathAzureOpenAIEmbeddings (
296+ model = model_name , settings = client_settings , ** model_kwargs
297+ ),
264298 )
265299 else :
266300 from uipath_langchain_client .clients .openai .embeddings import (
267301 UiPathOpenAIEmbeddings ,
268302 )
269303
270- return UiPathOpenAIEmbeddings (
271- model = model_name ,
272- settings = client_settings ,
273- byo_connection_id = byo_connection_id ,
274- ** model_kwargs ,
304+ return cast (
305+ UiPathBaseEmbeddings ,
306+ UiPathOpenAIEmbeddings (
307+ model = model_name ,
308+ settings = client_settings ,
309+ byo_connection_id = byo_connection_id ,
310+ ** model_kwargs ,
311+ ),
275312 )
276313 case "vertexai" :
277314 from uipath_langchain_client .clients .google .embeddings import (
278315 UiPathGoogleGenerativeAIEmbeddings ,
279316 )
280317
281- return UiPathGoogleGenerativeAIEmbeddings (
282- model = model_name ,
283- settings = client_settings ,
284- byo_connection_id = byo_connection_id ,
285- ** model_kwargs ,
318+ return cast (
319+ UiPathBaseEmbeddings ,
320+ UiPathGoogleGenerativeAIEmbeddings (
321+ model = model_name ,
322+ settings = client_settings ,
323+ byo_connection_id = byo_connection_id ,
324+ ** model_kwargs ,
325+ ),
286326 )
287327 case "awsbedrock" :
288328 from uipath_langchain_client .clients .bedrock .embeddings import (
289329 UiPathBedrockEmbeddings ,
290330 )
291331
292- return UiPathBedrockEmbeddings (
293- model = model_name ,
294- settings = client_settings ,
295- byo_connection_id = byo_connection_id ,
296- ** model_kwargs ,
332+ return cast (
333+ UiPathBaseEmbeddings ,
334+ UiPathBedrockEmbeddings (
335+ model = model_name ,
336+ settings = client_settings ,
337+ byo_connection_id = byo_connection_id ,
338+ ** model_kwargs ,
339+ ),
297340 )
298341 case _:
299342 raise ValueError (
0 commit comments