diff --git a/CHANGELOG.md b/CHANGELOG.md index b6278fa768..5b3a8a0ce4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,13 @@ - `get_cloud_provider_token` - Added support for the following scalar functions in `functions.py`: + - `array_remove_at` + - `as_boolean` + - `boolor_agg` + - `chr` + - `div0null` + - `dp_interval_high` + - `dp_interval_low` - `h3_cell_to_boundary` - `h3_cell_to_parent` - `h3_cell_to_point` @@ -24,6 +31,9 @@ - `h3_get_resolution` - `h3_grid_disk` - `h3_grid_distance` + - `hex_decode_binary` + - `last_query_id` + - `last_transaction` ### Snowpark pandas API Updates diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index 9ae84e13f8..4624fe359c 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -59,6 +59,7 @@ Functions array_position array_prepend array_remove + array_remove_at array_reverse array_size array_slice @@ -70,6 +71,7 @@ Functions arrays_zip as_array as_binary + as_boolean as_char as_date as_decimal @@ -109,6 +111,7 @@ Functions bitshiftright bitxor bitxor_agg + boolor_agg build_stage_file_url builtin bround @@ -123,6 +126,7 @@ Functions charindex check_json check_xml + chr coalesce col collate @@ -187,6 +191,8 @@ Functions desc_nulls_last div0 divnull + dp_interval_high + dp_interval_low editdistance endswith equal_nan @@ -227,6 +233,7 @@ Functions grouping_id hash hex + hex_decode_binary hex_encode hour h3_cell_to_boundary @@ -275,6 +282,8 @@ Functions kurtosis lag last_day + last_query_id + last_transaction last_value lead least diff --git a/src/snowflake/snowpark/_functions/scalar_functions.py b/src/snowflake/snowpark/_functions/scalar_functions.py index 924feccf95..0e92ee624c 100644 --- a/src/snowflake/snowpark/_functions/scalar_functions.py +++ b/src/snowflake/snowpark/_functions/scalar_functions.py @@ -763,3 +763,264 @@ def h3_grid_distance( cell_id_1 = _to_col_if_str(cell_id_1, "h3_grid_distance") cell_id_2 = _to_col_if_str(cell_id_2, "h3_grid_distance") return builtin("h3_grid_distance", _emit_ast=_emit_ast)(cell_id_1, cell_id_2) + + +@publicapi +def array_remove_at( + array: ColumnOrName, position: ColumnOrName, _emit_ast: bool = True +) -> Column: + """ + Returns an ARRAY with the element at the specified position removed. + + Args: + array (ColumnOrName): Column containing the source ARRAY. + position (ColumnOrName): Column containing a (zero-based) position in the source ARRAY. + The element at this position is removed from the resulting ARRAY. + A negative position is interpreted as an index from the back of the array (e.g. -1 removes the last element in the array). + + Returns: + Column: The resulting ARRAY with the specified element removed. + + Example:: + + >>> df = session.create_dataframe([([2, 5, 7], 0), ([2, 5, 7], -1), ([2, 5, 7], 10)], schema=["array_col", "position_col"]) + >>> df.select(array_remove_at("array_col", "position_col").alias("result")).collect() + [Row(RESULT='[\\n 5,\\n 7\\n]'), Row(RESULT='[\\n 2,\\n 5\\n]'), Row(RESULT='[\\n 2,\\n 5,\\n 7\\n]')] + """ + a = _to_col_if_str(array, "array_remove_at") + p = _to_col_if_str(position, "array_remove_at") + return builtin("array_remove_at", _emit_ast=_emit_ast)(a, p) + + +@publicapi +def as_boolean(variant: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Casts a VARIANT value to a boolean. + + Args: + variant (ColumnOrName): A Column or column name containing VARIANT values to be cast to boolean. + + Returns: + ColumnL The boolean values cast from the VARIANT input. + + Example:: + >>> from snowflake.snowpark.functions import to_variant, to_boolean + >>> df = session.create_dataframe([ + ... [True], + ... [False] + ... ], schema=["a"]) + >>> df.select(as_boolean(to_variant(to_boolean(df["a"]))).alias("result")).collect() + [Row(RESULT=True), Row(RESULT=False)] + """ + c = _to_col_if_str(variant, "as_boolean") + return builtin("as_boolean", _emit_ast=_emit_ast)(c) + + +@publicapi +def boolor_agg(e: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Returns the logical OR of all non-NULL records in a group. If all records are NULL, returns NULL. + + Args: + e (ColumnOrName): Boolean values to aggregate. + + Returns: + Column: The logical OR aggregation result. + + Example:: + + >>> df = session.create_dataframe([ + ... [True, False, True], + ... [False, False, False], + ... [True, True, False], + ... [False, True, True] + ... ], schema=["a", "b", "c"]) + >>> df.select( + ... boolor_agg(df["a"]).alias("boolor_a"), + ... boolor_agg(df["b"]).alias("boolor_b"), + ... boolor_agg(df["c"]).alias("boolor_c") + ... ).collect() + [Row(BOOLOR_A=True, BOOLOR_B=True, BOOLOR_C=True)] + """ + c = _to_col_if_str(e, "boolor_agg") + return builtin("boolor_agg", _emit_ast=_emit_ast)(c) + + +@publicapi +def chr(col: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Converts a Unicode code point (including 7-bit ASCII) into the character that matches the input Unicode. + + Args: + col (ColumnOrName): Integer Unicode code points. + + Returns: + Column: The corresponding character for each code point. + + Example:: + + >>> df = session.create_dataframe([83, 33, 169, 8364, None], schema=['a']) + >>> df.select(df.a, chr(df.a).as_('char')).sort(df.a).show() + ----------------- + |"A" |"CHAR" | + ----------------- + |NULL |NULL | + |33 |! | + |83 |S | + |169 |© | + |8364 |€ | + ----------------- + + """ + c = _to_col_if_str(col, "chr") + return builtin("chr", _emit_ast=_emit_ast)(c) + + +@publicapi +def div0null( + dividend: Union[ColumnOrName, int, float], + divisor: Union[ColumnOrName, int, float], + _emit_ast: bool = True, +) -> Column: + """ + Performs division like the division operator (/), but returns 0 when the divisor is 0 or NULL (rather than reporting an error). + + Args: + dividend (ColumnOrName, int, float): The dividend. + divisor (ColumnOrName, int, float): The divisor. + + Returns: + Column: The result of the division, with 0 returned for cases where the divisor is 0 or NULL. + + Example:: + + >>> df = session.create_dataframe([[10, 2], [10, 0], [10, None]], schema=["dividend", "divisor"]) + >>> df.select(div0null(df["dividend"], df["divisor"]).alias("result")).collect() + [Row(RESULT=Decimal('5.000000')), Row(RESULT=Decimal('0.000000')), Row(RESULT=Decimal('0.000000'))] + """ + dividend_col = ( + lit(dividend) + if isinstance(dividend, (int, float)) + else _to_col_if_str(dividend, "div0null") + ) + divisor_col = ( + lit(divisor) + if isinstance(divisor, (int, float)) + else _to_col_if_str(divisor, "div0null") + ) + return builtin("div0null", _emit_ast=_emit_ast)(dividend_col, divisor_col) + + +@publicapi +def dp_interval_high(aggregated_column: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Returns the high end of the confidence interval for a differentially private aggregate. + This function is used with differential privacy aggregation functions to provide + the upper bound of the confidence interval for the aggregated result. + + Args: + aggregated_column (ColumnOrName): The result of a differential privacy aggregation function. + + Returns: + Column: The high end of the confidence interval for the differentially private aggregate. + + Example:: + + >>> from snowflake.snowpark.functions import sum as sum_ + >>> df = session.create_dataframe([[10], [20], [30]], schema=["num_claims"]) + >>> df.select(sum_(df["num_claims"]).alias("sum_claims")).select(dp_interval_high("sum_claims")).collect() + [Row(DP_INTERVAL_HIGH("SUM_CLAIMS")=None)] + """ + c = _to_col_if_str(aggregated_column, "dp_interval_high") + return builtin("dp_interval_high", _emit_ast=_emit_ast)(c) + + +@publicapi +def dp_interval_low(aggregated_column: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Returns the lower bound of the confidence interval for a differentially private aggregate. This function is used with differential privacy aggregation functions to provide statistical bounds on the results. + + Args: + aggregated_column (ColumnOrName): The differentially private aggregate result. + + Returns: + Column: The lower bound of the confidence interval. + + Example:: + + >>> from snowflake.snowpark.functions import sum as sum_ + >>> df = session.create_dataframe([[10], [20], [30]], schema=["num_claims"]) + >>> result = df.select(sum_("num_claims").alias("sum_claims")).select(dp_interval_low("sum_claims").alias("interval_low")) + >>> result.collect() + [Row(INTERVAL_LOW=None)] + """ + c = _to_col_if_str(aggregated_column, "dp_interval_low") + return builtin("dp_interval_low", _emit_ast=_emit_ast)(c) + + +@publicapi +def hex_decode_binary(input_expr: ColumnOrName, _emit_ast: bool = True) -> Column: + """ + Decodes a hex-encoded string to binary data. + + Args: + input_expr (:class:`ColumnOrName`): the hex-encoded string to decode. + Returns: + :class:`Column`: the decoded binary data. + + Example:: + + >>> df = session.create_dataframe(['48454C4C4F', '576F726C64'], schema=['hex_string']) + >>> df.select(hex_decode_binary(df['hex_string']).alias('decoded_binary')).collect() + [Row(DECODED_BINARY=bytearray(b'HELLO')), Row(DECODED_BINARY=bytearray(b'World'))] + """ + c = _to_col_if_str(input_expr, "hex_decode_binary") + return builtin("hex_decode_binary", _emit_ast=_emit_ast)(c) + + +@publicapi +def last_query_id(num: ColumnOrName = None, _emit_ast: bool = True) -> Column: + """ + Returns the query ID of the last statement executed in the current session. + If num is specified, returns the query ID of the nth statement executed in the current session. + + Args: + num (ColumnOrName, optional): The number of statements back to retrieve the query ID for. If None, returns the query ID of the last statement. + + Returns: + Column: The query ID as a string. + + Example:: + + >>> df = session.create_dataframe([1], schema=["a"]) + >>> result1 = df.select(last_query_id().alias("QUERY_ID")).collect() + >>> assert len(result1) == 1 + >>> assert isinstance(result1[0]["QUERY_ID"], str) + >>> assert len(result1[0]["QUERY_ID"]) > 0 + >>> result2 = df.select(last_query_id(1).alias("QUERY_ID")).collect() + >>> assert len(result2) == 1 + >>> assert isinstance(result2[0]["QUERY_ID"], str) + >>> assert len(result2[0]["QUERY_ID"]) > 0 + """ + if num is None: + return builtin("last_query_id", _emit_ast=_emit_ast)() + else: + return builtin("last_query_id", _emit_ast=_emit_ast)(num) + + +@publicapi +def last_transaction(_emit_ast: bool = True) -> Column: + """ + Returns the query ID of the last transaction committed or rolled back in the current session. If no transaction has been committed or rolled back in the current session, returns NULL. + + Returns: + Column: The last transaction. + + Example:: + + >>> df = session.create_dataframe([1]) + >>> result = df.select(last_transaction()).collect() + >>> # Result will be None if no transaction has occurred + >>> assert result[0]['LAST_TRANSACTION()'] is None or isinstance(result[0]['LAST_TRANSACTION()'], str) + """ + return builtin("last_transaction", _emit_ast=_emit_ast)()