Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,12 @@ def _annotate_div(self, expression: exp.Div) -> exp.Div:

def _annotate_dot(self, expression: exp.Dot) -> exp.Dot:
self._set_type(expression, None)

# Propagate type from qualified UDF calls (e.g., db.my_udf(...))
if isinstance(expression.expression, exp.Anonymous):
self._set_type(expression, expression.expression.type)
return expression

this_type = expression.this.type

if this_type and this_type.is_type(exp.DataType.Type.STRUCT):
Expand Down
199 changes: 185 additions & 14 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,25 @@ def has_column(
name = column if isinstance(column, str) else column.name
return name in self.column_names(table, dialect=dialect, normalize=normalize)

def get_udf_type(
self,
udf: exp.Anonymous | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
"""
Get the return type of a UDF.

Args:
udf: the UDF expression or string.
dialect: the SQL dialect for parsing string arguments.
normalize: whether to normalize identifiers.

Returns:
The return type as a DataType, or UNKNOWN if not found.
"""
return exp.DataType.build("unknown")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly done to preserve backwards-compatibility. We could also make it abstract and label this as breaking.

cc @barakalon


@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
Expand All @@ -128,11 +147,18 @@ class AbstractMappingSchema:
def __init__(
self,
mapping: t.Optional[t.Dict] = None,
udf_mapping: t.Optional[t.Dict] = None,
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
)

self.udf_mapping = udf_mapping or {}
self.udf_trie = new_trie(
tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth())
)

self._supported_table_args: t.Tuple[str, ...] = tuple()

@property
Expand All @@ -142,6 +168,9 @@ def empty(self) -> bool:
def depth(self) -> int:
return dict_depth(self.mapping)

def udf_depth(self) -> int:
return dict_depth(self.udf_mapping)

@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
Expand All @@ -157,7 +186,39 @@ def supported_table_args(self) -> t.Tuple[str, ...]:
return self._supported_table_args

def table_parts(self, table: exp.Table) -> t.List[str]:
return [part.name for part in reversed(table.parts)]
return [p.name for p in reversed(table.parts)]

def udf_parts(self, udf: exp.Anonymous) -> t.List[str]:
# a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...))
parent = udf.parent
parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name]
return list(reversed(parts))[0 : self.udf_depth()]

def _find_in_trie(
self,
parts: t.List[str],
trie: t.Dict,
raise_on_missing: bool,
) -> t.Optional[t.List[str]]:
value, trie = in_trie(trie, parts)

if value == TrieResult.FAILED:
return None

if value == TrieResult.PREFIX:
possibilities = flatten_schema(trie)

if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
if raise_on_missing:
joined_parts = ".".join(parts)
message = ", ".join(".".join(p) for p in possibilities)
raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.")

return None

return parts

