diff --git a/sqlglot/expressions/core.py b/sqlglot/expressions/core.py index c2cbee6d4f..99bf642e27 100644 --- a/sqlglot/expressions/core.py +++ b/sqlglot/expressions/core.py @@ -10,12 +10,14 @@ import sys import textwrap import typing as t +from builtins import type as Type from collections import deque +from collections.abc import Collection, Iterator, Mapping, MutableMapping, Sequence from copy import deepcopy from decimal import Decimal from functools import reduce -from collections.abc import Iterator, Sequence, Collection, Mapping, MutableMapping -from sqlglot._typing import E, T + +from sqlglot._typing import E, GeneratorNoDialectArgs, ParserNoDialectArgs, T from sqlglot.errors import ParseError from sqlglot.helper import ( camel_to_snake_case, @@ -24,17 +26,15 @@ to_bool, trait, ) - from sqlglot.tokenizer_core import Token -from builtins import type as Type -from sqlglot._typing import GeneratorNoDialectArgs, ParserNoDialectArgs if t.TYPE_CHECKING: - from typing_extensions import Self, Unpack, Concatenate + from typing_extensions import Concatenate, Self, Unpack + + from sqlglot._typing import P from sqlglot.dialects.dialect import DialectType from sqlglot.expressions.datatypes import DATA_TYPE, DataType, DType, Interval from sqlglot.expressions.query import Select - from sqlglot._typing import P R = t.TypeVar("R") @@ -1808,8 +1808,10 @@ def output_name(self) -> str: # https://docs.snowflake.com/en/sql-reference/identifier-literal +# "expressions" holds the arguments when the resolved identifier is invoked as a +# function, e.g. `IDENTIFIER('my_func')(1, 2)` class DynamicIdentifier(Expression, Func): - pass + arg_types = {"this": True, "expressions": False} class Opclass(Expression): @@ -2425,7 +2427,8 @@ def convert(value: t.Any, copy: bool = False) -> Expr: return _Array(expressions=[convert(v, copy=copy) for v in value]) if isinstance(value, dict): - from sqlglot.expressions.array import Array as _Array, Map as _Map + from sqlglot.expressions.array import Array as _Array + from sqlglot.expressions.array import Map as _Map return _Map( keys=_Array(expressions=[convert(k, copy=copy) for k in value]), diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 023304e55a..af1f42130c 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1960,7 +1960,14 @@ def index_sql(self, expression: exp.Index) -> str: def dynamicidentifier_sql(self, expression: exp.DynamicIdentifier) -> str: this = expression.this if this and this.is_string: - return maybe_parse(this.name).sql(self.dialect) + resolved = maybe_parse(this.name).sql(self.dialect) + if "expressions" in expression.args: + # `IDENTIFIER(...)` invoked as a function, e.g. `IDENTIFIER('my_func')(1, 2)` + # We can't safely emit the call to other dialects since name/arg semantics may differ + self.unsupported( + "Transpiling dynamically-invoked IDENTIFIER() functions is unsupported" + ) + return resolved self.unsupported("IDENTIFIER() with non-literal arguments is not supported") return self.func("IDENTIFIER", this) diff --git a/sqlglot/generators/snowflake.py b/sqlglot/generators/snowflake.py index 401188f4e1..dd8522094c 100644 --- a/sqlglot/generators/snowflake.py +++ b/sqlglot/generators/snowflake.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from collections import defaultdict from sqlglot import exp, generator, transforms from sqlglot.dialects.dialect import ( @@ -34,7 +35,6 @@ build_object_construct, ) from sqlglot.tokens import TokenType -from collections import defaultdict if t.TYPE_CHECKING: from sqlglot._typing import E @@ -475,7 +475,6 @@ class SnowflakeGenerator(generator.Generator): exp.DayOfWeekIso: rename_func("DAYOFWEEKISO"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DotProduct: rename_func("VECTOR_INNER_PRODUCT"), - exp.DynamicIdentifier: rename_func("IDENTIFIER"), exp.Explode: rename_func("FLATTEN"), exp.Extract: lambda self, e: self.func( "DATE_PART", map_date_part(e.this, self.dialect), e.expression @@ -617,6 +616,13 @@ class SnowflakeGenerator(generator.Generator): ), } + def dynamicidentifier_sql(self, expression: exp.DynamicIdentifier) -> str: + this = self.func("IDENTIFIER", expression.this) + if "expressions" in expression.args: + # `IDENTIFIER(...)` invoked as a function, e.g. `IDENTIFIER('my_func')(1, 2)` + return self.func(this, *expression.expressions, normalize=False) + return this + def sortarray_sql(self, expression: exp.SortArray) -> str: asc = expression.args.get("asc") nulls_first = expression.args.get("nulls_first") diff --git a/sqlglot/parsers/snowflake.py b/sqlglot/parsers/snowflake.py index 99dbb27595..be57cf67de 100644 --- a/sqlglot/parsers/snowflake.py +++ b/sqlglot/parsers/snowflake.py @@ -1156,6 +1156,31 @@ def _parse_table( return table + def _parse_function_call( + self, + functions: dict[str, t.Callable] | None = None, + anonymous: bool = False, + optional_parens: bool = True, + any_token: bool = False, + ) -> exp.Expr | None: + this = super()._parse_function_call( + functions=functions, + anonymous=anonymous, + optional_parens=optional_parens, + any_token=any_token, + ) + + # Snowflake can invoke a function whose name is dynamically resolved, e.g. + # `IDENTIFIER('my_func')(1, 2)`. The trailing argument list is the call's arguments. + # + # https://docs.snowflake.com/en/sql-reference/identifier-literal + if isinstance(this, exp.DynamicIdentifier) and self._match( + TokenType.L_PAREN, advance=False + ): + this.set("expressions", self._parse_wrapped_csv(self._parse_lambda)) + + return this + def _parse_id_var( self, any_token: bool = True, diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 35ba376744..8439b282f7 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -769,6 +769,15 @@ def test_snowflake(self): self.validate_identity("SELECT KURTOSIS(x) OVER (PARTITION BY 1)") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x") + self.validate_identity("SELECT IDENTIFIER($my_function_name)()") + self.validate_identity("SELECT IDENTIFIER('speed_of_light')()") + self.validate_all( + "SELECT IDENTIFIER('my_func')(1, 2)", + write={ + "snowflake": "SELECT IDENTIFIER('my_func')(1, 2)", + "duckdb": UnsupportedError, + }, + ) self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')") self.validate_identity("OBJECT_CONSTRUCT(*)") self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'")