@@ -111,6 +111,25 @@ def has_column(
111111 name = column if isinstance (column , str ) else column .name
112112 return name in self .column_names (table , dialect = dialect , normalize = normalize )
113113
114+ def get_udf_type (
115+ self ,
116+ udf : exp .Anonymous | str ,
117+ dialect : DialectType = None ,
118+ normalize : t .Optional [bool ] = None ,
119+ ) -> exp .DataType :
120+ """
121+ Get the return type of a UDF.
122+
123+ Args:
124+ udf: the UDF expression or string.
125+ dialect: the SQL dialect for parsing string arguments.
126+ normalize: whether to normalize identifiers.
127+
128+ Returns:
129+ The return type as a DataType, or UNKNOWN if not found.
130+ """
131+ return exp .DataType .build ("unknown" )
132+
114133 @property
115134 @abc .abstractmethod
116135 def supported_table_args (self ) -> t .Tuple [str , ...]:
@@ -128,11 +147,18 @@ class AbstractMappingSchema:
128147 def __init__ (
129148 self ,
130149 mapping : t .Optional [t .Dict ] = None ,
150+ udf_mapping : t .Optional [t .Dict ] = None ,
131151 ) -> None :
132152 self .mapping = mapping or {}
133153 self .mapping_trie = new_trie (
134154 tuple (reversed (t )) for t in flatten_schema (self .mapping , depth = self .depth ())
135155 )
156+
157+ self .udf_mapping = udf_mapping or {}
158+ self .udf_trie = new_trie (
159+ tuple (reversed (t )) for t in flatten_schema (self .udf_mapping , depth = self .udf_depth ())
160+ )
161+
136162 self ._supported_table_args : t .Tuple [str , ...] = tuple ()
137163
138164 @property
@@ -142,6 +168,9 @@ def empty(self) -> bool:
142168 def depth (self ) -> int :
143169 return dict_depth (self .mapping )
144170
171+ def udf_depth (self ) -> int :
172+ return dict_depth (self .udf_mapping )
173+
145174 @property
146175 def supported_table_args (self ) -> t .Tuple [str , ...]:
147176 if not self ._supported_table_args and self .mapping :
@@ -157,7 +186,39 @@ def supported_table_args(self) -> t.Tuple[str, ...]:
157186 return self ._supported_table_args
158187
159188 def table_parts (self , table : exp .Table ) -> t .List [str ]:
160- return [part .name for part in reversed (table .parts )]
189+ return [p .name for p in reversed (table .parts )]
190+
191+ def udf_parts (self , udf : exp .Anonymous ) -> t .List [str ]:
192+ # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
193+ parent = udf .parent
194+ parts = [p .name for p in parent .flatten ()] if isinstance (parent , exp .Dot ) else [udf .name ]
195+ return list (reversed (parts ))
196+
197+ def _find_in_trie (
198+ self ,
199+ parts : t .List [str ],
200+ trie : t .Dict ,
201+ raise_on_missing : bool ,
202+ ) -> t .Optional [t .List [str ]]:
203+ value , trie = in_trie (trie , parts )
204+
205+ if value == TrieResult .FAILED :
206+ return None
207+
208+ if value == TrieResult .PREFIX :
209+ possibilities = flatten_schema (trie )
210+
211+ if len (possibilities ) == 1 :
212+ parts .extend (possibilities [0 ])
213+ else :
214+ if raise_on_missing :
215+ joined_parts = "." .join (parts )
216+ message = ", " .join ("." .join (p ) for p in possibilities )
217+ raise SchemaError (f"Ambiguous mapping for { joined_parts } : { message } ." )
218+
219+ return None
220+
221+ return parts
161222
162223 def find (
163224 self , table : exp .Table , raise_on_missing : bool = True , ensure_data_types : bool = False
@@ -174,23 +235,35 @@ def find(
174235 The schema of the target table.
175236 """
176237 parts = self .table_parts (table )[0 : len (self .supported_table_args )]
177- value , trie = in_trie ( self .mapping_trie , parts )
238+ resolved_parts = self . _find_in_trie ( parts , self .mapping_trie , raise_on_missing )
178239
179- if value == TrieResult . FAILED :
240+ if resolved_parts is None :
180241 return None
181242
182- if value == TrieResult .PREFIX :
183- possibilities = flatten_schema (trie )
243+ return self .nested_get (resolved_parts , raise_on_missing = raise_on_missing )
184244
185- if len ( possibilities ) == 1 :
186- parts . extend ( possibilities [ 0 ])
187- else :
188- message = ", " . join ( "." . join ( parts ) for parts in possibilities )
189- if raise_on_missing :
190- raise SchemaError ( f"Ambiguous mapping for { table } : { message } ." )
191- return None
245+ def find_udf ( self , udf : exp . Anonymous , raise_on_missing : bool = False ) -> t . Optional [ t . Any ] :
246+ """
247+ Returns the return type of a given UDF.
248+
249+ Args :
250+ udf: the target UDF expression.
251+ raise_on_missing: whether to raise if the UDF is not found.
192252
193- return self .nested_get (parts , raise_on_missing = raise_on_missing )
253+ Returns:
254+ The return type of the UDF, or None if not found.
255+ """
256+ parts = self .udf_parts (udf )[0 : self .udf_depth ()]
257+ resolved_parts = self ._find_in_trie (parts , self .udf_trie , raise_on_missing )
258+
259+ if resolved_parts is None :
260+ return None
261+
262+ return nested_get (
263+ self .udf_mapping ,
264+ * zip (resolved_parts , reversed (resolved_parts )),
265+ raise_on_missing = raise_on_missing ,
266+ )
194267
195268 def nested_get (
196269 self , parts : t .Sequence [str ], d : t .Optional [t .Dict ] = None , raise_on_missing = True
@@ -227,15 +300,20 @@ def __init__(
227300 visible : t .Optional [t .Dict ] = None ,
228301 dialect : DialectType = None ,
229302 normalize : bool = True ,
303+ udf_mapping : t .Optional [t .Dict ] = None ,
230304 ) -> None :
231305 self .visible = {} if visible is None else visible
232306 self .normalize = normalize
233307 self ._dialect = Dialect .get_or_raise (dialect )
234308 self ._type_mapping_cache : t .Dict [str , exp .DataType ] = {}
235309 self ._depth = 0
236310 schema = {} if schema is None else schema
311+ udf_mapping = {} if udf_mapping is None else udf_mapping
237312
238- super ().__init__ (self ._normalize (schema ) if self .normalize else schema )
313+ super ().__init__ (
314+ self ._normalize (schema ) if self .normalize else schema ,
315+ self ._normalize_udfs (udf_mapping ) if self .normalize else udf_mapping ,
316+ )
239317
240318 @property
241319 def dialect (self ) -> Dialect :
@@ -249,6 +327,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
249327 visible = mapping_schema .visible ,
250328 dialect = mapping_schema .dialect ,
251329 normalize = mapping_schema .normalize ,
330+ udf_mapping = mapping_schema .udf_mapping ,
252331 )
253332
254333 def find (
@@ -272,6 +351,7 @@ def copy(self, **kwargs) -> MappingSchema:
272351 "visible" : self .visible .copy (),
273352 "dialect" : self .dialect ,
274353 "normalize" : self .normalize ,
354+ "udf_mapping" : self .udf_mapping .copy (),
275355 ** kwargs ,
276356 }
277357 )
@@ -360,6 +440,43 @@ def get_column_type(
360440
361441 return exp .DataType .build ("unknown" )
362442
443+ def get_udf_type (
444+ self ,
445+ udf : exp .Anonymous | str ,
446+ dialect : DialectType = None ,
447+ normalize : t .Optional [bool ] = None ,
448+ ) -> exp .DataType :
449+ """
450+ Get the return type of a UDF.
451+
452+ Args:
453+ udf: the UDF expression or string (e.g., "db.my_func()").
454+ dialect: the SQL dialect for parsing string arguments.
455+ normalize: whether to normalize identifiers.
456+
457+ Returns:
458+ The return type as a DataType, or UNKNOWN if not found.
459+ """
460+ parts = self ._normalize_udf (udf , dialect = dialect , normalize = normalize )
461+ parts = parts [0 : self .udf_depth ()]
462+ resolved_parts = self ._find_in_trie (parts , self .udf_trie , raise_on_missing = False )
463+
464+ if resolved_parts is None :
465+ return exp .DataType .build ("unknown" )
466+
467+ udf_type = nested_get (
468+ self .udf_mapping ,
469+ * zip (resolved_parts , reversed (resolved_parts )),
470+ raise_on_missing = False ,
471+ )
472+
473+ if isinstance (udf_type , exp .DataType ):
474+ return udf_type
475+ elif isinstance (udf_type , str ):
476+ return self ._to_data_type (udf_type , dialect = dialect )
477+
478+ return exp .DataType .build ("unknown" )
479+
363480 def has_column (
364481 self ,
365482 table : exp .Table | str ,
@@ -414,6 +531,55 @@ def _normalize(self, schema: t.Dict) -> t.Dict:
414531
415532 return normalized_mapping
416533
534+ def _normalize_udfs (self , udfs : t .Dict ) -> t .Dict :
535+ """
536+ Normalizes all identifiers in the UDF mapping.
537+
538+ Args:
539+ udfs: the UDF mapping to normalize.
540+
541+ Returns:
542+ The normalized UDF mapping.
543+ """
544+ normalized_mapping : t .Dict = {}
545+
546+ for keys in flatten_schema (udfs , depth = dict_depth (udfs )):
547+ udf_type = nested_get (udfs , * zip (keys , keys ))
548+ normalized_keys = [self ._normalize_name (key , is_table = True ) for key in keys ]
549+ nested_set (normalized_mapping , normalized_keys , udf_type )
550+
551+ return normalized_mapping
552+
553+ def _normalize_udf (
554+ self ,
555+ udf : exp .Anonymous | str ,
556+ dialect : DialectType = None ,
557+ normalize : t .Optional [bool ] = None ,
558+ ) -> t .List [str ]:
559+ """
560+ Extract and normalize UDF parts for lookup.
561+
562+ Args:
563+ udf: the UDF expression or qualified string (e.g., "db.my_func()").
564+ dialect: the SQL dialect for parsing.
565+ normalize: whether to normalize identifiers.
566+
567+ Returns:
568+ A list of normalized UDF parts (reversed for trie lookup).
569+ """
570+ dialect = dialect or self .dialect
571+ normalize = self .normalize if normalize is None else normalize
572+
573+ if isinstance (udf , str ):
574+ udf = t .cast (exp .Anonymous , exp .maybe_parse (udf , dialect = dialect ).this )
575+
576+ parts = self .udf_parts (udf )
577+
578+ if normalize :
579+ parts = [self ._normalize_name (part , dialect = dialect , is_table = True ) for part in parts ]
580+
581+ return parts
582+
417583 def _normalize_table (
418584 self ,
419585 table : exp .Table | str ,
0 commit comments