Skip to content

Commit b921f72

Browse files
committed
Feat(optimizer): UDF annotation
1 parent 36211c2 commit b921f72

4 files changed

Lines changed: 229 additions & 21 deletions

File tree

sqlglot/optimizer/annotate_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,12 @@ def _annotate_div(self, expression: exp.Div) -> exp.Div:
722722

723723
def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
724724
self._set_type(expression, None)
725+
726+
# Propagate type from qualified UDF calls (e.g., db.my_udf(...))
727+
if isinstance(expression.expression, exp.Anonymous):
728+
self._set_type(expression, expression.expression.type)
729+
return expression
730+
725731
this_type = expression.this.type
726732

727733
if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
@@ -860,3 +866,7 @@ def _annotate_by_array_element(self, expression: exp.Expression) -> exp.Expressi
860866
self._set_type(expression, exp.DataType.Type.UNKNOWN)
861867

862868
return expression
869+
870+
def _annotate_anonymous(self, expression: exp.Anonymous) -> exp.Anonymous:
871+
self._set_type(expression, self.schema.get_udf_type(expression))
872+
return expression

sqlglot/schema.py

Lines changed: 180 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,25 @@ def has_column(
111111
name = column if isinstance(column, str) else column.name
112112
return name in self.column_names(table, dialect=dialect, normalize=normalize)
113113

114+
def get_udf_type(
115+
self,
116+
udf: exp.Anonymous | str,
117+
dialect: DialectType = None,
118+
normalize: t.Optional[bool] = None,
119+
) -> exp.DataType:
120+
"""
121+
Get the return type of a UDF.
122+
123+
Args:
124+
udf: the UDF expression or string.
125+
dialect: the SQL dialect for parsing string arguments.
126+
normalize: whether to normalize identifiers.
127+
128+
Returns:
129+
The return type as a DataType, or UNKNOWN if not found.
130+
"""
131+
return exp.DataType.build("unknown")
132+
114133
@property
115134
@abc.abstractmethod
116135
def supported_table_args(self) -> t.Tuple[str, ...]:
@@ -128,11 +147,18 @@ class AbstractMappingSchema:
128147
def __init__(
129148
self,
130149
mapping: t.Optional[t.Dict] = None,
150+
udf_mapping: t.Optional[t.Dict] = None,
131151
) -> None:
132152
self.mapping = mapping or {}
133153
self.mapping_trie = new_trie(
134154
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
135155
)
156+
157+
self.udf_mapping = udf_mapping or {}
158+
self.udf_trie = new_trie(
159+
tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
160+
)
161+
136162
self._supported_table_args: t.Tuple[str, ...] = tuple()
137163

138164
@property
@@ -142,6 +168,9 @@ def empty(self) -> bool:
142168
def depth(self) -> int:
143169
return dict_depth(self.mapping)
144170

171+
def udf_depth(self) -> int:
172+
return dict_depth(self.udf_mapping)
173+
145174
@property
146175
def supported_table_args(self) -> t.Tuple[str, ...]:
147176
if not self._supported_table_args and self.mapping:
@@ -157,7 +186,39 @@ def supported_table_args(self) -> t.Tuple[str, ...]:
157186
return self._supported_table_args
158187

159188
def table_parts(self, table: exp.Table) -> t.List[str]:
160-
return [part.name for part in reversed(table.parts)]
189+
return [p.name for p in reversed(table.parts)]
190+
191+
def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
192+
# a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
193+
parent = udf.parent
194+
parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
195+
return list(reversed(parts))
196+
197+
def _find_in_trie(
198+
self,
199+
parts: t.List[str],
200+
trie: t.Dict,
201+
raise_on_missing: bool,
202+
) -> t.Optional[t.List[str]]:
203+
value, trie = in_trie(trie, parts)
204+
205+
if value == TrieResult.FAILED:
206+
return None
207+
208+
if value == TrieResult.PREFIX:
209+
possibilities = flatten_schema(trie)
210+
211+
if len(possibilities) == 1:
212+
parts.extend(possibilities[0])
213+
else:
214+
if raise_on_missing:
215+
joined_parts = ".".join(parts)
216+
message = ", ".join(".".join(p) for p in possibilities)
217+
raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")
218+
219+
return None
220+
221+
return parts
161222

162223
def find(
163224
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
@@ -174,23 +235,35 @@ def find(
174235
The schema of the target table.
175236
"""
176237
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
177-
value, trie = in_trie(self.mapping_trie, parts)
238+
resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)
178239

179-
if value == TrieResult.FAILED:
240+
if resolved_parts is None:
180241
return None
181242

182-
if value == TrieResult.PREFIX:
183-
possibilities = flatten_schema(trie)
243+
return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)
184244

185-
if len(possibilities) == 1:
186-
parts.extend(possibilities[0])
187-
else:
188-
message = ", ".join(".".join(parts) for parts in possibilities)
189-
if raise_on_missing:
190-
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
191-
return None
245+
def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
246+
"""
247+
Returns the return type of a given UDF.
248+
249+
Args:
250+
udf: the target UDF expression.
251+
raise_on_missing: whether to raise if the UDF is not found.
192252
193-
return self.nested_get(parts, raise_on_missing=raise_on_missing)
253+
Returns:
254+
The return type of the UDF, or None if not found.
255+
"""
256+
parts = self.udf_parts(udf)[0 : self.udf_depth()]
257+
resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)
258+
259+
if resolved_parts is None:
260+
return None
261+
262+
return nested_get(
263+
self.udf_mapping,
264+
*zip(resolved_parts, reversed(resolved_parts)),
265+
raise_on_missing=raise_on_missing,
266+
)
194267

195268
def nested_get(
196269
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
@@ -227,15 +300,20 @@ def __init__(
227300
visible: t.Optional[t.Dict] = None,
228301
dialect: DialectType = None,
229302
normalize: bool = True,
303+
udf_mapping: t.Optional[t.Dict] = None,
230304
) -> None:
231305
self.visible = {} if visible is None else visible
232306
self.normalize = normalize
233307
self._dialect = Dialect.get_or_raise(dialect)
234308
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
235309
self._depth = 0
236310
schema = {} if schema is None else schema
311+
udf_mapping = {} if udf_mapping is None else udf_mapping
237312

238-
super().__init__(self._normalize(schema) if self.normalize else schema)
313+
super().__init__(
314+
self._normalize(schema) if self.normalize else schema,
315+
self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
316+
)
239317

240318
@property
241319
def dialect(self) -> Dialect:
@@ -249,6 +327,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
249327
visible=mapping_schema.visible,
250328
dialect=mapping_schema.dialect,
251329
normalize=mapping_schema.normalize,
330+
udf_mapping=mapping_schema.udf_mapping,
252331
)
253332

254333
def find(
@@ -272,6 +351,7 @@ def copy(self, **kwargs) -> MappingSchema:
272351
"visible": self.visible.copy(),
273352
"dialect": self.dialect,
274353
"normalize": self.normalize,
354+
"udf_mapping": self.udf_mapping.copy(),
275355
**kwargs,
276356
}
277357
)
@@ -360,6 +440,43 @@ def get_column_type(
360440

361441
return exp.DataType.build("unknown")
362442

443+
def get_udf_type(
444+
self,
445+
udf: exp.Anonymous | str,
446+
dialect: DialectType = None,
447+
normalize: t.Optional[bool] = None,
448+
) -> exp.DataType:
449+
"""
450+
Get the return type of a UDF.
451+
452+
Args:
453+
udf: the UDF expression or string (e.g., "db.my_func()").
454+
dialect: the SQL dialect for parsing string arguments.
455+
normalize: whether to normalize identifiers.
456+
457+
Returns:
458+
The return type as a DataType, or UNKNOWN if not found.
459+
"""
460+
parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
461+
parts = parts[0 : self.udf_depth()]
462+
resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)
463+
464+
if resolved_parts is None:
465+
return exp.DataType.build("unknown")
466+
467+
udf_type = nested_get(
468+
self.udf_mapping,
469+
*zip(resolved_parts, reversed(resolved_parts)),
470+
raise_on_missing=False,
471+
)
472+
473+
if isinstance(udf_type, exp.DataType):
474+
return udf_type
475+
elif isinstance(udf_type, str):
476+
return self._to_data_type(udf_type, dialect=dialect)
477+
478+
return exp.DataType.build("unknown")
479+
363480
def has_column(
364481
self,
365482
table: exp.Table | str,
@@ -414,6 +531,55 @@ def _normalize(self, schema: t.Dict) -> t.Dict:
414531

415532
return normalized_mapping
416533

534+
def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
535+
"""
536+
Normalizes all identifiers in the UDF mapping.
537+
538+
Args:
539+
udfs: the UDF mapping to normalize.
540+
541+
Returns:
542+
The normalized UDF mapping.
543+
"""
544+
normalized_mapping: t.Dict = {}
545+
546+
for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
547+
udf_type = nested_get(udfs, *zip(keys, keys))
548+
normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
549+
nested_set(normalized_mapping, normalized_keys, udf_type)
550+
551+
return normalized_mapping
552+
553+
def _normalize_udf(
554+
self,
555+
udf: exp.Anonymous | str,
556+
dialect: DialectType = None,
557+
normalize: t.Optional[bool] = None,
558+
) -> t.List[str]:
559+
"""
560+
Extract and normalize UDF parts for lookup.
561+
562+
Args:
563+
udf: the UDF expression or qualified string (e.g., "db.my_func()").
564+
dialect: the SQL dialect for parsing.
565+
normalize: whether to normalize identifiers.
566+
567+
Returns:
568+
A list of normalized UDF parts (reversed for trie lookup).
569+
"""
570+
dialect = dialect or self.dialect
571+
normalize = self.normalize if normalize is None else normalize
572+
573+
if isinstance(udf, str):
574+
udf = t.cast(exp.Anonymous, exp.maybe_parse(udf, dialect=dialect).this)
575+
576+
parts = self.udf_parts(udf)
577+
578+
if normalize:
579+
parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]
580+
581+
return parts
582+
417583
def _normalize_table(
418584
self,
419585
table: exp.Table | str,

sqlglot/typing/__init__.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,8 @@
250250
exp.ArrayLast,
251251
}
252252
},
253-
**{
254-
expr_type: {"returns": exp.DataType.Type.UNKNOWN}
255-
for expr_type in {
256-
exp.Anonymous,
257-
exp.Slice,
258-
}
259-
},
253+
exp.Anonymous: {"annotator": lambda self, e: self._annotate_anonymous(e)},
254+
exp.Slice: {"returns": exp.DataType.Type.UNKNOWN},
260255
**{
261256
expr_type: {"annotator": lambda self, e: self._annotate_timeunit(e)}
262257
for expr_type in {

tests/test_optimizer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,43 @@ def test_unknown_annotation(self):
14201420
exp.DataType.Type.UNKNOWN,
14211421
)
14221422

1423+
def test_udf_annotation(self):
1424+
# Unqualified UDF
1425+
schema = MappingSchema(
1426+
schema={"t": {"col": "INT"}},
1427+
udf_mapping={"my_func": "VARCHAR"},
1428+
)
1429+
expr = annotate_types(parse_one("SELECT my_func(col) FROM t"), schema=schema)
1430+
self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.VARCHAR)
1431+
1432+
# Qualified UDF (2-level)
1433+
schema = MappingSchema(
1434+
schema={"db": {"t": {"col": "INT"}}},
1435+
udf_mapping={"db": {"my_func": "DOUBLE"}},
1436+
)
1437+
expr = annotate_types(parse_one("SELECT db.my_func(col) FROM db.t"), schema=schema)
1438+
anon = expr.selects[0].find(exp.Anonymous)
1439+
self.assertEqual(anon.type.this, exp.DataType.Type.DOUBLE)
1440+
# Dot parent should also have the type
1441+
self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.DOUBLE)
1442+
1443+
# Qualified UDF (3-level)
1444+
schema = MappingSchema(
1445+
schema={"cat": {"db": {"t": {"col": "INT"}}}},
1446+
udf_mapping={"cat": {"db": {"my_func": "BOOLEAN"}}},
1447+
)
1448+
expr = annotate_types(parse_one("SELECT cat.db.my_func(col) FROM cat.db.t"), schema=schema)
1449+
anon = expr.selects[0].find(exp.Anonymous)
1450+
self.assertEqual(anon.type.this, exp.DataType.Type.BOOLEAN)
1451+
1452+
# Unknown UDF returns UNKNOWN
1453+
schema = MappingSchema(
1454+
schema={"t": {"col": "INT"}},
1455+
udf_mapping={"known_func": "DATE"},
1456+
)
1457+
expr = annotate_types(parse_one("SELECT unknown_func(col) FROM t"), schema=schema)
1458+
self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.UNKNOWN)
1459+
14231460
def test_predicate_annotation(self):
14241461
expression = annotate_types(parse_one("x BETWEEN a AND b"))
14251462
self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)

0 commit comments

Comments
 (0)