Skip to content

Commit 36211c2

Browse files
feat(snowflake)!: transpile BASE64_DECODE_STRING/BINARY to DuckDB (#6837)
* transpilation of base64_decode* * fixed tests, refactored replacement
1 parent 6ebe5cc commit 36211c2

2 files changed

Lines changed: 86 additions & 26 deletions

File tree

sqlglot/dialects/duckdb.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,57 @@
9797
_SEQ_RESTRICTED = (exp.Where, exp.Having, exp.AggFunc, exp.Order, exp.Select)
9898

9999

100+
def _apply_base64_alphabet_replacements(
101+
result: exp.Expression,
102+
alphabet: t.Optional[exp.Expression],
103+
reverse: bool = False,
104+
) -> exp.Expression:
105+
"""
106+
Apply base64 alphabet character replacements.
107+
108+
Base64 alphabet can be 1-3 chars: 1st = index 62 ('+'), 2nd = index 63 ('/'), 3rd = padding ('=').
109+
zip truncates to the shorter string, so 1-char alphabet only replaces '+', 2-char replaces '+/', etc.
110+
111+
Args:
112+
result: The expression to apply replacements to
113+
alphabet: Custom alphabet literal (expected chars for +/=)
114+
reverse: If False, replace default with custom (encode)
115+
If True, replace custom with default (decode)
116+
"""
117+
if isinstance(alphabet, exp.Literal) and alphabet.is_string:
118+
for default_char, new_char in zip("+/=", alphabet.this):
119+
if new_char != default_char:
120+
find, replace = (new_char, default_char) if reverse else (default_char, new_char)
121+
result = exp.Replace(
122+
this=result,
123+
expression=exp.Literal.string(find),
124+
replacement=exp.Literal.string(replace),
125+
)
126+
return result
127+
128+
129+
def _base64_decode_sql(self: DuckDB.Generator, expression: exp.Expression, to_string: bool) -> str:
130+
"""
131+
Transpile Snowflake BASE64_DECODE_STRING/BINARY to DuckDB.
132+
133+
DuckDB uses FROM_BASE64() which returns BLOB. For string output, wrap with DECODE().
134+
Custom alphabets require REPLACE() calls to convert to standard base64.
135+
"""
136+
input_expr = expression.this
137+
alphabet = expression.args.get("alphabet")
138+
139+
# Handle custom alphabet by replacing non-standard chars with standard ones
140+
input_expr = _apply_base64_alphabet_replacements(input_expr, alphabet, reverse=True)
141+
142+
# FROM_BASE64 returns BLOB
143+
input_expr = exp.FromBase64(this=input_expr)
144+
145+
if to_string:
146+
input_expr = exp.Decode(this=input_expr)
147+
148+
return self.sql(input_expr)
149+
150+
100151
def _last_day_sql(self: DuckDB.Generator, expression: exp.LastDay) -> str:
101152
"""
102153
DuckDB's LAST_DAY only supports finding the last day of a month.
@@ -1536,6 +1587,8 @@ class Generator(generator.Generator):
15361587
exp.ArrayUniqueAgg: lambda self, e: self.func(
15371588
"LIST", exp.Distinct(expressions=[e.this])
15381589
),
1590+
exp.Base64DecodeBinary: lambda self, e: _base64_decode_sql(self, e, to_string=False),
1591+
exp.Base64DecodeString: lambda self, e: _base64_decode_sql(self, e, to_string=True),
15391592
exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"),
15401593
exp.BitwiseAndAgg: _bitwise_agg_sql,
15411594
exp.BitwiseLeftShift: _bitshift_sql,
@@ -2742,29 +2795,19 @@ def base64encode_sql(self, expression: exp.Base64Encode) -> str:
27422795
# DuckDB TO_BASE64 requires BLOB input
27432796
# Snowflake BASE64_ENCODE accepts both VARCHAR and BINARY - for VARCHAR it implicitly
27442797
# encodes UTF-8 bytes. We add ENCODE unless the input is a binary type.
2745-
input_expr = expression.this
2798+
result = expression.this
27462799

27472800
# Check if input is a string type - ENCODE only accepts VARCHAR
2748-
result = input_expr
2749-
if input_expr.is_type(*exp.DataType.TEXT_TYPES):
2750-
result = exp.Encode(this=input_expr)
2801+
if result.is_type(*exp.DataType.TEXT_TYPES):
2802+
result = exp.Encode(this=result)
27512803

27522804
result = exp.ToBase64(this=result)
27532805

27542806
max_line_length = expression.args.get("max_line_length")
27552807
alphabet = expression.args.get("alphabet")
27562808

2757-
# Handle custom alphabet by replacing characters (applied before line breaks)
2758-
# Alphabet can be 1-3 chars: 1st = index 62 ('+'), 2nd = index 63 ('/'), 3rd = padding ('=')
2759-
# zip truncates to the shorter string, so 1-char alphabet only replaces '+', 2-char replaces '+/'
2760-
if isinstance(alphabet, exp.Literal) and alphabet.is_string:
2761-
for default_char, new_char in zip("+/=", alphabet.this):
2762-
if new_char != default_char:
2763-
result = exp.Replace(
2764-
this=result,
2765-
expression=exp.Literal.string(default_char),
2766-
replacement=exp.Literal.string(new_char),
2767-
)
2809+
# Handle custom alphabet by replacing standard chars with custom ones
2810+
result = _apply_base64_alphabet_replacements(result, alphabet)
27682811

27692812
# Handle max_line_length by inserting newlines every N characters
27702813
line_length = (

tests/dialects/test_snowflake.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2562,12 +2562,6 @@ def test_snowflake(self):
25622562
"SELECT ILIKE(col, 'pattern', '!')", "SELECT col ILIKE 'pattern' ESCAPE '!'"
25632563
)
25642564

2565-
self.validate_identity("SELECT BASE64_DECODE_BINARY('SGVsbG8=')")
2566-
self.validate_identity("SELECT BASE64_DECODE_BINARY('SGVsbG8=')")
2567-
2568-
self.validate_identity("SELECT BASE64_DECODE_STRING('SGVsbG8gV29ybGQ=')")
2569-
self.validate_identity("SELECT BASE64_DECODE_STRING('SGVsbG8gV29ybGQ=', '+/=')")
2570-
25712565
expr = self.validate_identity("SELECT BASE64_ENCODE('Hello World')")
25722566
annotated = annotate_types(expr, dialect="snowflake")
25732567
self.assertEqual(annotated.sql("duckdb"), "SELECT TO_BASE64(ENCODE('Hello World'))")
@@ -2593,11 +2587,34 @@ def test_snowflake(self):
25932587
},
25942588
)
25952589

2596-
self.validate_identity("SELECT TRY_BASE64_DECODE_BINARY('SGVsbG8=')")
2597-
self.validate_identity("SELECT TRY_BASE64_DECODE_BINARY('SGVsbG8=', '+/=')")
2598-
2599-
self.validate_identity("SELECT TRY_BASE64_DECODE_STRING('SGVsbG8gV29ybGQ=')")
2600-
self.validate_identity("SELECT TRY_BASE64_DECODE_STRING('SGVsbG8gV29ybGQ=', '+/=')")
2590+
self.validate_all(
2591+
"SELECT BASE64_DECODE_STRING('U25vd2ZsYWtl')",
2592+
write={
2593+
"snowflake": "SELECT BASE64_DECODE_STRING('U25vd2ZsYWtl')",
2594+
"duckdb": "SELECT DECODE(FROM_BASE64('U25vd2ZsYWtl'))",
2595+
},
2596+
)
2597+
self.validate_all(
2598+
"SELECT BASE64_DECODE_STRING('U25vd2ZsYWtl', '-_+')",
2599+
write={
2600+
"snowflake": "SELECT BASE64_DECODE_STRING('U25vd2ZsYWtl', '-_+')",
2601+
"duckdb": "SELECT DECODE(FROM_BASE64(REPLACE(REPLACE(REPLACE('U25vd2ZsYWtl', '-', '+'), '_', '/'), '+', '=')))",
2602+
},
2603+
)
2604+
self.validate_all(
2605+
"SELECT BASE64_DECODE_BINARY(x)",
2606+
write={
2607+
"snowflake": "SELECT BASE64_DECODE_BINARY(x)",
2608+
"duckdb": "SELECT FROM_BASE64(x)",
2609+
},
2610+
)
2611+
self.validate_all(
2612+
"SELECT BASE64_DECODE_BINARY(x, '-_+')",
2613+
write={
2614+
"snowflake": "SELECT BASE64_DECODE_BINARY(x, '-_+')",
2615+
"duckdb": "SELECT FROM_BASE64(REPLACE(REPLACE(REPLACE(x, '-', '+'), '_', '/'), '+', '='))",
2616+
},
2617+
)
26012618

26022619
self.validate_identity("SELECT TRY_HEX_DECODE_BINARY('48656C6C6F')")
26032620

0 commit comments

Comments
 (0)