Skip to content

Commit dffaf52

Browse files
feat(snowflake)!: Transpilation support for TO_DECIMAL, TO_NUMBER, TO_NUMERIC
1 parent 78c1d46 commit dffaf52

3 files changed

Lines changed: 163 additions & 6 deletions

File tree

sqlglot/dialects/duckdb.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4181,6 +4181,126 @@ def strtok_sql(self, expression: exp.Strtok) -> str:
41814181

41824182
return self.function_fallback_sql(expression)
41834183

4184+
def tonumber_sql(self, expression: exp.ToNumber) -> str:
4185+
# TO_NUMBER(expr) -> CAST(expr AS BIGINT)
4186+
# TO_NUMBER(expr, precision, scale) -> CAST(expr AS DECIMAL(precision, scale))
4187+
# TO_NUMBER(expr, format, precision, scale) -> CAST(REGEXP_REPLACE(expr, ...) AS DECIMAL(precision, scale))
4188+
4189+
this = expression.this
4190+
format_arg = expression.args.get("format")
4191+
precision_arg = expression.args.get("precision")
4192+
scale_arg = expression.args.get("scale")
4193+
4194+
# Determine if format_arg is actually a format string or precision
4195+
actual_precision = None
4196+
actual_scale = None
4197+
format_string = None
4198+
4199+
if format_arg:
4200+
if format_arg.is_string:
4201+
# It's a format string
4202+
format_string = format_arg.this # Get the string value
4203+
actual_precision = precision_arg
4204+
actual_scale = scale_arg
4205+
else:
4206+
# It's numeric, so it's actually the precision
4207+
actual_precision = format_arg
4208+
actual_scale = precision_arg
4209+
4210+
# Process format string to build preprocessing expression
4211+
if format_string:
4212+
# Determine what characters need to be stripped based on format
4213+
chars_to_remove = []
4214+
4215+
# Check for common format patterns
4216+
format_lower = format_string.lower()
4217+
4218+
# Comma separator (9,999.99 or 999,999)
4219+
if "," in format_string:
4220+
chars_to_remove.append(",")
4221+
4222+
# Currency symbols ($, £, €, etc.)
4223+
if "$" in format_string:
4224+
chars_to_remove.append("$")
4225+
if "£" in format_string:
4226+
chars_to_remove.append("£")
4227+
if "€" in format_string:
4228+
chars_to_remove.append("€")
4229+
if "¥" in format_string:
4230+
chars_to_remove.append("¥")
4231+
4232+
# Hexadecimal format (XXX or xxx)
4233+
if format_lower in ("xxx", "xxxx"):
4234+
# Hexadecimal conversion: ('0x' || input)::UBIGINT::BIGINT
4235+
hex_expr = exp.Cast(
4236+
this=exp.Cast(
4237+
this=exp.Concat(expressions=[exp.Literal.string("0x"), this]),
4238+
to=exp.DataType(this=exp.DataType.Type.UBIGINT),
4239+
),
4240+
to=exp.DataType(this=exp.DataType.Type.BIGINT),
4241+
)
4242+
return self.sql(hex_expr)
4243+
4244+
# Build REGEXP_REPLACE to strip formatting characters
4245+
if chars_to_remove:
4246+
# Build pattern: [$,] for multiple characters
4247+
if len(chars_to_remove) == 1:
4248+
pattern = chars_to_remove[0]
4249+
else:
4250+
# Escape special regex characters
4251+
escaped = [
4252+
c if c not in ".^$*+?{}[]\\|()" else f"\\{c}" for c in chars_to_remove
4253+
]
4254+
pattern = "[" + "".join(escaped) + "]"
4255+
4256+
# REGEXP_REPLACE(input, pattern, '', 'g')
4257+
this = exp.RegexpReplace(
4258+
this=this,
4259+
expression=exp.Literal.string(pattern),
4260+
replacement=exp.Literal.string(""),
4261+
modifiers=exp.Literal.string("g"),
4262+
)
4263+
4264+
# Determine target type
4265+
if actual_precision is None:
4266+
# No precision/scale -> BIGINT (matches NUMBER(38, 0))
4267+
target_type = exp.DataType(this=exp.DataType.Type.BIGINT)
4268+
else:
4269+
# Get precision value
4270+
if isinstance(actual_precision, exp.Literal):
4271+
prec_val = int(actual_precision.to_py())
4272+
else:
4273+
# Dynamic precision not supported
4274+
self.unsupported(
4275+
"TO_NUMBER with non-literal precision is not supported. "
4276+
"Using DOUBLE instead of DECIMAL."
4277+
)
4278+
return self.sql(exp.cast(this, exp.DataType.Type.DOUBLE))
4279+
4280+
# Get scale value (default to 0)
4281+
if actual_scale is not None:
4282+
if isinstance(actual_scale, exp.Literal):
4283+
scale_val = int(actual_scale.to_py())
4284+
else:
4285+
self.unsupported(
4286+
"TO_NUMBER with non-literal scale is not supported. "
4287+
"Using DOUBLE instead of DECIMAL."
4288+
)
4289+
return self.sql(exp.cast(this, exp.DataType.Type.DOUBLE))
4290+
else:
4291+
scale_val = 0
4292+
4293+
# Build DECIMAL(precision, scale) type
4294+
target_type = exp.DataType(
4295+
this=exp.DataType.Type.DECIMAL,
4296+
expressions=[
4297+
exp.DataTypeParam(this=exp.Literal.number(prec_val)),
4298+
exp.DataTypeParam(this=exp.Literal.number(scale_val)),
4299+
],
4300+
)
4301+
4302+
return self.sql(exp.Cast(this=this, to=target_type))
4303+
41844304
def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str:
41854305
result = self.func("APPROX_QUANTILE", expression.this, expression.args.get("quantile"))
41864306

