Skip to content
Merged
2 changes: 1 addition & 1 deletion sqlglot-integration-tests
14 changes: 14 additions & 0 deletions sqlglot/generators/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,20 @@ def tobinary_sql(self, expression: exp.ToBinary) -> str:
result = self.func("TO_BINARY", value)
return f"TRY({result})" if is_safe else result

def tonumber_sql(self, expression: exp.ToNumber) -> str:
fmt = expression.args.get("format")
precision = expression.args.get("precision")
scale = expression.args.get("scale")

if not fmt and precision and scale:
return self.sql(
exp.cast(
expression.this, f"DECIMAL({precision.name}, {scale.name})", dialect="duckdb"
)
)

return super().tonumber_sql(expression)

def _greatest_least_sql(self, expression: exp.Greatest | exp.Least) -> str:
"""
Handle GREATEST/LEAST functions with dialect-aware NULL behavior.
Expand Down
27 changes: 17 additions & 10 deletions sqlglot/generators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,6 @@ class SnowflakeGenerator(generator.Generator):
exp.ToFile: lambda self, e: self.func(
f"{'TRY_' if e.args.get('safe') else ''}TO_FILE", e.this, e.args.get("path")
),
exp.ToNumber: lambda self, e: self.func(
f"{'TRY_' if e.args.get('safe') else ''}TO_NUMBER",
e.this,
e.args.get("format"),
e.args.get("precision"),
e.args.get("scale"),
),
exp.JSONFormat: rename_func("TO_JSON"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess([transforms.add_within_group_for_percentiles]),
Expand Down Expand Up @@ -712,12 +705,26 @@ def datatype_sql(self, expression: exp.DataType) -> str:
return super().datatype_sql(expression)

def tonumber_sql(self, expression: exp.ToNumber) -> str:
precision = expression.args.get("precision")
scale = expression.args.get("scale")

default_precision = isinstance(precision, exp.Literal) and precision.name == "38"
default_scale = isinstance(scale, exp.Literal) and scale.name == "0"

if default_precision and default_scale:
precision = None
scale = None
elif default_scale:
scale = None

func_name = "TRY_TO_NUMBER" if expression.args.get("safe") else "TO_NUMBER"

return self.func(
"TO_NUMBER",
func_name,
expression.this,
expression.args.get("format"),
expression.args.get("precision"),
expression.args.get("scale"),
precision,
scale,
)

def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
Expand Down
40 changes: 23 additions & 17 deletions sqlglot/parsers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK:
return exp.ApproxTopK.from_arg_list(args)


def _build_to_number(args: t.List, safe: bool = False) -> exp.ToNumber:
second_arg = seq_get(args, 1)
if second_arg and second_arg.is_number:
fmt = None
precision = second_arg
scale = seq_get(args, 2) or exp.Literal.number(0)
else:
fmt = second_arg
precision = seq_get(args, 2) or exp.Literal.number(38)
scale = seq_get(args, 3) or exp.Literal.number(0)

return exp.ToNumber(
this=seq_get(args, 0),
format=fmt,
precision=precision,
scale=scale,
safe=safe,
)


def _build_date_from_parts(args: t.List) -> exp.DateFromParts:
return exp.DateFromParts(
year=seq_get(args, 0),
Expand Down Expand Up @@ -295,16 +315,6 @@ def _build_generator(args: t.List) -> exp.Generator:
return exp.Generator(**gen_args)


def _build_try_to_number(args: t.List[exp.Expr]) -> exp.Expr:
return exp.ToNumber(
this=seq_get(args, 0),
format=seq_get(args, 1),
precision=seq_get(args, 2),
scale=seq_get(args, 3),
safe=True,
)


def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[SnowflakeParser], exp.Show]:
def _parse(self: SnowflakeParser) -> exp.Show:
return self._parse_show_snowflake(*args, **kwargs)
Expand Down Expand Up @@ -638,7 +648,8 @@ class SnowflakeParser(parser.Parser):
"TRY_TO_BOOLEAN": lambda args: exp.ToBoolean(this=seq_get(args, 0), safe=True),
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DType.DATE, safe=True),
**dict.fromkeys(
("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"), _build_try_to_number
("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"),
lambda args: _build_to_number(args, safe=True),
),
"TRY_TO_DOUBLE": lambda args: exp.ToDouble(
this=seq_get(args, 0), format=seq_get(args, 1), safe=True
Expand All @@ -661,12 +672,7 @@ class SnowflakeParser(parser.Parser):
"TO_DATE": _build_datetime("TO_DATE", exp.DType.DATE),
**dict.fromkeys(
("TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"),
lambda args: exp.ToNumber(
this=seq_get(args, 0),
format=seq_get(args, 1),
precision=seq_get(args, 2),
scale=seq_get(args, 3),
),
lambda args: _build_to_number(args),
),
"TO_TIME": _build_datetime("TO_TIME", exp.DType.TIME),
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DType.TIMESTAMP),
Expand Down
29 changes: 29 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,35 @@ def test_duckdb(self):
"snowflake": "SELECT IFF(_u.pos = _u_2.pos_2, _u_2.col, NULL) AS col FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => [1, 2, 3])) AS _u_2(seq, key, path, pos_2, col, this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND _u_2.pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))",
},
)

self.validate_all(
"SELECT CAST('12.3456' AS DECIMAL(38, 0))",
read={
"snowflake": "SELECT TO_NUMBER('12.3456')",
},
write={
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(38, 0))",
},
)
self.validate_all(
"SELECT CAST('12.3456' AS DECIMAL(10, 0))",
read={
"snowflake": "SELECT TO_NUMBER('12.3456', 10)",
},
write={
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 0))",
},
)
self.validate_all(
"SELECT CAST('12.3456' AS DECIMAL(10, 2))",
read={
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 2)",
},
write={
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 2))",
},
)

self.validate_all(
"VAR_POP(x)",
read={
Expand Down
77 changes: 71 additions & 6 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,39 @@ def test_snowflake(self):
self.validate_identity(
"TO_DECIMAL(expr, fmt, precision, scale)", "TO_NUMBER(expr, fmt, precision, scale)"
)
self.validate_identity("TO_NUMBER(expr)")
self.validate_identity("TO_NUMBER(expr, fmt)")
self.validate_identity("TO_NUMBER(expr, fmt, precision, scale)")
self.validate_identity("TO_NUMBER(expr, 38, 0)", "TO_NUMBER(expr)")
self.validate_identity("TO_NUMBER(expr, 38)", "TO_NUMBER(expr)")

ast = self.validate_identity("TO_NUMBER('12.3456')")
self.assertIsInstance(ast, exp.ToNumber)
self.assertIsNone(ast.args.get("format"))
self.assertEqual(ast.args.get("precision").name, "38")
self.assertEqual(ast.args.get("scale").name, "0")

ast = self.validate_identity("TO_NUMBER('12.3456', 10, 1)")
self.assertIsInstance(ast, exp.ToNumber)
self.assertIsNone(ast.args.get("format"))
self.assertEqual(ast.args.get("precision").name, "10")
self.assertEqual(ast.args.get("scale").name, "1")

ast = self.validate_identity("TO_NUMBER('12.3456', '99.99')")
self.assertIsInstance(ast, exp.ToNumber)
self.assertEqual(ast.args.get("format").name, "99.99")
self.assertEqual(ast.args.get("precision").name, "38")
self.assertEqual(ast.args.get("scale").name, "0")

ast = self.validate_identity("TO_NUMBER('12.3456', '99.99', 10, 1)")
self.assertIsInstance(ast, exp.ToNumber)
self.assertEqual(ast.args.get("format").name, "99.99")
self.assertEqual(ast.args.get("precision").name, "10")
self.assertEqual(ast.args.get("scale").name, "1")

ast = self.validate_identity("TO_NUMBER('12.3456', 3)")
self.assertIsInstance(ast, exp.ToNumber)
self.assertIsNone(ast.args.get("format"))
self.assertEqual(ast.args.get("precision").name, "3")
self.assertEqual(ast.args.get("scale").name, "0")

self.validate_identity("TO_DECFLOAT('123.456')")
self.validate_identity("TO_DECFLOAT('1,234.56', '999,999.99')")
self.validate_identity("TRY_TO_DECFLOAT('123.456')")
Expand Down Expand Up @@ -542,9 +572,44 @@ def test_snowflake(self):
self.validate_identity("TRY_TO_FILE(object_col)")
self.validate_identity("TRY_TO_FILE('file.csv')")
self.validate_identity("TRY_TO_FILE('file.csv', 'relativepath/')")
self.validate_identity("TRY_TO_NUMBER('123.45')")
self.validate_identity("TRY_TO_NUMBER('123.45', '999.99')")
self.validate_identity("TRY_TO_NUMBER('123.45', '999.99', 10, 2)")
self.validate_identity("TRY_TO_NUMBER(expr, 38, 0)", "TRY_TO_NUMBER(expr)")
self.validate_identity("TRY_TO_NUMBER(expr, 38)", "TRY_TO_NUMBER(expr)")

ast = self.validate_identity("TRY_TO_NUMBER('12.3456')")
self.assertIsInstance(ast, exp.ToNumber)
self.assertIsNone(ast.args.get("format"))
self.assertEqual(ast.args.get("precision").name, "38")
self.assertEqual(ast.args.get("scale").name, "0")
self.assertTrue(ast.args.get("safe"))

ast = self.validate_identity("TRY_TO_NUMBER('12.3456', 10, 1)")
self.assertIsInstance(ast, exp.ToNumber)
self.assertIsNone(ast.args.get("format"))
self.assertEqual(ast.args.get("precision").name, "10")
self.assertEqual(ast.args.get("scale").name, "1")
self.assertTrue(ast.args.get("safe"))

ast = self.validate_identity("TRY_TO_NUMBER('12.3456', '99.99')")
self.assertIsInstance(ast, exp.ToNumber)
self.assertEqual(ast.args.get("format").name, "99.99")
self.assertEqual(ast.args.get("precision").name, "38")
self.assertEqual(ast.args.get("scale").name, "0")
self.assertTrue(ast.args.get("safe"))

ast = self.validate_identity("TRY_TO_NUMBER('12.3456', '99.99', 10, 1)")
self.assertIsInstance(ast, exp.ToNumber)
self.assertEqual(ast.args.get("format").name, "99.99")
self.assertEqual(ast.args.get("precision").name, "10")
self.assertEqual(ast.args.get("scale").name, "1")
self.assertTrue(ast.args.get("safe"))

ast = self.validate_identity("TRY_TO_NUMBER('12.3456', 3)")
self.assertIsInstance(ast, exp.ToNumber)
self.assertIsNone(ast.args.get("format"))
self.assertEqual(ast.args.get("precision").name, "3")
self.assertEqual(ast.args.get("scale").name, "0")
self.assertTrue(ast.args.get("safe"))

self.validate_identity("TO_NUMERIC('123.45')", "TO_NUMBER('123.45')")
self.validate_identity("TO_NUMERIC('123.45', '999.99')", "TO_NUMBER('123.45', '999.99')")
self.validate_identity(
Expand Down
Loading