2020 >>> embeddings = get_embedding_model(model_name="text-embedding-3-large", client_settings=settings)
2121"""
2222
23- from typing import Any , Literal , cast
23+ from typing import Any , Literal
2424
2525from uipath_langchain_client .base_client import (
2626 UiPathBaseChatModel ,
@@ -88,9 +88,7 @@ def get_chat_model(
8888 ValueError: If the model is not found in available models or vendor is not supported
8989 """
9090 client_settings = client_settings or get_default_client_settings ()
91-
9291 model_info = _get_model_info (model_name , client_settings , byo_connection_id )
93-
9492 is_uipath_owned = model_info .get ("modelSubscriptionType" ) == "UiPathOwned"
9593 if not is_uipath_owned :
9694 client_settings .validate_byo_model (model_info )
@@ -100,14 +98,11 @@ def get_chat_model(
10098 UiPathChat ,
10199 )
102100
103- return cast (
104- UiPathBaseChatModel ,
105- UiPathChat (
106- model = model_name ,
107- settings = client_settings ,
108- byo_connection_id = byo_connection_id ,
109- ** model_kwargs ,
110- ),
101+ return UiPathChat (
102+ model = model_name ,
103+ settings = client_settings ,
104+ byo_connection_id = byo_connection_id ,
105+ ** model_kwargs ,
111106 )
112107
113108 vendor_type = model_info ["vendor" ].lower ()
@@ -118,27 +113,22 @@ def get_chat_model(
118113 UiPathAzureChatOpenAI ,
119114 )
120115
121- return cast (
122- UiPathBaseChatModel ,
123- UiPathAzureChatOpenAI (
124- model = model_name ,
125- settings = client_settings ,
126- ** model_kwargs ,
127- ),
116+ return UiPathAzureChatOpenAI (
117+ model = model_name ,
118+ settings = client_settings ,
119+ byo_connection_id = byo_connection_id ,
120+ ** model_kwargs ,
128121 )
129122 else :
130123 from uipath_langchain_client .clients .openai .chat_models import (
131124 UiPathChatOpenAI ,
132125 )
133126
134- return cast (
135- UiPathBaseChatModel ,
136- UiPathChatOpenAI (
137- model = model_name ,
138- settings = client_settings ,
139- byo_connection_id = byo_connection_id ,
140- ** model_kwargs ,
141- ),
127+ return UiPathChatOpenAI (
128+ model = model_name ,
129+ settings = client_settings ,
130+ byo_connection_id = byo_connection_id ,
131+ ** model_kwargs ,
142132 )
143133 case "vertexai" :
144134 if is_uipath_owned :
@@ -147,27 +137,23 @@ def get_chat_model(
147137 UiPathChatAnthropic ,
148138 )
149139
150- return cast (
151- UiPathBaseChatModel ,
152- UiPathChatAnthropic (
153- model = model_name ,
154- settings = client_settings ,
155- vendor_type = vendor_type ,
156- ** model_kwargs ,
157- ),
140+ return UiPathChatAnthropic (
141+ model = model_name ,
142+ settings = client_settings ,
143+ vendor_type = vendor_type ,
144+ byo_connection_id = byo_connection_id ,
145+ ** model_kwargs ,
158146 )
159147 elif "gemini" in model_name :
160148 from uipath_langchain_client .clients .google .chat_models import (
161149 UiPathChatGoogleGenerativeAI ,
162150 )
163151
164- return cast (
165- UiPathBaseChatModel ,
166- UiPathChatGoogleGenerativeAI (
167- model = model_name ,
168- settings = client_settings ,
169- ** model_kwargs ,
170- ),
152+ return UiPathChatGoogleGenerativeAI (
153+ model = model_name ,
154+ settings = client_settings ,
155+ byo_connection_id = byo_connection_id ,
156+ ** model_kwargs ,
171157 )
172158 else :
173159 raise ValueError (
@@ -178,14 +164,11 @@ def get_chat_model(
178164 UiPathChatGoogleGenerativeAI ,
179165 )
180166
181- return cast (
182- UiPathBaseChatModel ,
183- UiPathChatGoogleGenerativeAI (
184- model = model_name ,
185- settings = client_settings ,
186- byo_connection_id = byo_connection_id ,
187- ** model_kwargs ,
188- ),
167+ return UiPathChatGoogleGenerativeAI (
168+ model = model_name ,
169+ settings = client_settings ,
170+ byo_connection_id = byo_connection_id ,
171+ ** model_kwargs ,
189172 )
190173 case "awsbedrock" :
191174 if is_uipath_owned :
@@ -194,44 +177,39 @@ def get_chat_model(
194177 UiPathChatAnthropic ,
195178 )
196179
197- return cast (
198- UiPathBaseChatModel ,
199- UiPathChatAnthropic (
200- model = model_name ,
201- settings = client_settings ,
202- vendor_type = vendor_type ,
203- ** model_kwargs ,
204- ),
180+ return UiPathChatAnthropic (
181+ model = model_name ,
182+ settings = client_settings ,
183+ vendor_type = vendor_type ,
184+ byo_connection_id = byo_connection_id ,
185+ ** model_kwargs ,
205186 )
206187 else :
207188 from uipath_langchain_client .clients .bedrock .chat_models import (
208189 UiPathChatBedrock ,
209190 )
210191
211- return cast (
212- UiPathBaseChatModel ,
213- UiPathChatBedrock (
214- model = model_name ,
215- settings = client_settings ,
216- ** model_kwargs ,
217- ),
192+ return UiPathChatBedrock (
193+ model = model_name ,
194+ settings = client_settings ,
195+ byo_connection_id = byo_connection_id ,
196+ ** model_kwargs ,
218197 )
198+
219199 else :
220200 from uipath_langchain_client .clients .bedrock .chat_models import (
221201 UiPathChatBedrockConverse ,
222202 )
223203
224- return cast (
225- UiPathBaseChatModel ,
226- UiPathChatBedrockConverse (
227- model = model_name ,
228- settings = client_settings ,
229- ** model_kwargs ,
230- ),
204+ return UiPathChatBedrockConverse (
205+ model = model_name ,
206+ settings = client_settings ,
207+ byo_connection_id = byo_connection_id ,
208+ ** model_kwargs ,
231209 )
232210 case _:
233211 raise ValueError (
234- f"Invalid vendor type: { vendor_type } , we don't currently have clients that support that api type "
212+ f"Invalid vendor type: { vendor_type } , we don't currently have clients that support this vendor "
235213 )
236214
237215
@@ -266,80 +244,71 @@ def get_embedding_model(
266244 """
267245 client_settings = client_settings or get_default_client_settings ()
268246 model_info = _get_model_info (model_name , client_settings , byo_connection_id )
247+ is_uipath_owned = model_info .get ("modelSubscriptionType" ) == "UiPathOwned"
248+ if not is_uipath_owned :
249+ client_settings .validate_byo_model (model_info )
269250
270251 if client_type == "normalized" :
271252 from uipath_langchain_client .clients .normalized .embeddings import (
272253 UiPathEmbeddings ,
273254 )
274255
275- return cast (
276- UiPathBaseEmbeddings ,
277- UiPathEmbeddings (
278- model = model_name ,
279- settings = client_settings ,
280- byo_connection_id = byo_connection_id ,
281- ** model_kwargs ,
282- ),
256+ return UiPathEmbeddings (
257+ model = model_name ,
258+ settings = client_settings ,
259+ byo_connection_id = byo_connection_id ,
260+ ** model_kwargs ,
283261 )
284262
285263 vendor_type = model_info ["vendor" ].lower ()
286- is_uipath_owned = model_info .get ("modelSubscriptionType" ) == "UiPathOwned"
287264 match vendor_type :
288265 case "openai" :
289266 if is_uipath_owned :
290267 from uipath_langchain_client .clients .openai .embeddings import (
291268 UiPathAzureOpenAIEmbeddings ,
292269 )
293270
294- return cast (
295- UiPathBaseEmbeddings ,
296- UiPathAzureOpenAIEmbeddings (
297- model = model_name , settings = client_settings , ** model_kwargs
298- ) ,
271+ return UiPathAzureOpenAIEmbeddings (
272+ model = model_name ,
273+ settings = client_settings ,
274+ byo_connection_id = byo_connection_id ,
275+ ** model_kwargs ,
299276 )
300277 else :
301278 from uipath_langchain_client .clients .openai .embeddings import (
302279 UiPathOpenAIEmbeddings ,
303280 )
304281
305- return cast (
306- UiPathBaseEmbeddings ,
307- UiPathOpenAIEmbeddings (
308- model = model_name ,
309- settings = client_settings ,
310- byo_connection_id = byo_connection_id ,
311- ** model_kwargs ,
312- ),
282+ return UiPathOpenAIEmbeddings (
283+ model = model_name ,
284+ settings = client_settings ,
285+ byo_connection_id = byo_connection_id ,
286+ ** model_kwargs ,
313287 )
288+
314289 case "vertexai" :
315290 from uipath_langchain_client .clients .google .embeddings import (
316291 UiPathGoogleGenerativeAIEmbeddings ,
317292 )
318293
319- return cast (
320- UiPathBaseEmbeddings ,
321- UiPathGoogleGenerativeAIEmbeddings (
322- model = model_name ,
323- settings = client_settings ,
324- byo_connection_id = byo_connection_id ,
325- ** model_kwargs ,
326- ),
294+ return UiPathGoogleGenerativeAIEmbeddings (
295+ model = model_name ,
296+ settings = client_settings ,
297+ byo_connection_id = byo_connection_id ,
298+ ** model_kwargs ,
327299 )
328300 case "awsbedrock" :
329301 from uipath_langchain_client .clients .bedrock .embeddings import (
330302 UiPathBedrockEmbeddings ,
331303 )
332304
333- return cast (
334- UiPathBaseEmbeddings ,
335- UiPathBedrockEmbeddings (
336- model = model_name ,
337- settings = client_settings ,
338- byo_connection_id = byo_connection_id ,
339- ** model_kwargs ,
340- ),
305+ return UiPathBedrockEmbeddings (
306+ model = model_name ,
307+ settings = client_settings ,
308+ byo_connection_id = byo_connection_id ,
309+ ** model_kwargs ,
341310 )
342311 case _:
343312 raise ValueError (
344- f"We don't currently have clients that support this provider: { vendor_type } "
313+ f"Invalid vendor type: { vendor_type } , we don't currently have clients that support this vendor "
345314 )
0 commit comments