def find(
self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False
Expand All @@ -174,23 +235,35 @@ def find(
The schema of the target table.
"""
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie, parts)
resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing)

if value == TrieResult.FAILED:
if resolved_parts is None:
return None

if value == TrieResult.PREFIX:
possibilities = flatten_schema(trie)
return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing)

if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
return None
def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]:
"""
Returns the return type of a given UDF.

Args:
udf: the target UDF expression.
raise_on_missing: whether to raise if the UDF is not found.

Returns:
The return type of the UDF, or None if not found.
"""
parts = self.udf_parts(udf)
resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing)

return self.nested_get(parts, raise_on_missing=raise_on_missing)
if resolved_parts is None:
return None

return nested_get(
self.udf_mapping,
*zip(resolved_parts, reversed(resolved_parts)),
raise_on_missing=raise_on_missing,
)

def nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
Expand Down Expand Up @@ -227,15 +300,20 @@ def __init__(
visible: t.Optional[t.Dict] = None,
dialect: DialectType = None,
normalize: bool = True,
udf_mapping: t.Optional[t.Dict] = None,
) -> None:
self.visible = {} if visible is None else visible
self.normalize = normalize
self._dialect = Dialect.get_or_raise(dialect)
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
self._depth = 0
schema = {} if schema is None else schema
udf_mapping = {} if udf_mapping is None else udf_mapping

super().__init__(self._normalize(schema) if self.normalize else schema)
super().__init__(
self._normalize(schema) if self.normalize else schema,
self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping,
)

@property
def dialect(self) -> Dialect:
Expand All @@ -249,6 +327,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
normalize=mapping_schema.normalize,
udf_mapping=mapping_schema.udf_mapping,
)

def find(
Expand All @@ -272,6 +351,7 @@ def copy(self, **kwargs) -> MappingSchema:
"visible": self.visible.copy(),
"dialect": self.dialect,
"normalize": self.normalize,
"udf_mapping": self.udf_mapping.copy(),
**kwargs,
}
)
Expand Down Expand Up @@ -360,6 +440,42 @@ def get_column_type(

return exp.DataType.build("unknown")

def get_udf_type(
self,
udf: exp.Anonymous | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> exp.DataType:
"""
Get the return type of a UDF.

Args:
udf: the UDF expression or string (e.g., "db.my_func()").
dialect: the SQL dialect for parsing string arguments.
normalize: whether to normalize identifiers.

Returns:
The return type as a DataType, or UNKNOWN if not found.
"""
parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize)
resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False)

if resolved_parts is None:
return exp.DataType.build("unknown")

udf_type = nested_get(
self.udf_mapping,
*zip(resolved_parts, reversed(resolved_parts)),
raise_on_missing=False,
)

if isinstance(udf_type, exp.DataType):
return udf_type
elif isinstance(udf_type, str):
return self._to_data_type(udf_type, dialect=dialect)

return exp.DataType.build("unknown")

def has_column(
self,
table: exp.Table | str,
Expand Down Expand Up @@ -414,6 +530,61 @@ def _normalize(self, schema: t.Dict) -> t.Dict:

return normalized_mapping

def _normalize_udfs(self, udfs: t.Dict) -> t.Dict:
"""
Normalizes all identifiers in the UDF mapping.

Args:
udfs: the UDF mapping to normalize.

Returns:
The normalized UDF mapping.
"""
normalized_mapping: t.Dict = {}

for keys in flatten_schema(udfs, depth=dict_depth(udfs)):
udf_type = nested_get(udfs, *zip(keys, keys))
normalized_keys = [self._normalize_name(key, is_table=True) for key in keys]
nested_set(normalized_mapping, normalized_keys, udf_type)

return normalized_mapping

def _normalize_udf(
self,
udf: exp.Anonymous | str,
dialect: DialectType = None,
normalize: t.Optional[bool] = None,
) -> t.List[str]:
"""
Extract and normalize UDF parts for lookup.

Args:
udf: the UDF expression or qualified string (e.g., "db.my_func()").
dialect: the SQL dialect for parsing.
normalize: whether to normalize identifiers.

Returns:
A list of normalized UDF parts (reversed for trie lookup).
"""
dialect = dialect or self.dialect
normalize = self.normalize if normalize is None else normalize

if isinstance(udf, str):
parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect)

if isinstance(parsed, exp.Anonymous):
udf = parsed
elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous):
udf = parsed.expression
else:
raise SchemaError(f"Unable to parse UDF from: {udf!r}")
parts = self.udf_parts(udf)

if normalize:
parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts]

return parts

def _normalize_table(
self,
table: exp.Table | str,
Expand Down
8 changes: 1 addition & 7 deletions sqlglot/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,7 @@
exp.ArrayLast,
}
},
**{
expr_type: {"returns": exp.DataType.Type.UNKNOWN}
for expr_type in {
exp.Anonymous,
exp.Slice,
}
},
exp.Anonymous: {"annotator": lambda self, e: self._set_type(e, self.schema.get_udf_type(e))},
**{
expr_type: {"annotator": lambda self, e: self._annotate_timeunit(e)}
for expr_type in {
Expand Down
53 changes: 53 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,59 @@ def test_unknown_annotation(self):
exp.DataType.Type.UNKNOWN,
)

def test_udf_annotation(self):
# Unqualified UDF
schema = MappingSchema(
schema={"t": {"col": "INT"}},
udf_mapping={"my_func": "VARCHAR"},
)
expr = annotate_types(parse_one("SELECT my_func(col) FROM t"), schema=schema)
self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.VARCHAR)

# Qualified UDF (2-level)
schema = MappingSchema(
schema={"db": {"t": {"col": "INT"}}},
udf_mapping={"db": {"my_func": "DOUBLE"}},
)
expr = annotate_types(parse_one("SELECT db.my_func(col) FROM db.t"), schema=schema)
anon = expr.selects[0].find(exp.Anonymous)
self.assertEqual(anon.type.this, exp.DataType.Type.DOUBLE)
# Dot parent should also have the type
self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.DOUBLE)

# Qualified UDF (3-level)
schema = MappingSchema(
schema={"cat": {"db": {"t": {"col": "INT"}}}},
udf_mapping={"cat": {"db": {"my_func": "BOOLEAN"}}},
)
expr = annotate_types(parse_one("SELECT cat.db.my_func(col) FROM cat.db.t"), schema=schema)
anon = expr.selects[0].find(exp.Anonymous)
self.assertEqual(anon.type.this, exp.DataType.Type.BOOLEAN)
Comment thread
georgesittas marked this conversation as resolved.

# Unknown UDF returns UNKNOWN
schema = MappingSchema(
schema={"t": {"col": "INT"}},
udf_mapping={"known_func": "DATE"},
)
expr = annotate_types(parse_one("SELECT unknown_func(col) FROM t"), schema=schema)
self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.UNKNOWN)

# Test get_udf_type with string input
schema = MappingSchema(udf_mapping={"my_func": "INT"})
self.assertEqual(schema.get_udf_type("my_func(x)").this, exp.DataType.Type.INT)

schema = MappingSchema(udf_mapping={"db": {"my_func": "FLOAT"}})
self.assertEqual(schema.get_udf_type("db.my_func(x, y)").this, exp.DataType.Type.FLOAT)

schema = MappingSchema(udf_mapping={"cat": {"db": {"my_func": "DATE"}}})
self.assertEqual(
schema.get_udf_type("cat.db.my_func(a, b, c)").this, exp.DataType.Type.DATE
)

# Unknown UDF string returns UNKNOWN
schema = MappingSchema(udf_mapping={"known": "INT"})
self.assertEqual(schema.get_udf_type("unknown(x)").this, exp.DataType.Type.UNKNOWN)

def test_predicate_annotation(self):
expression = annotate_types(parse_one("x BETWEEN a AND b"))
self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
Expand Down