Skip to content

Commit 3618118

Browse files
SNOW-2314365: Added support for conditional expression functions (#3767)
Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com>
1 parent 550e5b2 commit 3618118

4 files changed

Lines changed: 300 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@
2020
- Added support for the following scalar functions in `functions.py`:
2121
- `array_remove_at`
2222
- `as_boolean`
23+
- `booland`
24+
- `boolnot`
25+
- `boolor`
2326
- `boolor_agg`
27+
- `boolxor`
2428
- `chr`
29+
- `decode`
2530
- `div0null`
2631
- `dp_interval_high`
2732
- `dp_interval_low`
33+
- `greatest_ignore_nulls`
2834
- `h3_cell_to_boundary`
2935
- `h3_cell_to_parent`
3036
- `h3_cell_to_point`
@@ -38,6 +44,11 @@
3844
- `hex_decode_binary`
3945
- `last_query_id`
4046
- `last_transaction`
47+
- `least_ignore_nulls`
48+
- `nullif`
49+
- `nvl2`
50+
- `regr_valx`
51+
4152

4253
### Snowpark pandas API Updates
4354

docs/source/snowpark/functions.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ Functions
112112
bitxor
113113
bitxor_agg
114114
boolor_agg
115+
booland
116+
boolnot
117+
boolor
118+
boolxor
115119
build_stage_file_url
116120
builtin
117121
bround
@@ -184,6 +188,7 @@ Functions
184188
dayofmonth
185189
dayofweek
186190
dayofyear
191+
decode
187192
degrees
188193
dense_rank
189194
desc
@@ -229,6 +234,7 @@ Functions
229234
getdate
230235
getvariable
231236
greatest
237+
greatest_ignore_nulls
232238
grouping
233239
grouping_id
234240
hash
@@ -287,6 +293,7 @@ Functions
287293
last_value
288294
lead
289295
least
296+
least_ignore_nulls
290297
left
291298
length
292299
listagg
@@ -327,8 +334,10 @@ Functions
327334
not_
328335
nth_value
329336
ntile
337+
nullif
330338
nullifzero
331339
nvl
340+
nvl2
332341
object_agg
333342
object_construct
334343
object_construct_keep_null
@@ -365,6 +374,7 @@ Functions
365374
regr_sxx
366375
regr_sxy
367376
regr_syy
377+
regr_valx
368378
repeat
369379
replace
370380
right

src/snowflake/snowpark/_functions/scalar_functions.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,3 +1024,276 @@ def last_transaction(_emit_ast: bool = True) -> Column:
10241024
>>> assert result[0]['LAST_TRANSACTION()'] is None or isinstance(result[0]['LAST_TRANSACTION()'], str)
10251025
"""
10261026
return builtin("last_transaction", _emit_ast=_emit_ast)()
1027+
1028+
1029+
@publicapi
1030+
def booland(expr1: ColumnOrName, expr2: ColumnOrName, _emit_ast: bool = True) -> Column:
1031+
"""
1032+
Computes the Boolean AND of two numeric expressions. In accordance with Boolean semantics:
1033+
- Non-zero values (including negative numbers) are regarded as True.
1034+
- Zero values are regarded as False.
1035+
1036+
Args:
1037+
expr1 (ColumnOrName): The first boolean expression.
1038+
expr2 (ColumnOrName): The second boolean expression.
1039+
1040+
Returns:
1041+
- True if both expressions are non-zero.
1042+
- False if both expressions are zero or one expression is zero and the other expression is non-zero or NULL.
1043+
- NULL if both expressions are NULL or one expression is NULL and the other expression is non-zero.
1044+
1045+
Example::
1046+
>>> from snowflake.snowpark.functions import col
1047+
>>> df = session.create_dataframe([[1, -2], [0, 2], [0, 0], [5, 3]], schema=["a", "b"])
1048+
>>> df.select(booland(col("a"), col("b")).alias("result")).collect()
1049+
[Row(RESULT=True), Row(RESULT=False), Row(RESULT=False), Row(RESULT=True)]
1050+
"""
1051+
c1 = _to_col_if_str(expr1, "booland")
1052+
c2 = _to_col_if_str(expr2, "booland")
1053+
return builtin("booland", _emit_ast=_emit_ast)(c1, c2)
1054+
1055+
1056+
@publicapi
1057+
def boolnot(e: ColumnOrName, _emit_ast: bool = True) -> Column:
1058+
"""
1059+
Computes the Boolean NOT of a single numeric expression. In accordance with Boolean semantics:
1060+
- Non-zero values (including negative numbers) are regarded as True.
1061+
- Zero values are regarded as False.
1062+
1063+
Args:
1064+
e (ColumnOrName): A numeric expression to be evaluated.
1065+
1066+
Returns:
1067+
- True if the expression is zero.
1068+
- False if the expression is non-zero.
1069+
- NULL if the expression is NULL.
1070+
1071+
Example::
1072+
1073+
>>> df = session.create_dataframe([0, 10, -5], schema=["a"])
1074+
>>> df.select(boolnot("a")).collect()
1075+
[Row(BOOLNOT("A")=True), Row(BOOLNOT("A")=False), Row(BOOLNOT("A")=False)]
1076+
"""
1077+
c = _to_col_if_str(e, "boolnot")
1078+
return builtin("boolnot", _emit_ast=_emit_ast)(c)
1079+
1080+
1081+
@publicapi
1082+
def boolor(expr1: ColumnOrName, expr2: ColumnOrName, _emit_ast: bool = True) -> Column:
1083+
"""
1084+
Computes the Boolean OR of two numeric expressions. In accordance with Boolean semantics:
1085+
- Non-zero values (including negative numbers) are regarded as True.
1086+
- Zero values are regarded as False.
1087+
1088+
Args:
1089+
expr1 (ColumnOrName): The first boolean expression.
1090+
expr2 (ColumnOrName): The second boolean expression.
1091+
1092+
Returns:
1093+
- True if both expressions are non-zero or the first expression is non-zero and the second expression is zero or None.
1094+
- False if both expressions are zero.
1095+
- None if both expressions are None or the first expression is None and the second expression is zero.
1096+
1097+
Example::
1098+
1099+
>>> from snowflake.snowpark.functions import col
1100+
>>> df = session.create_dataframe([
1101+
... [1, 2],
1102+
... [-1, 0],
1103+
... [3, None],
1104+
... [0, 0],
1105+
... [None, 0],
1106+
... [None, None]
1107+
... ], schema=["expr1", "expr2"])
1108+
>>> df.select(boolor(col("expr1"), col("expr2")).alias("result")).collect()
1109+
[Row(RESULT=True), Row(RESULT=True), Row(RESULT=True), Row(RESULT=False), Row(RESULT=None), Row(RESULT=None)]
1110+
"""
1111+
c1 = _to_col_if_str(expr1, "boolor")
1112+
c2 = _to_col_if_str(expr2, "boolor")
1113+
return builtin("boolor", _emit_ast=_emit_ast)(c1, c2)
1114+
1115+
1116+
@publicapi
1117+
def boolxor(expr1: ColumnOrName, expr2: ColumnOrName, _emit_ast: bool = True) -> Column:
1118+
"""
1119+
Computes the Boolean XOR of two numeric expressions (i.e. one of the expressions, but not both expressions, is True). In accordance with Boolean semantics:
1120+
- Non-zero values (including negative numbers) are regarded as True.
1121+
- Zero values are regarded as False.
1122+
1123+
Args:
1124+
expr1 (ColumnOrName): First numeric expression or a string name of the column.
1125+
expr2 (ColumnOrName): Second numeric expression or a string name of the column.
1126+
1127+
Returns:
1128+
- True if exactly one of the expressions is non-zero.
1129+
- False if both expressions are zero or both expressions are non-zero.
1130+
- None if both expressions are None, or one expression is None and the other expression is zero.
1131+
1132+
Example::
1133+
>>> from snowflake.snowpark.functions import col
1134+
>>> df = session.create_dataframe([[2, 0], [1, -1], [0, 0], [None, 3]], schema=["a", "b"])
1135+
>>> df.select(boolxor(col("a"), col("b")).alias("result")).collect()
1136+
[Row(RESULT=True), Row(RESULT=False), Row(RESULT=False), Row(RESULT=None)]
1137+
"""
1138+
c1 = _to_col_if_str(expr1, "boolxor")
1139+
c2 = _to_col_if_str(expr2, "boolxor")
1140+
return builtin("boolxor", _emit_ast=_emit_ast)(c1, c2)
1141+
1142+
1143+
@publicapi
1144+
def decode(expr: ColumnOrName, *args: ColumnOrName, _emit_ast: bool = True) -> Column:
1145+
"""Decodes an expression by comparing it with search values and returning corresponding result values.
1146+
1147+
Similar to a Case statement, this function compares an expression to one or more search values
1148+
and returns the corresponding result when a match is found.
1149+
1150+
Args:
1151+
expr (ColumnOrName): The expression to decode.
1152+
*args (ColumnOrName): Variable length argument list containing pairs of search values and
1153+
result values, with an optional default value at the end.
1154+
1155+
1156+
Returns:
1157+
Column: The decoded result.
1158+
1159+
Example:
1160+
1161+
>>> from snowflake.snowpark.functions import col, lit
1162+
>>> df = session.create_dataframe([[1, 1], [2, 4], [16, 24]], schema=["a", "b"])
1163+
>>> df.select(decode(col("a"), lit(1), lit("one"), lit(2), lit("two"), lit("default")).alias("RESULT")).collect()
1164+
[Row(RESULT='one'), Row(RESULT='two'), Row(RESULT='default')]
1165+
"""
1166+
expr_col = _to_col_if_str(expr, "decode")
1167+
arg_cols = [_to_col_if_str(arg, "decode") for arg in args]
1168+
return builtin("decode", _emit_ast=_emit_ast)(expr_col, *arg_cols)
1169+
1170+
1171+
@publicapi
1172+
def greatest_ignore_nulls(*columns: ColumnOrName, _emit_ast: bool = True) -> Column:
1173+
"""
1174+
Returns the largest value from a list of expressions, ignoring None values.
1175+
If all argument values are None, the result is None.
1176+
1177+
Args:
1178+
columns (ColumnOrName): The name strings to compare.
1179+
1180+
Returns:
1181+
Column: The greatest value, ignoring None values.
1182+
1183+
Examples::
1184+
1185+
>>> df = session.create_dataframe([[1, 2, 3, 4.25], [2, 4, -1, None], [3, 6, None, -2.75]], schema=["a", "b", "c", "d"])
1186+
>>> df.select(greatest_ignore_nulls(df["a"], df["b"], df["c"], df["d"]).alias("greatest_ignore_nulls")).collect()
1187+
[Row(GREATEST_IGNORE_NULLS=4.25), Row(GREATEST_IGNORE_NULLS=4.0), Row(GREATEST_IGNORE_NULLS=6.0)]
1188+
"""
1189+
c = [_to_col_if_str(ex, "greatest_ignore_nulls") for ex in columns]
1190+
return builtin("greatest_ignore_nulls", _emit_ast=_emit_ast)(*c)
1191+
1192+
1193+
@publicapi
1194+
def least_ignore_nulls(*columns: ColumnOrName, _emit_ast: bool = True) -> Column:
1195+
"""
1196+
Returns the smallest value from a list of expressions, ignoring None values.
1197+
If all argument values are None, the result is None.
1198+
1199+
Args:
1200+
columns (ColumnOrName): list of column or column names to compare.
1201+
1202+
Returns:
1203+
Column: The smallest value from the list of expressions, ignoring None values.
1204+
1205+
Example::
1206+
1207+
>>> df = session.create_dataframe([[1, 2, 3], [2, 4, -1], [3, 6, None]], schema=["a", "b", "c"])
1208+
>>> df.select(least_ignore_nulls(df["a"], df["b"], df["c"]).alias("least_ignore_nulls")).collect()
1209+
[Row(LEAST_IGNORE_NULLS=1), Row(LEAST_IGNORE_NULLS=-1), Row(LEAST_IGNORE_NULLS=3)]
1210+
"""
1211+
c = [_to_col_if_str(ex, "least_ignore_nulls") for ex in columns]
1212+
return builtin("least_ignore_nulls", _emit_ast=_emit_ast)(*c)
1213+
1214+
1215+
@publicapi
1216+
def nullif(expr1: ColumnOrName, expr2: ColumnOrName, _emit_ast: bool = True) -> Column:
1217+
"""
1218+
Returns None if expr1 is equal to expr2, otherwise returns expr1.
1219+
1220+
Args:
1221+
expr1 (ColumnOrName): The first expression to compare.
1222+
expr2 (ColumnOrName): The second expression to compare.
1223+
1224+
Returns:
1225+
Column: None if expr1 is equal to expr2, otherwise expr1.
1226+
1227+
Example::
1228+
1229+
>>> df = session.create_dataframe([[0, 0], [0, 1], [1, 0], [1, 1], [None, 0]], schema=["a", "b"])
1230+
>>> df.select(nullif(df["a"], df["b"]).alias("result")).collect()
1231+
[Row(RESULT=None), Row(RESULT=0), Row(RESULT=1), Row(RESULT=None), Row(RESULT=None)]
1232+
"""
1233+
c1 = _to_col_if_str(expr1, "nullif")
1234+
c2 = _to_col_if_str(expr2, "nullif")
1235+
return builtin("nullif", _emit_ast=_emit_ast)(c1, c2)
1236+
1237+
1238+
@publicapi
1239+
def nvl2(
1240+
expr1: ColumnOrName,
1241+
expr2: ColumnOrName,
1242+
expr3: ColumnOrName,
1243+
_emit_ast: bool = True,
1244+
) -> Column:
1245+
"""
1246+
Returns expr2 if expr1 is not None, otherwise returns expr3.
1247+
1248+
Args:
1249+
expr1 (ColumnOrName): The expression to test for None.
1250+
expr2 (ColumnOrName): The value to return if expr1 is not None.
1251+
expr3 (ColumnOrName): The value to return if expr1 is None.
1252+
1253+
Returns:
1254+
Column: The result of the nvl2 function.
1255+
1256+
Example::
1257+
1258+
>>> from snowflake.snowpark.functions import col
1259+
>>> df = session.create_dataframe([
1260+
... [0, 5, 3],
1261+
... [0, 5, None],
1262+
... [0, None, 3],
1263+
... [None, 5, 3],
1264+
... [None, None, 3]
1265+
... ], schema=["a", "b", "c"])
1266+
>>> df.select(nvl2(col("a"), col("b"), col("c")).alias("nvl2_result")).collect()
1267+
[Row(NVL2_RESULT=5), Row(NVL2_RESULT=5), Row(NVL2_RESULT=None), Row(NVL2_RESULT=3), Row(NVL2_RESULT=3)]
1268+
"""
1269+
c1 = _to_col_if_str(expr1, "nvl2")
1270+
c2 = _to_col_if_str(expr2, "nvl2")
1271+
c3 = _to_col_if_str(expr3, "nvl2")
1272+
return builtin("nvl2", _emit_ast=_emit_ast)(c1, c2, c3)
1273+
1274+
1275+
@publicapi
1276+
def regr_valx(y: ColumnOrName, x: ColumnOrName, _emit_ast: bool = True) -> Column:
1277+
"""
1278+
Returns None if either argument is None; otherwise, returns the second argument.
1279+
Note that REGR_VALX is a None-preserving function, while the more commonly-used NVL is a None-replacing function.
1280+
1281+
Args:
1282+
y (ColumnOrName): The dependent variable column.
1283+
x (ColumnOrName): The independent variable column.
1284+
1285+
Returns:
1286+
Column: The result of the regr_valx function.
1287+
1288+
Example::
1289+
1290+
>>> from snowflake.snowpark import Row
1291+
>>> df = session.create_dataframe([[2.0, 1.0], [None, 3.0], [6.0, None]], schema=["col_y", "col_x"])
1292+
>>> result = df.select(regr_valx(df["col_y"], df["col_x"]).alias("result")).collect()
1293+
>>> assert result == [Row(RESULT=1.0), Row(RESULT=None), Row(RESULT=None)]
1294+
1295+
Important: Note the order of the arguments; y precedes x
1296+
"""
1297+
y_col = _to_col_if_str(y, "regr_valx")
1298+
x_col = _to_col_if_str(x, "regr_valx")
1299+
return builtin("regr_valx", _emit_ast=_emit_ast)(y_col, x_col)

tests/mock/test_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,30 +340,30 @@ def test_patch_unsupported_function(session):
340340
df = session.create_dataframe([[3, 1], [3, 2], [4, 3]], schema=["a", "b"])
341341
with pytest.raises(NotImplementedError):
342342
df.select(
343-
call_function("greatest_ignore_nulls", df["a"], df["b"]).alias("greatest")
343+
call_function("my_function", df["a"], df["b"]).alias("greatest")
344344
).collect()
345345

346-
@patch("greatest_ignore_nulls")
347-
def mock_greatest_ignore_nulls(
346+
@patch("my_mocked_function")
347+
def mock_my_mocked_function(
348348
*columns: Iterable[ColumnEmulator],
349349
) -> ColumnEmulator:
350350
return ColumnEmulator(
351351
[1] * len(columns[0]), sf_type=ColumnType(IntegerType(), False)
352352
)
353353

354354
assert df.select(
355-
call_function("greatest_ignore_nulls", df["a"], df["b"]).alias("greatest")
355+
call_function("my_mocked_function", df["a"], df["b"]).alias("greatest")
356356
).collect() == [Row(1), Row(1), Row(1)]
357357

358-
@patch("greatest_ignore_nulls")
358+
@patch("my_mocked_function_2")
359359
def mock_wrong_patch(columns: Iterable[ColumnEmulator]) -> ColumnEmulator:
360360
return ColumnEmulator(
361361
[1] * len(columns[0]), sf_type=ColumnType(IntegerType(), False)
362362
)
363363

364364
with pytest.raises(SnowparkLocalTestingException) as exc:
365365
df.select(
366-
call_function("greatest_ignore_nulls", df["a"], df["b"]).alias("greatest")
366+
call_function("my_mocked_function_2", df["a"], df["b"]).alias("greatest")
367367
).collect()
368368
assert "Please ensure the implementation follows specifications" in str(exc.value)
369369

0 commit comments

Comments
 (0)