Skip to content

Commit ff27d17

Browse files
committed
refactor
1 parent c7c8002 commit ff27d17

4 files changed

Lines changed: 30 additions & 74 deletions

File tree

sqlglot/generators/duckdb.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,26 +2352,13 @@ def tobinary_sql(self, expression: exp.ToBinary) -> str:
23522352

23532353
@unsupported_args("format")
23542354
def tonumber_sql(self, expression: exp.ToNumber) -> str:
2355-
"""
2356-
Snowflake's TO_NUMBER without precision/scale defaults to NUMBER(38, 0),
2357-
which truncates decimals. The parser sets these defaults at parse time.
2358-
Always cast to DECIMAL(precision, scale) using the values from the AST.
2359-
2360-
Oracle's TO_NUMBER without precision/scale should convert to DOUBLE.
2361-
"""
23622355
precision = expression.args.get("precision")
23632356
scale = expression.args.get("scale")
23642357

2365-
# Build DECIMAL type with precision and scale from AST
23662358
if precision and scale:
2367-
# Snowflake parser ensures defaults (38, 0) are set when not specified
23682359
decimal_type = exp.DataType.build(f"DECIMAL({precision.name}, {scale.name})")
2369-
elif precision is None and scale is None:
2370-
# Oracle or other dialects that don't set defaults - convert to DOUBLE
2371-
decimal_type = exp.DataType.build("DOUBLE")
23722360
else:
2373-
# Fallback for partial specification
2374-
decimal_type = exp.DataType.build("DECIMAL(38, 0)")
2361+
decimal_type = exp.DataType.build("DOUBLE")
23752362

23762363
return self.sql(exp.cast(expression.this, decimal_type))
23772364

sqlglot/generators/snowflake.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -741,39 +741,6 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) ->
741741
if expression.is_type(exp.DType.GEOMETRY):
742742
return self.func("TO_GEOMETRY", expression.this)
743743

744-
# Convert CAST to DECIMAL/NUMERIC to TO_NUMBER only for string inputs
745-
# Don't convert TryCast - it's handled by trycast_sql
746-
if expression.is_type(exp.DType.DECIMAL) and not isinstance(expression, exp.TryCast):
747-
value = expression.this
748-
749-
# Annotate types if not already done
750-
if value.type is None:
751-
from sqlglot.optimizer.annotate_types import annotate_types
752-
753-
value = annotate_types(value, dialect=self.dialect)
754-
755-
# Only convert to TO_NUMBER for string inputs
756-
if value.is_string or value.is_type(*exp.DataType.TEXT_TYPES):
757-
# Extract precision and scale from DECIMAL(p, s)
758-
params = expression.to.expressions or []
759-
precision = (
760-
params[0].this
761-
if len(params) >= 1 and isinstance(params[0], exp.DataTypeParam)
762-
else None
763-
)
764-
scale = (
765-
params[1].this
766-
if len(params) >= 2 and isinstance(params[1], exp.DataTypeParam)
767-
else None
768-
)
769-
770-
to_number = exp.ToNumber(
771-
this=value,
772-
precision=precision,
773-
scale=scale,
774-
)
775-
return self.tonumber_sql(to_number)
776-
777744
return super().cast_sql(expression, safe_prefix=safe_prefix)
778745

779746
def trycast_sql(self, expression: exp.TryCast) -> str:

sqlglot/parsers/snowflake.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ def _build_approx_top_k(args: t.List) -> exp.ApproxTopK:
3939
return exp.ApproxTopK.from_arg_list(args)
4040

4141