sqlglot/transforms.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -894,9 +894,9 @@ def eliminate_join_marks(expression: exp.Expr) -> exp.Expr:
894894
if not left_join_table:
895895
continue
896896

897-
assert not (len(left_join_table) > 1), (
898-
"Cannot combine JOIN predicates from different tables"
899-
)
897+
assert not (
898+
len(left_join_table) > 1
899+
), "Cannot combine JOIN predicates from different tables"
900900

901901
for col in join_cols:
902902
col.set("join_mark", False)
@@ -927,9 +927,9 @@ def eliminate_join_marks(expression: exp.Expr) -> exp.Expr:
927927

928928
if query_from.alias_or_name in new_joins:
929929
only_old_joins = old_joins.keys() - new_joins.keys()
930-
assert len(only_old_joins) >= 1, (
931-
"Cannot determine which table to use in the new FROM clause"
932-
)
930+
assert (
931+
len(only_old_joins) >= 1
932+
), "Cannot determine which table to use in the new FROM clause"
933933

934934
new_from_name = list(only_old_joins)[0]
935935
query.set("from_", exp.From(this=old_joins[new_from_name].this))

tests/dialects/test_duckdb.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,3 +2704,40 @@ def test_map_insert(self):
27042704
"snowflake": "SELECT TO_VARIANT('1')",
27052705
},
27062706
)
2707+
2708+
# TO_NUMBER / TO_DECIMAL / TO_NUMERIC transpilation tests
2709+
self.validate_all(
2710+
"SELECT TO_NUMBER('12.3456')",
2711+
write={
2712+
"duckdb": "SELECT CAST('12.3456' AS BIGINT)",
2713+
"snowflake": "SELECT TO_NUMBER('12.3456')",
2714+
},
2715+
)
2716+
self.validate_all(
2717+
"SELECT TO_NUMBER('12.3456', 10, 1)",
2718+
write={
2719+
"duckdb": "SELECT CAST('12.3456' AS DECIMAL(10, 1))",
2720+
"snowflake": "SELECT TO_NUMBER('12.3456', 10, 1)",
2721+
},
2722+
)
2723+
self.validate_all(
2724+
"SELECT TO_DECIMAL('3,741.72', '9,999.99', 6, 2)",
2725+
write={
2726+
"duckdb": "SELECT CAST(REGEXP_REPLACE('3,741.72', ',', '', 'g') AS DECIMAL(6, 2))",
2727+
"snowflake": "SELECT TO_DECIMAL('3,741.72', '9,999.99', 6, 2)",
2728+
},
2729+
)
2730+
self.validate_all(
2731+
"SELECT TO_DECIMAL('$3,741.72', '$9,999.99', 6, 2)",
2732+
write={
2733+
"duckdb": "SELECT CAST(REGEXP_REPLACE('$3,741.72', '[,\\$]', '', 'g') AS DECIMAL(6, 2))",
2734+
"snowflake": "SELECT TO_DECIMAL('$3,741.72', '$9,999.99', 6, 2)",
2735+
},
2736+
)
2737+
self.validate_all(
2738+
"SELECT TO_DECIMAL('ae5', 'XXX')",
2739+
write={
2740+
"duckdb": "SELECT CAST(CAST('0x' || 'ae5' AS UBIGINT) AS BIGINT)",
2741+
"snowflake": "SELECT TO_DECIMAL('ae5', 'XXX')",
2742+
},
2743+
)

0 commit comments

Comments
 (0)