diff --git a/sqlglot/expressions/properties.py b/sqlglot/expressions/properties.py index 261c868165..9f8cc32054 100644 --- a/sqlglot/expressions/properties.py +++ b/sqlglot/expressions/properties.py @@ -5,8 +5,8 @@ import typing as t from enum import auto +from sqlglot.expressions.core import ColumnConstraintKind, Expression, Literal, convert from sqlglot.helper import AutoName -from sqlglot.expressions.core import Expression, ColumnConstraintKind, Literal, convert class Property(Expression): @@ -63,6 +63,10 @@ class BlockCompressionProperty(Property): } +class CalledOnNullInputProperty(Property): + arg_types = {} + + class CatalogProperty(Property): arg_types = {} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 6dd4294d34..023304e55a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -149,6 +149,7 @@ class Generator: exp.CaseSpecificColumnConstraint: lambda _, e: ( f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC" ), + exp.CalledOnNullInputProperty: lambda *_: "CALLED ON NULL INPUT", exp.Ceil: lambda self, e: self.ceil_floor(e), exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.CharacterSetProperty: lambda self, e: ( @@ -677,6 +678,7 @@ class Generator: exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA, exp.BackupProperty: exp.Properties.Location.POST_SCHEMA, exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, + exp.CalledOnNullInputProperty: exp.Properties.Location.POST_SCHEMA, exp.CatalogProperty: exp.Properties.Location.POST_CREATE, exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, exp.ChecksumProperty: exp.Properties.Location.POST_NAME, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 57a645f9d6..c4de06bb53 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4,7 +4,9 @@ import logging import re import typing as t +from builtins import type as Type from collections import defaultdict +from collections.abc import Sequence from sqlglot import exp from sqlglot.errors import ( @@ -17,20 +19,17 @@ ) from sqlglot.expressions import apply_index_offset from sqlglot.helper import ensure_list, i64, seq_get -from sqlglot.trie import new_trie from sqlglot.time import format_time from sqlglot.tokens import Token, Tokenizer, TokenType -from sqlglot.trie import TrieResult, in_trie -from collections.abc import Sequence -from builtins import type as Type +from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: - from sqlglot.expressions import ExpOrStr - from sqlglot._typing import E, BuilderArgs - from sqlglot.dialects.dialect import Dialect, DialectType - from re import Pattern + from sqlglot._typing import BuilderArgs, E + from sqlglot.dialects.dialect import Dialect, DialectType + from sqlglot.expressions import ExpOrStr + T = t.TypeVar("T") TCeilFloor = t.TypeVar("TCeilFloor", exp.Ceil, exp.Floor) @@ -1236,6 +1235,7 @@ class Parser: exp.BackupProperty(this=self._parse_var(any_token=True)) ), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), + "CALLED": lambda self: self._parse_called_on_null_input_property(), "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), "CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs), "CHECKSUM": lambda self: self._parse_checksum(), @@ -2900,6 +2900,13 @@ def _parse_settings_property(self) -> exp.SettingsProperty: exp.SettingsProperty(expressions=self._parse_csv(self._parse_assignment)) ) + def _parse_called_on_null_input_property(self) -> exp.CalledOnNullInputProperty | None: + if not self._match_text_seq("ON", "NULL", "INPUT"): + self._retreat(self._index - 1) + return None + + return self.expression(exp.CalledOnNullInputProperty()) + def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty: if self._index >= 2: pre_volatile_token = self._tokens[self._index - 2] diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 99077113ea..4f4f354845 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1292,8 +1292,8 @@ def test_ddl(self): ) self.validate_identity( "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT", - check_command_warning=True, - ) + "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT AS 'select $1 + $2;'", + ).assert_is(exp.Create) self.validate_identity( "CREATE CONSTRAINT TRIGGER my_trigger AFTER INSERT OR DELETE OR UPDATE OF col_a, col_b ON public.my_table DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION DO_STH()" ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 852896052f..0757d1a126 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -4322,7 +4322,7 @@ def test_ddl(self): ) self.validate_identity( """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT RETURNS NULL ON NULL INPUT AS ' return Object.values(obj) '""" - ) + ).assert_is(exp.Create) self.validate_identity( """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT STRICT AS ' return Object.values(obj) '""" ) @@ -4427,7 +4427,9 @@ def test_user_defined_functions(self): "snowflake": "CREATE FUNCTION a() RETURNS INT IMMUTABLE AS 'SELECT 1'", }, ) - + self.validate_identity( + "CREATE FUNCTION a(x DOUBLE) RETURNS DOUBLE LANGUAGE SQL CALLED ON NULL INPUT AS ' x * 2 '" + ).assert_is(exp.Create) self.validate_identity( "CREATE OR REPLACE FUNCTION repro_fn() RETURNS INT LANGUAGE PYTHON HANDLER = 'fn' RUNTIME_VERSION='3.11' PACKAGES=() AS '\\ndef fn():\\n return 1\\n'" ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 33fc0e7908..7144686303 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -632,6 +632,7 @@ CREATE FUNCTION a.b(x INT) RETURNS INT AS RETURN x + 1 CREATE FUNCTION a.b(x TEXT) RETURNS TEXT CONTAINS SQL AS RETURN x CREATE FUNCTION a.b(x TEXT) RETURNS TEXT LANGUAGE SQL MODIFIES SQL DATA AS RETURN x CREATE FUNCTION a.b(x TEXT) LANGUAGE SQL READS SQL DATA RETURNS TEXT AS RETURN x +CREATE FUNCTION a(x INT) RETURNS INT LANGUAGE SQL CALLED ON NULL INPUT AS 'SELECT 1' CREATE FUNCTION a.b.c() CREATE INDEX abc ON t(a) CREATE INDEX "abc" ON t(a)