diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index db861b0ed5..6f36e3e18f 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -722,6 +722,12 @@ def _annotate_div(self, expression: exp.Div) -> exp.Div: def _annotate_dot(self, expression: exp.Dot) -> exp.Dot: self._set_type(expression, None) + + # Propagate type from qualified UDF calls (e.g., db.my_udf(...)) + if isinstance(expression.expression, exp.Anonymous): + self._set_type(expression, expression.expression.type) + return expression + this_type = expression.this.type if this_type and this_type.is_type(exp.DataType.Type.STRUCT): diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 73e21f44c6..fc22c38a90 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -111,6 +111,25 @@ def has_column( name = column if isinstance(column, str) else column.name return name in self.column_names(table, dialect=dialect, normalize=normalize) + def get_udf_type( + self, + udf: exp.Anonymous | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.DataType: + """ + Get the return type of a UDF. + + Args: + udf: the UDF expression or string. + dialect: the SQL dialect for parsing string arguments. + normalize: whether to normalize identifiers. + + Returns: + The return type as a DataType, or UNKNOWN if not found. + """ + return exp.DataType.build("unknown") + @property @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: @@ -128,11 +147,18 @@ class AbstractMappingSchema: def __init__( self, mapping: t.Optional[t.Dict] = None, + udf_mapping: t.Optional[t.Dict] = None, ) -> None: self.mapping = mapping or {} self.mapping_trie = new_trie( tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) ) + + self.udf_mapping = udf_mapping or {} + self.udf_trie = new_trie( + tuple(reversed(t)) for t in flatten_schema(self.udf_mapping, depth=self.udf_depth()) + ) + self._supported_table_args: t.Tuple[str, ...] = tuple() @property @@ -142,6 +168,9 @@ def empty(self) -> bool: def depth(self) -> int: return dict_depth(self.mapping) + def udf_depth(self) -> int: + return dict_depth(self.udf_mapping) + @property def supported_table_args(self) -> t.Tuple[str, ...]: if not self._supported_table_args and self.mapping: @@ -157,7 +186,39 @@ def supported_table_args(self) -> t.Tuple[str, ...]: return self._supported_table_args def table_parts(self, table: exp.Table) -> t.List[str]: - return [part.name for part in reversed(table.parts)] + return [p.name for p in reversed(table.parts)] + + def udf_parts(self, udf: exp.Anonymous) -> t.List[str]: + # a.b.c(...) is represented as Dot(Dot(a, b), Anonymous(c, ...)) + parent = udf.parent + parts = [p.name for p in parent.flatten()] if isinstance(parent, exp.Dot) else [udf.name] + return list(reversed(parts))[0 : self.udf_depth()] + + def _find_in_trie( + self, + parts: t.List[str], + trie: t.Dict, + raise_on_missing: bool, + ) -> t.Optional[t.List[str]]: + value, trie = in_trie(trie, parts) + + if value == TrieResult.FAILED: + return None + + if value == TrieResult.PREFIX: + possibilities = flatten_schema(trie) + + if len(possibilities) == 1: + parts.extend(possibilities[0]) + else: + if raise_on_missing: + joined_parts = ".".join(parts) + message = ", ".join(".".join(p) for p in possibilities) + raise SchemaError(f"Ambiguous mapping for {joined_parts}: {message}.") + + return None + + return parts def find( self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False @@ -174,23 +235,35 @@ def find( The schema of the target table. """ parts = self.table_parts(table)[0 : len(self.supported_table_args)] - value, trie = in_trie(self.mapping_trie, parts) + resolved_parts = self._find_in_trie(parts, self.mapping_trie, raise_on_missing) - if value == TrieResult.FAILED: + if resolved_parts is None: return None - if value == TrieResult.PREFIX: - possibilities = flatten_schema(trie) + return self.nested_get(resolved_parts, raise_on_missing=raise_on_missing) - if len(possibilities) == 1: - parts.extend(possibilities[0]) - else: - message = ", ".join(".".join(parts) for parts in possibilities) - if raise_on_missing: - raise SchemaError(f"Ambiguous mapping for {table}: {message}.") - return None + def find_udf(self, udf: exp.Anonymous, raise_on_missing: bool = False) -> t.Optional[t.Any]: + """ + Returns the return type of a given UDF. + + Args: + udf: the target UDF expression. + raise_on_missing: whether to raise if the UDF is not found. + + Returns: + The return type of the UDF, or None if not found. + """ + parts = self.udf_parts(udf) + resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing) - return self.nested_get(parts, raise_on_missing=raise_on_missing) + if resolved_parts is None: + return None + + return nested_get( + self.udf_mapping, + *zip(resolved_parts, reversed(resolved_parts)), + raise_on_missing=raise_on_missing, + ) def nested_get( self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True @@ -227,6 +300,7 @@ def __init__( visible: t.Optional[t.Dict] = None, dialect: DialectType = None, normalize: bool = True, + udf_mapping: t.Optional[t.Dict] = None, ) -> None: self.visible = {} if visible is None else visible self.normalize = normalize @@ -234,8 +308,12 @@ def __init__( self._type_mapping_cache: t.Dict[str, exp.DataType] = {} self._depth = 0 schema = {} if schema is None else schema + udf_mapping = {} if udf_mapping is None else udf_mapping - super().__init__(self._normalize(schema) if self.normalize else schema) + super().__init__( + self._normalize(schema) if self.normalize else schema, + self._normalize_udfs(udf_mapping) if self.normalize else udf_mapping, + ) @property def dialect(self) -> Dialect: @@ -249,6 +327,7 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: visible=mapping_schema.visible, dialect=mapping_schema.dialect, normalize=mapping_schema.normalize, + udf_mapping=mapping_schema.udf_mapping, ) def find( @@ -272,6 +351,7 @@ def copy(self, **kwargs) -> MappingSchema: "visible": self.visible.copy(), "dialect": self.dialect, "normalize": self.normalize, + "udf_mapping": self.udf_mapping.copy(), **kwargs, } ) @@ -360,6 +440,42 @@ def get_column_type( return exp.DataType.build("unknown") + def get_udf_type( + self, + udf: exp.Anonymous | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.DataType: + """ + Get the return type of a UDF. + + Args: + udf: the UDF expression or string (e.g., "db.my_func()"). + dialect: the SQL dialect for parsing string arguments. + normalize: whether to normalize identifiers. + + Returns: + The return type as a DataType, or UNKNOWN if not found. + """ + parts = self._normalize_udf(udf, dialect=dialect, normalize=normalize) + resolved_parts = self._find_in_trie(parts, self.udf_trie, raise_on_missing=False) + + if resolved_parts is None: + return exp.DataType.build("unknown") + + udf_type = nested_get( + self.udf_mapping, + *zip(resolved_parts, reversed(resolved_parts)), + raise_on_missing=False, + ) + + if isinstance(udf_type, exp.DataType): + return udf_type + elif isinstance(udf_type, str): + return self._to_data_type(udf_type, dialect=dialect) + + return exp.DataType.build("unknown") + def has_column( self, table: exp.Table | str, @@ -414,6 +530,61 @@ def _normalize(self, schema: t.Dict) -> t.Dict: return normalized_mapping + def _normalize_udfs(self, udfs: t.Dict) -> t.Dict: + """ + Normalizes all identifiers in the UDF mapping. + + Args: + udfs: the UDF mapping to normalize. + + Returns: + The normalized UDF mapping. + """ + normalized_mapping: t.Dict = {} + + for keys in flatten_schema(udfs, depth=dict_depth(udfs)): + udf_type = nested_get(udfs, *zip(keys, keys)) + normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] + nested_set(normalized_mapping, normalized_keys, udf_type) + + return normalized_mapping + + def _normalize_udf( + self, + udf: exp.Anonymous | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> t.List[str]: + """ + Extract and normalize UDF parts for lookup. + + Args: + udf: the UDF expression or qualified string (e.g., "db.my_func()"). + dialect: the SQL dialect for parsing. + normalize: whether to normalize identifiers. + + Returns: + A list of normalized UDF parts (reversed for trie lookup). + """ + dialect = dialect or self.dialect + normalize = self.normalize if normalize is None else normalize + + if isinstance(udf, str): + parsed: exp.Expression = exp.maybe_parse(udf, dialect=dialect) + + if isinstance(parsed, exp.Anonymous): + udf = parsed + elif isinstance(parsed, exp.Dot) and isinstance(parsed.expression, exp.Anonymous): + udf = parsed.expression + else: + raise SchemaError(f"Unable to parse UDF from: {udf!r}") + parts = self.udf_parts(udf) + + if normalize: + parts = [self._normalize_name(part, dialect=dialect, is_table=True) for part in parts] + + return parts + def _normalize_table( self, table: exp.Table | str, diff --git a/sqlglot/typing/__init__.py b/sqlglot/typing/__init__.py index 832049f707..a44c254132 100644 --- a/sqlglot/typing/__init__.py +++ b/sqlglot/typing/__init__.py @@ -250,13 +250,7 @@ exp.ArrayLast, } }, - **{ - expr_type: {"returns": exp.DataType.Type.UNKNOWN} - for expr_type in { - exp.Anonymous, - exp.Slice, - } - }, + exp.Anonymous: {"annotator": lambda self, e: self._set_type(e, self.schema.get_udf_type(e))}, **{ expr_type: {"annotator": lambda self, e: self._annotate_timeunit(e)} for expr_type in { diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1d74117d23..34ef9c1546 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1420,6 +1420,59 @@ def test_unknown_annotation(self): exp.DataType.Type.UNKNOWN, ) + def test_udf_annotation(self): + # Unqualified UDF + schema = MappingSchema( + schema={"t": {"col": "INT"}}, + udf_mapping={"my_func": "VARCHAR"}, + ) + expr = annotate_types(parse_one("SELECT my_func(col) FROM t"), schema=schema) + self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.VARCHAR) + + # Qualified UDF (2-level) + schema = MappingSchema( + schema={"db": {"t": {"col": "INT"}}}, + udf_mapping={"db": {"my_func": "DOUBLE"}}, + ) + expr = annotate_types(parse_one("SELECT db.my_func(col) FROM db.t"), schema=schema) + anon = expr.selects[0].find(exp.Anonymous) + self.assertEqual(anon.type.this, exp.DataType.Type.DOUBLE) + # Dot parent should also have the type + self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.DOUBLE) + + # Qualified UDF (3-level) + schema = MappingSchema( + schema={"cat": {"db": {"t": {"col": "INT"}}}}, + udf_mapping={"cat": {"db": {"my_func": "BOOLEAN"}}}, + ) + expr = annotate_types(parse_one("SELECT cat.db.my_func(col) FROM cat.db.t"), schema=schema) + anon = expr.selects[0].find(exp.Anonymous) + self.assertEqual(anon.type.this, exp.DataType.Type.BOOLEAN) + + # Unknown UDF returns UNKNOWN + schema = MappingSchema( + schema={"t": {"col": "INT"}}, + udf_mapping={"known_func": "DATE"}, + ) + expr = annotate_types(parse_one("SELECT unknown_func(col) FROM t"), schema=schema) + self.assertEqual(expr.selects[0].type.this, exp.DataType.Type.UNKNOWN) + + # Test get_udf_type with string input + schema = MappingSchema(udf_mapping={"my_func": "INT"}) + self.assertEqual(schema.get_udf_type("my_func(x)").this, exp.DataType.Type.INT) + + schema = MappingSchema(udf_mapping={"db": {"my_func": "FLOAT"}}) + self.assertEqual(schema.get_udf_type("db.my_func(x, y)").this, exp.DataType.Type.FLOAT) + + schema = MappingSchema(udf_mapping={"cat": {"db": {"my_func": "DATE"}}}) + self.assertEqual( + schema.get_udf_type("cat.db.my_func(a, b, c)").this, exp.DataType.Type.DATE + ) + + # Unknown UDF string returns UNKNOWN + schema = MappingSchema(udf_mapping={"known": "INT"}) + self.assertEqual(schema.get_udf_type("unknown(x)").this, exp.DataType.Type.UNKNOWN) + def test_predicate_annotation(self): expression = annotate_types(parse_one("x BETWEEN a AND b")) self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)