3030 API_FLAVOR_TO_VENDOR_TYPE ,
3131 BYOM_TO_ROUTING_FLAVOR ,
3232 ApiFlavor ,
33+ ModelFamily ,
3334 RoutingMode ,
3435 UiPathBaseSettings ,
3536 VendorType ,
3637 get_default_client_settings ,
3738)
3839
3940
40- def _get_model_info (
41- model_name : str ,
42- * ,
43- client_settings : UiPathBaseSettings ,
44- byo_connection_id : str | None = None ,
45- vendor_type : VendorType | str | None = None ,
46- ) -> dict [str , Any ]:
47- available_models = client_settings .get_available_models ()
48-
49- matching_models = [m for m in available_models if m ["modelName" ].lower () == model_name .lower ()]
50-
51- if vendor_type is not None :
52- matching_models = [
53- m for m in matching_models if m .get ("vendor" , "" ).lower () == str (vendor_type ).lower ()
54- ]
55-
56- if byo_connection_id :
57- matching_models = [
58- m
59- for m in matching_models
60- if (byom_details := m .get ("byomDetails" ))
61- and byom_details .get ("integrationServiceConnectionId" , "" ).lower ()
62- == byo_connection_id .lower ()
63- ]
64-
65- if not byo_connection_id and len (matching_models ) > 1 :
66- matching_models = [
67- m
68- for m in matching_models
69- if (
70- (m .get ("modelSubscriptionType" , "" ) == "UiPathOwned" )
71- or (m .get ("byomDetails" ) is None )
72- )
73- ]
74-
75- if not matching_models :
76- raise ValueError (
77- f"Model { model_name } not found. Available models are: { [m ['modelName' ] for m in available_models ]} "
78- )
79-
80- return matching_models [0 ]
81-
82-
8341def get_chat_model (
8442 model_name : str ,
8543 * ,
@@ -120,18 +78,12 @@ def get_chat_model(
12078 ValueError: If the model is not found in available models or vendor is not supported.
12179 """
12280 client_settings = client_settings or get_default_client_settings ()
123- model_info = _get_model_info (
81+ model_info = client_settings . get_model_info (
12482 model_name ,
125- client_settings = client_settings ,
12683 byo_connection_id = byo_connection_id ,
12784 vendor_type = vendor_type ,
12885 )
12986 model_family = model_info .get ("modelFamily" , None )
130- if model_family is not None :
131- model_family = model_family .lower ()
132- is_uipath_owned = model_info .get ("modelSubscriptionType" ) == "UiPathOwned"
133- if not is_uipath_owned :
134- client_settings .validate_byo_model (model_info )
13587
13688 if custom_class is not None :
13789 return custom_class (
@@ -171,7 +123,7 @@ def get_chat_model(
171123
172124 match discovered_vendor_type :
173125 case VendorType .OPENAI :
174- if is_uipath_owned :
126+ if model_family == ModelFamily . OPENAI :
175127 from uipath_langchain_client .clients .openai .chat_models import (
176128 UiPathAzureChatOpenAI ,
177129 )
@@ -196,7 +148,7 @@ def get_chat_model(
196148 ** model_kwargs ,
197149 )
198150 case VendorType .VERTEXAI :
199- if model_family == "anthropicclaude" :
151+ if model_family == ModelFamily . ANTHROPIC_CLAUDE :
200152 from uipath_langchain_client .clients .anthropic .chat_models import (
201153 UiPathChatAnthropic ,
202154 )
@@ -220,7 +172,7 @@ def get_chat_model(
220172 ** model_kwargs ,
221173 )
222174 case VendorType .AWSBEDROCK :
223- if model_family == "anthropicclaude" and api_flavor is None :
175+ if model_family == ModelFamily . ANTHROPIC_CLAUDE and api_flavor is None :
224176 from uipath_langchain_client .clients .bedrock .chat_models import (
225177 UiPathChatAnthropicBedrock ,
226178 )
@@ -300,15 +252,12 @@ def get_embedding_model(
300252 >>> vectors = embeddings.embed_documents(["Hello world"])
301253 """
302254 client_settings = client_settings or get_default_client_settings ()
303- model_info = _get_model_info (
255+ model_info = client_settings . get_model_info (
304256 model_name ,
305- client_settings = client_settings ,
306257 byo_connection_id = byo_connection_id ,
307258 vendor_type = vendor_type ,
308259 )
309- is_uipath_owned = model_info .get ("modelSubscriptionType" ) == "UiPathOwned"
310- if not is_uipath_owned :
311- client_settings .validate_byo_model (model_info )
260+ model_family = model_info .get ("modelFamily" , None )
312261
313262 if custom_class is not None :
314263 return custom_class (
@@ -342,7 +291,7 @@ def get_embedding_model(
342291 discovered_vendor_type = discovered_vendor_type .lower ()
343292 match discovered_vendor_type :
344293 case VendorType .OPENAI :
345- if is_uipath_owned :
294+ if model_family == ModelFamily . OPENAI :
346295 from uipath_langchain_client .clients .openai .embeddings import (
347296 UiPathAzureOpenAIEmbeddings ,
348297 )
0 commit comments