diff --git a/sqlglot-integration-tests b/sqlglot-integration-tests index ead3ade9d4..ea97cd804a 160000 --- a/sqlglot-integration-tests +++ b/sqlglot-integration-tests @@ -1 +1 @@ -Subproject commit ead3ade9d4a8d212a044f950bf2e5e078043e4e7 +Subproject commit ea97cd804a409d5e28155d49d89ee69a8924e167 diff --git a/sqlglot/generators/duckdb.py b/sqlglot/generators/duckdb.py index a150bf192f..373238aab3 100644 --- a/sqlglot/generators/duckdb.py +++ b/sqlglot/generators/duckdb.py @@ -2352,6 +2352,20 @@ def tobinary_sql(self, expression: exp.ToBinary) -> str: result = self.func("TO_BINARY", value) return f"TRY({result})" if is_safe else result + def tonumber_sql(self, expression: exp.ToNumber) -> str: + fmt = expression.args.get("format") + precision = expression.args.get("precision") + scale = expression.args.get("scale") + + if not fmt and precision and scale: + return self.sql( + exp.cast( + expression.this, f"DECIMAL({precision.name}, {scale.name})", dialect="duckdb" + ) + ) + + return super().tonumber_sql(expression) + def _greatest_least_sql(self, expression: exp.Greatest | exp.Least) -> str: """ Handle GREATEST/LEAST functions with dialect-aware NULL behavior. diff --git a/sqlglot/generators/snowflake.py b/sqlglot/generators/snowflake.py index 2eaa635706..38a1c7c180 100644 --- a/sqlglot/generators/snowflake.py +++ b/sqlglot/generators/snowflake.py @@ -529,13 +529,6 @@ class SnowflakeGenerator(generator.Generator): exp.ToFile: lambda self, e: self.func( f"{'TRY_' if e.args.get('safe') else ''}TO_FILE", e.this, e.args.get("path") ), - exp.ToNumber: lambda self, e: self.func( - f"{'TRY_' if e.args.get('safe') else ''}TO_NUMBER", - e.this, - e.args.get("format"), - e.args.get("precision"), - e.args.get("scale"), - ), exp.JSONFormat: rename_func("TO_JSON"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess([transforms.add_within_group_for_percentiles]), @@ -712,12 +705,26 @@ def datatype_sql(self, expression: exp.DataType) -> str: return super().datatype_sql(expression) def tonumber_sql(self, expression: exp.ToNumber) -> str: + precision = expression.args.get("precision") + scale = expression.args.get("scale") + + default_precision = isinstance(precision, exp.Literal) and precision.name == "38" + default_scale = isinstance(scale, exp.Literal) and scale.name == "0" + + if default_precision and default_scale: + precision = None + scale = None + elif default_scale: + scale = None + + func_name = "TRY_TO_NUMBER" if expression.args.get("safe") else "TO_NUMBER" + return self.func( - "TO_NUMBER", + func_name, expression.this, expression.args.get("format"), - expression.args.get("precision"), - expression.args.get("scale"), + precision, + scale, ) def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: diff --git a/sqlglot/parsers/snowflake.py b/sqlglot/parsers/snowflake.py index 10e614c046..e011bc9c21 100644 --- a/sqlglot/parsers/snowflake.py +++ b/sqlglot/parsers/snowflake.py @@ -39,6 +39,26 @@ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK: return exp.ApproxTopK.from_arg_list(args) +def _build_to_number(args: t.List, safe: bool = False) -> exp.ToNumber: + second_arg = seq_get(args, 1) + if second_arg and second_arg.is_number: + fmt = None + precision = second_arg + scale = seq_get(args, 2) or exp.Literal.number(0) + else: + fmt = second_arg + precision = seq_get(args, 2) or exp.Literal.number(38) + scale = seq_get(args, 3) or exp.Literal.number(0) + + return exp.ToNumber( + this=seq_get(args, 0), + format=fmt, + precision=precision, + scale=scale, + safe=safe, + ) + + def _build_date_from_parts(args: t.List) -> exp.DateFromParts: return exp.DateFromParts( year=seq_get(args, 0), @@ -295,16 +315,6 @@ def _build_generator(args: t.List) -> exp.Generator: return exp.Generator(**gen_args) -def _build_try_to_number(args: t.List[exp.Expr]) -> exp.Expr: - return exp.ToNumber( - this=seq_get(args, 0), - format=seq_get(args, 1), - precision=seq_get(args, 2), - scale=seq_get(args, 3), - safe=True, - ) - - def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[SnowflakeParser], exp.Show]: def _parse(self: SnowflakeParser) -> exp.Show: return self._parse_show_snowflake(*args, **kwargs) @@ -638,7 +648,8 @@ class SnowflakeParser(parser.Parser): "TRY_TO_BOOLEAN": lambda args: exp.ToBoolean(this=seq_get(args, 0), safe=True), "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DType.DATE, safe=True), **dict.fromkeys( - ("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"), _build_try_to_number + ("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"), + lambda args: _build_to_number(args, safe=True), ), "TRY_TO_DOUBLE": lambda args: exp.ToDouble( this=seq_get(args, 0), format=seq_get(args, 1), safe=True @@ -661,12 +672,7 @@ class SnowflakeParser(parser.Parser): "TO_DATE": _build_datetime("TO_DATE", exp.DType.DATE), **dict.fromkeys( ("TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"), - lambda args: exp.ToNumber( - this=seq_get(args, 0), - format=seq_get(args, 1), - precision=seq_get(args, 2), - scale=seq_get(args, 3), - ), + lambda args: _build_to_number(args), ), "TO_TIME": _build_datetime("TO_TIME", exp.DType.TIME), "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DType.TIMESTAMP), diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 35aaf2949e..4f16a54f5a 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -775,6 +775,35 @@ def test_duckdb(self): "snowflake": "SELECT IFF(_u.pos = _u_2.pos_2, _u_2.col, NULL) AS col FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => [1, 2, 3])) AS _u_2(seq, key, path, pos_2, col, this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND _u_2.pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", }, ) + + self.validate_all( + "SELECT CAST('12.3456' AS DECIMAL(38, 0))", + read={ + "snowflake": "SELECT TO_NUMBER('12.3456')", + }, + write={ + "duckdb": "SELECT CAST('12.3456' AS DECIMAL(38, 0))", + }, + ) + self.validate_all( + "SELECT CAST('12.3456' AS DECIMAL(10, 0))", + read={ + "snowflake": "SELECT TO_NUMBER('12.3456', 10)", + }, + write={ + "duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 0))", + }, + ) + self.validate_all( + "SELECT CAST('12.3456' AS DECIMAL(10, 2))", + read={ + "snowflake": "SELECT TO_NUMBER('12.3456', 10, 2)", + }, + write={ + "duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 2))", + }, + ) + self.validate_all( "VAR_POP(x)", read={ diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 797a90c30c..f72b92ccb0 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -499,9 +499,39 @@ def test_snowflake(self): self.validate_identity( "TO_DECIMAL(expr, fmt, precision, scale)", "TO_NUMBER(expr, fmt, precision, scale)" ) - self.validate_identity("TO_NUMBER(expr)") - self.validate_identity("TO_NUMBER(expr, fmt)") - self.validate_identity("TO_NUMBER(expr, fmt, precision, scale)") + self.validate_identity("TO_NUMBER(expr, 38, 0)", "TO_NUMBER(expr)") + self.validate_identity("TO_NUMBER(expr, 38)", "TO_NUMBER(expr)") + + ast = self.validate_identity("TO_NUMBER('12.3456')") + self.assertIsInstance(ast, exp.ToNumber) + self.assertIsNone(ast.args.get("format")) + self.assertEqual(ast.args.get("precision").name, "38") + self.assertEqual(ast.args.get("scale").name, "0") + + ast = self.validate_identity("TO_NUMBER('12.3456', 10, 1)") + self.assertIsInstance(ast, exp.ToNumber) + self.assertIsNone(ast.args.get("format")) + self.assertEqual(ast.args.get("precision").name, "10") + self.assertEqual(ast.args.get("scale").name, "1") + + ast = self.validate_identity("TO_NUMBER('12.3456', '99.99')") + self.assertIsInstance(ast, exp.ToNumber) + self.assertEqual(ast.args.get("format").name, "99.99") + self.assertEqual(ast.args.get("precision").name, "38") + self.assertEqual(ast.args.get("scale").name, "0") + + ast = self.validate_identity("TO_NUMBER('12.3456', '99.99', 10, 1)") + self.assertIsInstance(ast, exp.ToNumber) + self.assertEqual(ast.args.get("format").name, "99.99") + self.assertEqual(ast.args.get("precision").name, "10") + self.assertEqual(ast.args.get("scale").name, "1") + + ast = self.validate_identity("TO_NUMBER('12.3456', 3)") + self.assertIsInstance(ast, exp.ToNumber) + self.assertIsNone(ast.args.get("format")) + self.assertEqual(ast.args.get("precision").name, "3") + self.assertEqual(ast.args.get("scale").name, "0") + self.validate_identity("TO_DECFLOAT('123.456')") self.validate_identity("TO_DECFLOAT('1,234.56', '999,999.99')") self.validate_identity("TRY_TO_DECFLOAT('123.456')") @@ -542,9 +572,44 @@ def test_snowflake(self): self.validate_identity("TRY_TO_FILE(object_col)") self.validate_identity("TRY_TO_FILE('file.csv')") self.validate_identity("TRY_TO_FILE('file.csv', 'relativepath/')") - self.validate_identity("TRY_TO_NUMBER('123.45')") - self.validate_identity("TRY_TO_NUMBER('123.45', '999.99')") - self.validate_identity("TRY_TO_NUMBER('123.45', '999.99', 10, 2)") + self.validate_identity("TRY_TO_NUMBER(expr, 38, 0)", "TRY_TO_NUMBER(expr)") + self.validate_identity("TRY_TO_NUMBER(expr, 38)", "TRY_TO_NUMBER(expr)") + + ast = self.validate_identity("TRY_TO_NUMBER('12.3456')") + self.assertIsInstance(ast, exp.ToNumber) + self.assertIsNone(ast.args.get("format")) + self.assertEqual(ast.args.get("precision").name, "38") + self.assertEqual(ast.args.get("scale").name, "0") + self.assertTrue(ast.args.get("safe")) + + ast = self.validate_identity("TRY_TO_NUMBER('12.3456', 10, 1)") + self.assertIsInstance(ast, exp.ToNumber) + self.assertIsNone(ast.args.get("format")) + self.assertEqual(ast.args.get("precision").name, "10") + self.assertEqual(ast.args.get("scale").name, "1") + self.assertTrue(ast.args.get("safe")) + + ast = self.validate_identity("TRY_TO_NUMBER('12.3456', '99.99')") + self.assertIsInstance(ast, exp.ToNumber) + self.assertEqual(ast.args.get("format").name, "99.99") + self.assertEqual(ast.args.get("precision").name, "38") + self.assertEqual(ast.args.get("scale").name, "0") + self.assertTrue(ast.args.get("safe")) + + ast = self.validate_identity("TRY_TO_NUMBER('12.3456', '99.99', 10, 1)") + self.assertIsInstance(ast, exp.ToNumber) + self.assertEqual(ast.args.get("format").name, "99.99") + self.assertEqual(ast.args.get("precision").name, "10") + self.assertEqual(ast.args.get("scale").name, "1") + self.assertTrue(ast.args.get("safe")) + + ast = self.validate_identity("TRY_TO_NUMBER('12.3456', 3)") + self.assertIsInstance(ast, exp.ToNumber) + self.assertIsNone(ast.args.get("format")) + self.assertEqual(ast.args.get("precision").name, "3") + self.assertEqual(ast.args.get("scale").name, "0") + self.assertTrue(ast.args.get("safe")) + self.validate_identity("TO_NUMERIC('123.45')", "TO_NUMBER('123.45')") self.validate_identity("TO_NUMERIC('123.45', '999.99')", "TO_NUMBER('123.45', '999.99')") self.validate_identity(