11"""@private"""
22
33import re
4- from typing import Any , Dict , List , Literal , Optional
4+ from typing import Any , Dict , List , Literal , Optional , cast
55
66# NOTE ON DEPENDENCIES:
77# - since Jan 2024, there is https://pypi.org/project/langchain-openai/ which is a separate package and imports openai models.
1212def _extract_model_name (
1313 serialized : Optional [Dict [str , Any ]],
1414 ** kwargs : Any ,
15- ):
15+ ) -> Optional [ str ] :
1616 """Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse."""
1717 # In this function we return on the first match, so the order of operations is important
1818
@@ -39,39 +39,54 @@ def _extract_model_name(
3939
4040 for model_name , keys , select_from in models_by_id :
4141 model = _extract_model_by_path_for_id (
42- model_name , serialized , kwargs , keys , select_from
42+ model_name ,
43+ serialized ,
44+ kwargs ,
45+ keys ,
46+ cast (Literal ["serialized" , "kwargs" ], select_from ),
4347 )
4448 if model :
4549 return model
4650
4751 # Second, we match AzureOpenAI as we need to extract the model name, fdeployment version and deployment name
48- if serialized .get ("id" )[- 1 ] == "AzureOpenAI" :
49- if kwargs .get ("invocation_params" ).get ("model" ):
50- return kwargs .get ("invocation_params" ).get ("model" )
51-
52- if kwargs .get ("invocation_params" ).get ("model_name" ):
53- return kwargs .get ("invocation_params" ).get ("model_name" )
54-
55- deployment_name = None
56- deployment_version = None
57-
58- if serialized .get ("kwargs" ).get ("openai_api_version" ):
59- deployment_version = serialized .get ("kwargs" ).get ("deployment_version" )
60-
61- if serialized .get ("kwargs" ).get ("deployment_name" ):
62- deployment_name = serialized .get ("kwargs" ).get ("deployment_name" )
63-
64- if not isinstance (deployment_name , str ):
65- return None
66-
67- if not isinstance (deployment_version , str ):
68- return deployment_name
69-
70- return (
71- deployment_name + "-" + deployment_version
72- if deployment_version not in deployment_name
73- else deployment_name
74- )
52+ if serialized :
53+ serialized_id = serialized .get ("id" )
54+ if (
55+ serialized_id
56+ and isinstance (serialized_id , list )
57+ and len (serialized_id ) > 0
58+ and serialized_id [- 1 ] == "AzureOpenAI"
59+ ):
60+ invocation_params = kwargs .get ("invocation_params" )
61+ if invocation_params and isinstance (invocation_params , dict ):
62+ if invocation_params .get ("model" ):
63+ return str (invocation_params .get ("model" ))
64+
65+ if invocation_params .get ("model_name" ):
66+ return str (invocation_params .get ("model_name" ))
67+
68+ deployment_name = None
69+ deployment_version = None
70+
71+ serialized_kwargs = serialized .get ("kwargs" )
72+ if serialized_kwargs and isinstance (serialized_kwargs , dict ):
73+ if serialized_kwargs .get ("openai_api_version" ):
74+ deployment_version = serialized_kwargs .get ("deployment_version" )
75+
76+ if serialized_kwargs .get ("deployment_name" ):
77+ deployment_name = serialized_kwargs .get ("deployment_name" )
78+
79+ if not isinstance (deployment_name , str ):
80+ return None
81+
82+ if not isinstance (deployment_version , str ):
83+ return deployment_name
84+
85+ return (
86+ deployment_name + "-" + deployment_version
87+ if deployment_version not in deployment_name
88+ else deployment_name
89+ )
7590
7691 # Third, for some models, we are unable to extract the model by a path in an object. Langfuse provides us with a string representation of the model pbjects
7792 # We use regex to extract the model from the repr string
@@ -111,7 +126,9 @@ def _extract_model_name(
111126 ]
112127 for select in ["kwargs" , "serialized" ]:
113128 for path in random_paths :
114- model = _extract_model_by_path (serialized , kwargs , path , select )
129+ model = _extract_model_by_path (
130+ serialized , kwargs , path , cast (Literal ["serialized" , "kwargs" ], select )
131+ )
115132 if model :
116133 return model
117134
@@ -123,13 +140,20 @@ def _extract_model_from_repr_by_pattern(
123140 serialized : Optional [Dict [str , Any ]],
124141 pattern : str ,
125142 default : Optional [str ] = None ,
126- ):
143+ ) -> Optional [ str ] :
127144 if serialized is None :
128145 return None
129146
130- if serialized .get ("id" )[- 1 ] == id :
131- if serialized .get ("repr" ):
132- extracted = _extract_model_with_regex (pattern , serialized .get ("repr" ))
147+ serialized_id = serialized .get ("id" )
148+ if (
149+ serialized_id
150+ and isinstance (serialized_id , list )
151+ and len (serialized_id ) > 0
152+ and serialized_id [- 1 ] == id
153+ ):
154+ repr_str = serialized .get ("repr" )
155+ if repr_str and isinstance (repr_str , str ):
156+ extracted = _extract_model_with_regex (pattern , repr_str )
133157 return extracted if extracted else default if default else None
134158
135159 return None
@@ -145,15 +169,24 @@ def _extract_model_with_regex(pattern: str, text: str):
145169def _extract_model_by_path_for_id (
146170 id : str ,
147171 serialized : Optional [Dict [str , Any ]],
148- kwargs : dict ,
172+ kwargs : Dict [ str , Any ] ,
149173 keys : List [str ],
150174 select_from : Literal ["serialized" , "kwargs" ],
151- ):
175+ ) -> Optional [ str ] :
152176 if serialized is None and select_from == "serialized" :
153177 return None
154178
155- if serialized .get ("id" )[- 1 ] == id :
156- return _extract_model_by_path (serialized , kwargs , keys , select_from )
179+ if serialized :
180+ serialized_id = serialized .get ("id" )
181+ if (
182+ serialized_id
183+ and isinstance (serialized_id , list )
184+ and len (serialized_id ) > 0
185+ and serialized_id [- 1 ] == id
186+ ):
187+ return _extract_model_by_path (serialized , kwargs , keys , select_from )
188+
189+ return None
157190
158191
159192def _extract_model_by_path (
@@ -168,7 +201,10 @@ def _extract_model_by_path(
168201 current_obj = kwargs if select_from == "kwargs" else serialized
169202
170203 for key in keys :
171- current_obj = current_obj .get (key )
204+ if current_obj and isinstance (current_obj , dict ):
205+ current_obj = current_obj .get (key )
206+ else :
207+ return None
172208 if not current_obj :
173209 return None
174210
0 commit comments