|
97 | 97 | _SEQ_RESTRICTED = (exp.Where, exp.Having, exp.AggFunc, exp.Order, exp.Select) |
98 | 98 |
|
99 | 99 |
|
| 100 | +def _apply_base64_alphabet_replacements( |
| 101 | + result: exp.Expression, |
| 102 | + alphabet: t.Optional[exp.Expression], |
| 103 | + reverse: bool = False, |
| 104 | +) -> exp.Expression: |
| 105 | + """ |
| 106 | + Apply base64 alphabet character replacements. |
| 107 | +
|
| 108 | + Base64 alphabet can be 1-3 chars: 1st = index 62 ('+'), 2nd = index 63 ('/'), 3rd = padding ('='). |
| 109 | + zip truncates to the shorter string, so 1-char alphabet only replaces '+', 2-char replaces '+/', etc. |
| 110 | +
|
| 111 | + Args: |
| 112 | + result: The expression to apply replacements to |
| 113 | + alphabet: Custom alphabet literal (expected chars for +/=) |
| 114 | + reverse: If False, replace default with custom (encode) |
| 115 | + If True, replace custom with default (decode) |
| 116 | + """ |
| 117 | + if isinstance(alphabet, exp.Literal) and alphabet.is_string: |
| 118 | + for default_char, new_char in zip("+/=", alphabet.this): |
| 119 | + if new_char != default_char: |
| 120 | + find, replace = (new_char, default_char) if reverse else (default_char, new_char) |
| 121 | + result = exp.Replace( |
| 122 | + this=result, |
| 123 | + expression=exp.Literal.string(find), |
| 124 | + replacement=exp.Literal.string(replace), |
| 125 | + ) |
| 126 | + return result |
| 127 | + |
| 128 | + |
| 129 | +def _base64_decode_sql(self: DuckDB.Generator, expression: exp.Expression, to_string: bool) -> str: |
| 130 | + """ |
| 131 | + Transpile Snowflake BASE64_DECODE_STRING/BINARY to DuckDB. |
| 132 | +
|
| 133 | + DuckDB uses FROM_BASE64() which returns BLOB. For string output, wrap with DECODE(). |
| 134 | + Custom alphabets require REPLACE() calls to convert to standard base64. |
| 135 | + """ |
| 136 | + input_expr = expression.this |
| 137 | + alphabet = expression.args.get("alphabet") |
| 138 | + |
| 139 | + # Handle custom alphabet by replacing non-standard chars with standard ones |
| 140 | + input_expr = _apply_base64_alphabet_replacements(input_expr, alphabet, reverse=True) |
| 141 | + |
| 142 | + # FROM_BASE64 returns BLOB |
| 143 | + input_expr = exp.FromBase64(this=input_expr) |
| 144 | + |
| 145 | + if to_string: |
| 146 | + input_expr = exp.Decode(this=input_expr) |
| 147 | + |
| 148 | + return self.sql(input_expr) |
| 149 | + |
| 150 | + |
100 | 151 | def _last_day_sql(self: DuckDB.Generator, expression: exp.LastDay) -> str: |
101 | 152 | """ |
102 | 153 | DuckDB's LAST_DAY only supports finding the last day of a month. |
@@ -1536,6 +1587,8 @@ class Generator(generator.Generator): |
1536 | 1587 | exp.ArrayUniqueAgg: lambda self, e: self.func( |
1537 | 1588 | "LIST", exp.Distinct(expressions=[e.this]) |
1538 | 1589 | ), |
| 1590 | + exp.Base64DecodeBinary: lambda self, e: _base64_decode_sql(self, e, to_string=False), |
| 1591 | + exp.Base64DecodeString: lambda self, e: _base64_decode_sql(self, e, to_string=True), |
1539 | 1592 | exp.BitwiseAnd: lambda self, e: self._bitwise_op(e, "&"), |
1540 | 1593 | exp.BitwiseAndAgg: _bitwise_agg_sql, |
1541 | 1594 | exp.BitwiseLeftShift: _bitshift_sql, |
@@ -2742,29 +2795,19 @@ def base64encode_sql(self, expression: exp.Base64Encode) -> str: |
2742 | 2795 | # DuckDB TO_BASE64 requires BLOB input |
2743 | 2796 | # Snowflake BASE64_ENCODE accepts both VARCHAR and BINARY - for VARCHAR it implicitly |
2744 | 2797 | # encodes UTF-8 bytes. We add ENCODE unless the input is a binary type. |
2745 | | - input_expr = expression.this |
| 2798 | + result = expression.this |
2746 | 2799 |
|
2747 | 2800 | # Check if input is a string type - ENCODE only accepts VARCHAR |
2748 | | - result = input_expr |
2749 | | - if input_expr.is_type(*exp.DataType.TEXT_TYPES): |
2750 | | - result = exp.Encode(this=input_expr) |
| 2801 | + if result.is_type(*exp.DataType.TEXT_TYPES): |
| 2802 | + result = exp.Encode(this=result) |
2751 | 2803 |
|
2752 | 2804 | result = exp.ToBase64(this=result) |
2753 | 2805 |
|
2754 | 2806 | max_line_length = expression.args.get("max_line_length") |
2755 | 2807 | alphabet = expression.args.get("alphabet") |
2756 | 2808 |
|
2757 | | - # Handle custom alphabet by replacing characters (applied before line breaks) |
2758 | | - # Alphabet can be 1-3 chars: 1st = index 62 ('+'), 2nd = index 63 ('/'), 3rd = padding ('=') |
2759 | | - # zip truncates to the shorter string, so 1-char alphabet only replaces '+', 2-char replaces '+/' |
2760 | | - if isinstance(alphabet, exp.Literal) and alphabet.is_string: |
2761 | | - for default_char, new_char in zip("+/=", alphabet.this): |
2762 | | - if new_char != default_char: |
2763 | | - result = exp.Replace( |
2764 | | - this=result, |
2765 | | - expression=exp.Literal.string(default_char), |
2766 | | - replacement=exp.Literal.string(new_char), |
2767 | | - ) |
| 2809 | + # Handle custom alphabet by replacing standard chars with custom ones |
| 2810 | + result = _apply_base64_alphabet_replacements(result, alphabet) |
2768 | 2811 |
|
2769 | 2812 | # Handle max_line_length by inserting newlines every N characters |
2770 | 2813 | line_length = ( |
|
0 commit comments