Skip to content

Commit 61bb18c

Browse files
fivetran-ashashankargeooo109georgesittas
authored
feat(snowflake)!: Transpilation support for TO_DECIMAL, TO_NUMBER,NUMERIC (#7315)
* feat(snowflake)!: Transpilation support for TO_NUMBER transpilation. * feat(snowflake)!: Transpilation support for TO_NUMBER transpilation. * feat(snowflake)!: Transpilation support for TO_NUMBER transpilation. * refactor * ref 2 * ref 3 * sf tests * tests final * tests 4 * ref * Sync w/ integration tests Signed-off-by: George Sittas <giwrgos.sittas@gmail.com> --------- Signed-off-by: George Sittas <giwrgos.sittas@gmail.com> Co-authored-by: geooo109 <geomichas96@gmail.com> Co-authored-by: George Sittas <giwrgos.sittas@gmail.com>
1 parent 0b46246 commit 61bb18c

File tree

6 files changed

+155
-34
lines changed

6 files changed

+155
-34
lines changed

sqlglot-integration-tests

sqlglot/generators/duckdb.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,6 +2352,20 @@ def tobinary_sql(self, expression: exp.ToBinary) -> str:
23522352
result = self.func("TO_BINARY", value)
23532353
return f"TRY({result})" if is_safe else result
23542354

2355+
def tonumber_sql(self, expression: exp.ToNumber) -> str:
2356+
fmt = expression.args.get("format")
2357+
precision = expression.args.get("precision")
2358+
scale = expression.args.get("scale")
2359+
2360+
if not fmt and precision and scale:
2361+
return self.sql(
2362+
exp.cast(
2363+
expression.this, f"DECIMAL({precision.name}, {scale.name})", dialect="duckdb"
2364+
)
2365+
)
2366+
2367+
return super().tonumber_sql(expression)
2368+
23552369
def _greatest_least_sql(self, expression: exp.Greatest | exp.Least) -> str:
23562370
"""
23572371
Handle GREATEST/LEAST functions with dialect-aware NULL behavior.

sqlglot/generators/snowflake.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,6 @@ class SnowflakeGenerator(generator.Generator):
529529
exp.ToFile: lambda self, e: self.func(
530530
f"{'TRY_' if e.args.get('safe') else ''}TO_FILE", e.this, e.args.get("path")
531531
),
532-
exp.ToNumber: lambda self, e: self.func(
533-
f"{'TRY_' if e.args.get('safe') else ''}TO_NUMBER",
534-
e.this,
535-
e.args.get("format"),
536-
e.args.get("precision"),
537-
e.args.get("scale"),
538-
),
539532
exp.JSONFormat: rename_func("TO_JSON"),
540533
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
541534
exp.PercentileCont: transforms.preprocess([transforms.add_within_group_for_percentiles]),
@@ -712,12 +705,26 @@ def datatype_sql(self, expression: exp.DataType) -> str:
712705
return super().datatype_sql(expression)
713706

714707
def tonumber_sql(self, expression: exp.ToNumber) -> str:
708+
precision = expression.args.get("precision")
709+
scale = expression.args.get("scale")
710+
711+
default_precision = isinstance(precision, exp.Literal) and precision.name == "38"
712+
default_scale = isinstance(scale, exp.Literal) and scale.name == "0"
713+
714+
if default_precision and default_scale:
715+
precision = None
716+
scale = None
717+
elif default_scale:
718+
scale = None
719+
720+
func_name = "TRY_TO_NUMBER" if expression.args.get("safe") else "TO_NUMBER"
721+
715722
return self.func(
716-
"TO_NUMBER",
723+
func_name,
717724
expression.this,
718725
expression.args.get("format"),
719-
expression.args.get("precision"),
720-
expression.args.get("scale"),
726+
precision,
727+
scale,
721728
)
722729

723730
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:

sqlglot/parsers/snowflake.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK:
3939
return exp.ApproxTopK.from_arg_list(args)
4040

4141

42+
def _build_to_number(args: t.List, safe: bool = False) -> exp.ToNumber:
43+
second_arg = seq_get(args, 1)
44+
if second_arg and second_arg.is_number:
45+
fmt = None
46+
precision = second_arg
47+
scale = seq_get(args, 2) or exp.Literal.number(0)
48+
else:
49+
fmt = second_arg
50+
precision = seq_get(args, 2) or exp.Literal.number(38)
51+
scale = seq_get(args, 3) or exp.Literal.number(0)
52+
53+
return exp.ToNumber(
54+
this=seq_get(args, 0),
55+
format=fmt,
56+
precision=precision,
57+
scale=scale,
58+
safe=safe,
59+
)
60+
61+
4262
def _build_date_from_parts(args: t.List) -> exp.DateFromParts:
4363
return exp.DateFromParts(
4464
year=seq_get(args, 0),
@@ -295,16 +315,6 @@ def _build_generator(args: t.List) -> exp.Generator:
295315
return exp.Generator(**gen_args)
296316

297317

298-
def _build_try_to_number(args: t.List[exp.Expr]) -> exp.Expr:
299-
return exp.ToNumber(
300-
this=seq_get(args, 0),
301-
format=seq_get(args, 1),
302-
precision=seq_get(args, 2),
303-
scale=seq_get(args, 3),
304-
safe=True,
305-
)
306-
307-
308318
def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[SnowflakeParser], exp.Show]:
309319
def _parse(self: SnowflakeParser) -> exp.Show:
310320
return self._parse_show_snowflake(*args, **kwargs)
@@ -638,7 +648,8 @@ class SnowflakeParser(parser.Parser):
638648
"TRY_TO_BOOLEAN": lambda args: exp.ToBoolean(this=seq_get(args, 0), safe=True),
639649
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DType.DATE, safe=True),
640650
**dict.fromkeys(
641-
("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"), _build_try_to_number
651+
("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"),
652+
lambda args: _build_to_number(args, safe=True),
642653
),
643654
"TRY_TO_DOUBLE": lambda args: exp.ToDouble(
644655
this=seq_get(args, 0), format=seq_get(args, 1), safe=True
@@ -661,12 +672,7 @@ class SnowflakeParser(parser.Parser):
661672
"TO_DATE": _build_datetime("TO_DATE", exp.DType.DATE),
662673
**dict.fromkeys(
663674
("TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"),
664-
lambda args: exp.ToNumber(
665-
this=seq_get(args, 0),
666-
format=seq_get(args, 1),
667-
precision=seq_get(args, 2),
668-
scale=seq_get(args, 3),
669-
),
675+
lambda args: _build_to_number(args),
670676
),
671677
"TO_TIME": _build_datetime("TO_TIME", exp.DType.TIME),
672678
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DType.TIMESTAMP),

tests/dialects/test_duckdb.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,35 @@ def test_duckdb(self):
775775
"snowflake": "SELECT IFF(_u.pos = _u_2.pos_2, _u_2.col, NULL) AS col FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => [1, 2, 3])) AS _u_2(seq, key, path, pos_2, col, this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND _u_2.pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))",
776776
},
777777
)
778+
779+
self.validate_all(
780+
"SELECT CAST('12.3456' AS DECIMAL(38, 0))",
781+
read={
782+
"snowflake": "SELECT TO_NUMBER('12.3456')",
783+
},
784+
write={
785+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(38, 0))",
786+
},
787+
)
788+
self.validate_all(
789+
"SELECT CAST('12.3456' AS DECIMAL(10, 0))",
790+
read={
791+
"snowflake": "SELECT TO_NUMBER('12.3456', 10)",
792+
},
793+
write={
794+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 0))",
795+
},
796+
)
797+
self.validate_all(
798+
"SELECT CAST('12.3456' AS DECIMAL(10, 2))",
799+
read={
800+
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 2)",
801+
},
802+
write={
803+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 2))",
804+
},
805+
)
806+
778807
self.validate_all(
779808
"VAR_POP(x)",
780809
read={

tests/dialects/test_snowflake.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,39 @@ def test_snowflake(self):
499499
self.validate_identity(
500500
"TO_DECIMAL(expr, fmt, precision, scale)", "TO_NUMBER(expr, fmt, precision, scale)"
501501
)
502-
self.validate_identity("TO_NUMBER(expr)")
503-
self.validate_identity("TO_NUMBER(expr, fmt)")
504-
self.validate_identity("TO_NUMBER(expr, fmt, precision, scale)")
502+
self.validate_identity("TO_NUMBER(expr, 38, 0)", "TO_NUMBER(expr)")
503+
self.validate_identity("TO_NUMBER(expr, 38)", "TO_NUMBER(expr)")
504+
505+
ast = self.validate_identity("TO_NUMBER('12.3456')")
506+
self.assertIsInstance(ast, exp.ToNumber)
507+
self.assertIsNone(ast.args.get("format"))
508+
self.assertEqual(ast.args.get("precision").name, "38")
509+
self.assertEqual(ast.args.get("scale").name, "0")
510+
511+
ast = self.validate_identity("TO_NUMBER('12.3456', 10, 1)")
512+
self.assertIsInstance(ast, exp.ToNumber)
513+
self.assertIsNone(ast.args.get("format"))
514+
self.assertEqual(ast.args.get("precision").name, "10")
515+
self.assertEqual(ast.args.get("scale").name, "1")
516+
517+
ast = self.validate_identity("TO_NUMBER('12.3456', '99.99')")
518+
self.assertIsInstance(ast, exp.ToNumber)
519+
self.assertEqual(ast.args.get("format").name, "99.99")
520+
self.assertEqual(ast.args.get("precision").name, "38")
521+
self.assertEqual(ast.args.get("scale").name, "0")
522+
523+
ast = self.validate_identity("TO_NUMBER('12.3456', '99.99', 10, 1)")
524+
self.assertIsInstance(ast, exp.ToNumber)
525+
self.assertEqual(ast.args.get("format").name, "99.99")
526+
self.assertEqual(ast.args.get("precision").name, "10")
527+
self.assertEqual(ast.args.get("scale").name, "1")
528+
529+
ast = self.validate_identity("TO_NUMBER('12.3456', 3)")
530+
self.assertIsInstance(ast, exp.ToNumber)
531+
self.assertIsNone(ast.args.get("format"))
532+
self.assertEqual(ast.args.get("precision").name, "3")
533+
self.assertEqual(ast.args.get("scale").name, "0")
534+
505535
self.validate_identity("TO_DECFLOAT('123.456')")
506536
self.validate_identity("TO_DECFLOAT('1,234.56', '999,999.99')")
507537
self.validate_identity("TRY_TO_DECFLOAT('123.456')")
@@ -542,9 +572,44 @@ def test_snowflake(self):
542572
self.validate_identity("TRY_TO_FILE(object_col)")
543573
self.validate_identity("TRY_TO_FILE('file.csv')")
544574
self.validate_identity("TRY_TO_FILE('file.csv', 'relativepath/')")
545-
self.validate_identity("TRY_TO_NUMBER('123.45')")
546-
self.validate_identity("TRY_TO_NUMBER('123.45', '999.99')")
547-
self.validate_identity("TRY_TO_NUMBER('123.45', '999.99', 10, 2)")
575+
self.validate_identity("TRY_TO_NUMBER(expr, 38, 0)", "TRY_TO_NUMBER(expr)")
576+
self.validate_identity("TRY_TO_NUMBER(expr, 38)", "TRY_TO_NUMBER(expr)")
577+
578+
ast = self.validate_identity("TRY_TO_NUMBER('12.3456')")
579+
self.assertIsInstance(ast, exp.ToNumber)
580+
self.assertIsNone(ast.args.get("format"))
581+
self.assertEqual(ast.args.get("precision").name, "38")
582+
self.assertEqual(ast.args.get("scale").name, "0")
583+
self.assertTrue(ast.args.get("safe"))
584+
585+
ast = self.validate_identity("TRY_TO_NUMBER('12.3456', 10, 1)")
586+
self.assertIsInstance(ast, exp.ToNumber)
587+
self.assertIsNone(ast.args.get("format"))
588+
self.assertEqual(ast.args.get("precision").name, "10")
589+
self.assertEqual(ast.args.get("scale").name, "1")
590+
self.assertTrue(ast.args.get("safe"))
591+
592+
ast = self.validate_identity("TRY_TO_NUMBER('12.3456', '99.99')")
593+
self.assertIsInstance(ast, exp.ToNumber)
594+
self.assertEqual(ast.args.get("format").name, "99.99")
595+
self.assertEqual(ast.args.get("precision").name, "38")
596+
self.assertEqual(ast.args.get("scale").name, "0")
597+
self.assertTrue(ast.args.get("safe"))
598+
599+
ast = self.validate_identity("TRY_TO_NUMBER('12.3456', '99.99', 10, 1)")
600+
self.assertIsInstance(ast, exp.ToNumber)
601+
self.assertEqual(ast.args.get("format").name, "99.99")
602+
self.assertEqual(ast.args.get("precision").name, "10")
603+
self.assertEqual(ast.args.get("scale").name, "1")
604+
self.assertTrue(ast.args.get("safe"))
605+
606+
ast = self.validate_identity("TRY_TO_NUMBER('12.3456', 3)")
607+
self.assertIsInstance(ast, exp.ToNumber)
608+
self.assertIsNone(ast.args.get("format"))
609+
self.assertEqual(ast.args.get("precision").name, "3")
610+
self.assertEqual(ast.args.get("scale").name, "0")
611+
self.assertTrue(ast.args.get("safe"))
612+
548613
self.validate_identity("TO_NUMERIC('123.45')", "TO_NUMBER('123.45')")
549614
self.validate_identity("TO_NUMERIC('123.45', '999.99')", "TO_NUMBER('123.45', '999.99')")
550615
self.validate_identity(

0 commit comments

Comments
 (0)