Skip to content
Merged
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
- `get_cloud_provider_token`

- Added support for the following scalar functions in `functions.py`:
- `array_remove_at`
- `as_boolean`
- `boolor_agg`
- `chr`
- `div0null`
- `dp_interval_high`
- `dp_interval_low`
- `h3_cell_to_boundary`
- `h3_cell_to_parent`
- `h3_cell_to_point`
Expand All @@ -24,6 +31,9 @@
- `h3_get_resolution`
- `h3_grid_disk`
- `h3_grid_distance`
- `hex_decode_binary`
- `last_query_id`
- `last_transaction`

### Snowpark pandas API Updates

Expand Down
9 changes: 9 additions & 0 deletions docs/source/snowpark/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Functions
array_position
array_prepend
array_remove
array_remove_at
array_reverse
array_size
array_slice
Expand All @@ -70,6 +71,7 @@ Functions
arrays_zip
as_array
as_binary
as_boolean
as_char
as_date
as_decimal
Expand Down Expand Up @@ -109,6 +111,7 @@ Functions
bitshiftright
bitxor
bitxor_agg
boolor_agg
build_stage_file_url
builtin
bround
Expand All @@ -123,6 +126,7 @@ Functions
charindex
check_json
check_xml
chr
coalesce
col
collate
Expand Down Expand Up @@ -187,6 +191,8 @@ Functions
desc_nulls_last
div0
divnull
dp_interval_high
dp_interval_low
editdistance
endswith
equal_nan
Expand Down Expand Up @@ -227,6 +233,7 @@ Functions
grouping_id
hash
hex
hex_decode_binary
hex_encode
hour
h3_cell_to_boundary
Expand Down Expand Up @@ -275,6 +282,8 @@ Functions
kurtosis
lag
last_day
last_query_id
last_transaction
last_value
lead
least
Expand Down
261 changes: 261 additions & 0 deletions src/snowflake/snowpark/_functions/scalar_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,264 @@ def h3_grid_distance(
cell_id_1 = _to_col_if_str(cell_id_1, "h3_grid_distance")
cell_id_2 = _to_col_if_str(cell_id_2, "h3_grid_distance")
return builtin("h3_grid_distance", _emit_ast=_emit_ast)(cell_id_1, cell_id_2)


@publicapi
def array_remove_at(
array: ColumnOrName, position: ColumnOrName, _emit_ast: bool = True
) -> Column:
"""
Returns an ARRAY with the element at the specified position removed.

Args:
array (ColumnOrName): Column containing the source ARRAY.
position (ColumnOrName): Column containing a (zero-based) position in the source ARRAY.
The element at this position is removed from the resulting ARRAY.
A negative position is interpreted as an index from the back of the array (e.g. -1 removes the last element in the array).

Returns:
Column: The resulting ARRAY with the specified element removed.

Example::

>>> df = session.create_dataframe([([2, 5, 7], 0), ([2, 5, 7], -1), ([2, 5, 7], 10)], schema=["array_col", "position_col"])
>>> df.select(array_remove_at("array_col", "position_col").alias("result")).collect()
[Row(RESULT='[\\n 5,\\n 7\\n]'), Row(RESULT='[\\n 2,\\n 5\\n]'), Row(RESULT='[\\n 2,\\n 5,\\n 7\\n]')]
"""
a = _to_col_if_str(array, "array_remove_at")
p = _to_col_if_str(position, "array_remove_at")
return builtin("array_remove_at", _emit_ast=_emit_ast)(a, p)


@publicapi
def as_boolean(variant: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Casts a VARIANT value to a boolean.

Args:
variant (ColumnOrName): A Column or column name containing VARIANT values to be cast to boolean.

Returns:
ColumnL The boolean values cast from the VARIANT input.

Example::
>>> from snowflake.snowpark.functions import to_variant, to_boolean
>>> df = session.create_dataframe([
... [True],
... [False]
... ], schema=["a"])
>>> df.select(as_boolean(to_variant(to_boolean(df["a"]))).alias("result")).collect()
[Row(RESULT=True), Row(RESULT=False)]
"""
c = _to_col_if_str(variant, "as_boolean")
return builtin("as_boolean", _emit_ast=_emit_ast)(c)


@publicapi
def boolor_agg(e: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Returns the logical OR of all non-NULL records in a group. If all records are NULL, returns NULL.

Args:
e (ColumnOrName): Boolean values to aggregate.

Returns:
Column: The logical OR aggregation result.

Example::

>>> df = session.create_dataframe([
... [True, False, True],
... [False, False, False],
... [True, True, False],
... [False, True, True]
... ], schema=["a", "b", "c"])
>>> df.select(
... boolor_agg(df["a"]).alias("boolor_a"),
... boolor_agg(df["b"]).alias("boolor_b"),
... boolor_agg(df["c"]).alias("boolor_c")
... ).collect()
[Row(BOOLOR_A=True, BOOLOR_B=True, BOOLOR_C=True)]
"""
c = _to_col_if_str(e, "boolor_agg")
return builtin("boolor_agg", _emit_ast=_emit_ast)(c)


@publicapi
def chr(col: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Converts a Unicode code point (including 7-bit ASCII) into the character that matches the input Unicode.

Args:
col (ColumnOrName): Integer Unicode code points.

Returns:
Column: The corresponding character for each code point.

Example::

>>> df = session.create_dataframe([83, 33, 169, 8364, None], schema=['a'])
>>> df.select(df.a, chr(df.a).as_('char')).sort(df.a).show()
-----------------
|"A" |"CHAR" |
-----------------
|NULL |NULL |
|33 |! |
|83 |S |
|169 |© |
|8364 |€ |
-----------------
<BLANKLINE>
"""
c = _to_col_if_str(col, "chr")
return builtin("chr", _emit_ast=_emit_ast)(c)


@publicapi
def div0null(
dividend: Union[ColumnOrName, int, float],
divisor: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Performs division like the division operator (/), but returns 0 when the divisor is 0 or NULL (rather than reporting an error).

Args:
dividend (ColumnOrName, int, float): The dividend.
divisor (ColumnOrName, int, float): The divisor.

Returns:
Column: The result of the division, with 0 returned for cases where the divisor is 0 or NULL.

Example::

>>> df = session.create_dataframe([[10, 2], [10, 0], [10, None]], schema=["dividend", "divisor"])
>>> df.select(div0null(df["dividend"], df["divisor"]).alias("result")).collect()
[Row(RESULT=Decimal('5.000000')), Row(RESULT=Decimal('0.000000')), Row(RESULT=Decimal('0.000000'))]
"""
dividend_col = (
lit(dividend)
if isinstance(dividend, (int, float))
else _to_col_if_str(dividend, "div0null")
)
divisor_col = (
lit(divisor)
if isinstance(divisor, (int, float))
else _to_col_if_str(divisor, "div0null")
)
return builtin("div0null", _emit_ast=_emit_ast)(dividend_col, divisor_col)


@publicapi
def dp_interval_high(aggregated_column: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Returns the high end of the confidence interval for a differentially private aggregate.
This function is used with differential privacy aggregation functions to provide
the upper bound of the confidence interval for the aggregated result.

Args:
aggregated_column (ColumnOrName): The result of a differential privacy aggregation function.

Returns:
Column: The high end of the confidence interval for the differentially private aggregate.

Example::

>>> from snowflake.snowpark.functions import sum as sum_
>>> df = session.create_dataframe([[10], [20], [30]], schema=["num_claims"])
>>> df.select(sum_(df["num_claims"]).alias("sum_claims")).select(dp_interval_high("sum_claims")).collect()
[Row(DP_INTERVAL_HIGH("SUM_CLAIMS")=None)]
"""
c = _to_col_if_str(aggregated_column, "dp_interval_high")
return builtin("dp_interval_high", _emit_ast=_emit_ast)(c)


@publicapi
def dp_interval_low(aggregated_column: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Returns the lower bound of the confidence interval for a differentially private aggregate. This function is used with differential privacy aggregation functions to provide statistical bounds on the results.

Args:
aggregated_column (ColumnOrName): The differentially private aggregate result.

Returns:
Column: The lower bound of the confidence interval.

Example::

>>> from snowflake.snowpark.functions import sum as sum_
>>> df = session.create_dataframe([[10], [20], [30]], schema=["num_claims"])
>>> result = df.select(sum_("num_claims").alias("sum_claims")).select(dp_interval_low("sum_claims").alias("interval_low"))
>>> result.collect()
[Row(INTERVAL_LOW=None)]
"""
c = _to_col_if_str(aggregated_column, "dp_interval_low")
return builtin("dp_interval_low", _emit_ast=_emit_ast)(c)


@publicapi
def hex_decode_binary(input_expr: ColumnOrName, _emit_ast: bool = True) -> Column:
"""
Decodes a hex-encoded string to binary data.

Args:
input_expr (:class:`ColumnOrName`): the hex-encoded string to decode.
Returns:
:class:`Column`: the decoded binary data.

Example::

>>> df = session.create_dataframe(['48454C4C4F', '576F726C64'], schema=['hex_string'])
>>> df.select(hex_decode_binary(df['hex_string']).alias('decoded_binary')).collect()
[Row(DECODED_BINARY=bytearray(b'HELLO')), Row(DECODED_BINARY=bytearray(b'World'))]
"""
c = _to_col_if_str(input_expr, "hex_decode_binary")
return builtin("hex_decode_binary", _emit_ast=_emit_ast)(c)


@publicapi
def last_query_id(num: ColumnOrName = None, _emit_ast: bool = True) -> Column:
"""
Returns the query ID of the last statement executed in the current session.
If num is specified, returns the query ID of the nth statement executed in the current session.

Args:
num (ColumnOrName, optional): The number of statements back to retrieve the query ID for. If None, returns the query ID of the last statement.

Returns:
Column: The query ID as a string.

Example::

>>> df = session.create_dataframe([1], schema=["a"])
>>> result1 = df.select(last_query_id().alias("QUERY_ID")).collect()
>>> assert len(result1) == 1
>>> assert isinstance(result1[0]["QUERY_ID"], str)
>>> assert len(result1[0]["QUERY_ID"]) > 0
>>> result2 = df.select(last_query_id(1).alias("QUERY_ID")).collect()
>>> assert len(result2) == 1
>>> assert isinstance(result2[0]["QUERY_ID"], str)
>>> assert len(result2[0]["QUERY_ID"]) > 0
"""
if num is None:
return builtin("last_query_id", _emit_ast=_emit_ast)()
else:
return builtin("last_query_id", _emit_ast=_emit_ast)(num)


@publicapi
def last_transaction(_emit_ast: bool = True) -> Column:
"""
Returns the query ID of the last transaction committed or rolled back in the current session. If no transaction has been committed or rolled back in the current session, returns NULL.

Returns:
Column: The last transaction.

Example::

>>> df = session.create_dataframe([1])
>>> result = df.select(last_transaction()).collect()
>>> # Result will be None if no transaction has occurred
>>> assert result[0]['LAST_TRANSACTION()'] is None or isinstance(result[0]['LAST_TRANSACTION()'], str)
"""
return builtin("last_transaction", _emit_ast=_emit_ast)()