42+
def _build_to_number(args: t.List, safe: bool = False) -> exp.ToNumber:
43+
second_arg = seq_get(args, 1)
44+
if second_arg and second_arg.is_number:
45+
fmt = None
46+
precision = second_arg
47+
scale = seq_get(args, 2) or exp.Literal.number(0)
48+
else:
49+
fmt = second_arg
50+
precision = seq_get(args, 2) or exp.Literal.number(38)
51+
scale = seq_get(args, 3) or exp.Literal.number(0)
52+
53+
return exp.ToNumber(
54+
this=seq_get(args, 0),
55+
format=fmt,
56+
precision=precision,
57+
scale=scale,
58+
safe=safe,
59+
)
60+
61+
4262
def _build_date_from_parts(args: t.List) -> exp.DateFromParts:
4363
return exp.DateFromParts(
4464
year=seq_get(args, 0),
@@ -623,15 +643,7 @@ class SnowflakeParser(parser.Parser):
623643
"TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DType.DATE, safe=True),
624644
**dict.fromkeys(
625645
("TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"),
626-
lambda args: exp.ToNumber(
627-
this=seq_get(args, 0),
628-
format=seq_get(args, 1) if len(args) in (2, 4) else None,
629-
precision=(seq_get(args, 2) if len(args) in (2, 4) else seq_get(args, 1))
630-
or exp.Literal.number(38),
631-
scale=(seq_get(args, 3) if len(args) in (2, 4) else seq_get(args, 2))
632-
or exp.Literal.number(0),
633-
safe=True,
634-
),
646+
lambda args: _build_to_number(args, safe=True),
635647
),
636648
"TRY_TO_DOUBLE": lambda args: exp.ToDouble(
637649
this=seq_get(args, 0), format=seq_get(args, 1), safe=True
@@ -654,14 +666,7 @@ class SnowflakeParser(parser.Parser):
654666
"TO_DATE": _build_datetime("TO_DATE", exp.DType.DATE),
655667
**dict.fromkeys(
656668
("TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"),
657-
lambda args: exp.ToNumber(
658-
this=seq_get(args, 0),
659-
format=seq_get(args, 1) if len(args) in (2, 4) else None,
660-
precision=(seq_get(args, 2) if len(args) in (2, 4) else seq_get(args, 1))
661-
or exp.Literal.number(38),
662-
scale=(seq_get(args, 3) if len(args) in (2, 4) else seq_get(args, 2))
663-
or exp.Literal.number(0),
664-
),
669+
lambda args: _build_to_number(args),
665670
),
666671
"TO_TIME": _build_datetime("TO_TIME", exp.DType.TIME),
667672
"TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DType.TIMESTAMP),

tests/dialects/test_duckdb.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -762,37 +762,34 @@ def test_duckdb(self):
762762
},
763763
)
764764

765-
# TO_NUMBER transpilation from Snowflake to DuckDB
766765
self.validate_all(
767766
"SELECT CAST('12.3456' AS DECIMAL(38, 0))",
768767
read={
769768
"snowflake": "SELECT TO_NUMBER('12.3456')",
770769
},
771770
write={
772771
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(38, 0))",
773-
"snowflake": "SELECT TO_NUMBER('12.3456')",
774772
},
775773
)
776774
self.validate_all(
777-
"SELECT CAST('12.3456' AS DECIMAL(10, 1))",
775+
"SELECT CAST('12.3456' AS DECIMAL(10, 0))",
778776
read={
779-
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 1)",
777+
"snowflake": "SELECT TO_NUMBER('12.3456', 10)",
780778
},
781779
write={
782-
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 1))",
783-
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 1)",
780+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 0))",
784781
},
785782
)
786783
self.validate_all(
787-
"SELECT CAST('3,741.72' AS DECIMAL(6, 2))",
784+
"SELECT CAST('12.3456' AS DECIMAL(10, 2))",
788785
read={
789-
"snowflake": "SELECT TO_DECIMAL('3,741.72', '9,999.99', 6, 2)",
786+
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 2)",
790787
},
791788
write={
792-
"duckdb": "SELECT CAST('3,741.72' AS DECIMAL(6, 2))",
793-
"snowflake": "SELECT TO_NUMBER('3,741.72', 6, 2)", # Format is lost during transpilation
789+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 2))",
794790
},
795791
)
792+
796793
self.validate_all(
797794
"VAR_POP(x)",
798795
read={

0 commit comments

Comments
 (0)