Skip to content

Commit 84d7393

Browse files
timsaucerclaude
andcommitted
Improve array function APIs: optional params, better naming, restore comment
- Make null_string optional in string_to_array/string_to_list - Make step optional in gen_series/generate_series - Rename second_array to element in array_contains/list_has/list_contains - Restore # Window Functions section comment in __all__ - Add tests for optional parameter variants Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a7c7de4 commit 84d7393

File tree

3 files changed

+75
-29
lines changed

3 files changed

+75
-29
lines changed

crates/core/src/functions.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,34 @@ fn arrays_zip(exprs: Vec<PyExpr>) -> PyExpr {
109109
datafusion::functions_nested::expr_fn::arrays_zip(exprs).into()
110110
}
111111

112+
#[pyfunction]
113+
#[pyo3(signature = (string, delimiter, null_string=None))]
114+
fn string_to_array(string: PyExpr, delimiter: PyExpr, null_string: Option<PyExpr>) -> PyExpr {
115+
let mut args = vec![string.into(), delimiter.into()];
116+
if let Some(null_string) = null_string {
117+
args.push(null_string.into());
118+
}
119+
Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
120+
datafusion::functions_nested::string::string_to_array_udf(),
121+
args,
122+
))
123+
.into()
124+
}
125+
126+
#[pyfunction]
127+
#[pyo3(signature = (start, stop, step=None))]
128+
fn gen_series(start: PyExpr, stop: PyExpr, step: Option<PyExpr>) -> PyExpr {
129+
let mut args = vec![start.into(), stop.into()];
130+
if let Some(step) = step {
131+
args.push(step.into());
132+
}
133+
Expr::ScalarFunction(datafusion::logical_expr::expr::ScalarFunction::new_udf(
134+
datafusion::functions_nested::range::gen_series_udf(),
135+
args,
136+
))
137+
.into()
138+
}
139+
112140
#[pyfunction]
113141
#[pyo3(signature = (array, element, index=None))]
114142
fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
@@ -687,8 +715,6 @@ array_fn!(array_any_value, array);
687715
array_fn!(array_max, array);
688716
array_fn!(array_min, array);
689717
array_fn!(array_reverse, array);
690-
array_fn!(string_to_array, string delimiter null_string);
691-
array_fn!(gen_series, start stop step);
692718
array_fn!(cardinality, array);
693719
array_fn!(flatten, array);
694720
array_fn!(range, start stop step);

python/datafusion/functions.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@
317317
"var_samp",
318318
"var_sample",
319319
"when",
320+
# Window Functions
320321
"window",
321322
]
322323

@@ -2848,22 +2849,22 @@ def array_has_any(first_array: Expr, second_array: Expr) -> Expr:
28482849
return Expr(f.array_has_any(first_array.expr, second_array.expr))
28492850

28502851

2851-
def array_contains(first_array: Expr, second_array: Expr) -> Expr:
2852-
"""Returns true if the element appears in the first array, otherwise false.
2852+
def array_contains(array: Expr, element: Expr) -> Expr:
2853+
"""Returns true if the element appears in the array, otherwise false.
28532854
28542855
See Also:
28552856
This is an alias for :py:func:`array_has`.
28562857
"""
2857-
return array_has(first_array, second_array)
2858+
return array_has(array, element)
28582859

28592860

2860-
def list_has(first_array: Expr, second_array: Expr) -> Expr:
2861-
"""Returns true if the element appears in the first array, otherwise false.
2861+
def list_has(array: Expr, element: Expr) -> Expr:
2862+
"""Returns true if the element appears in the array, otherwise false.
28622863
28632864
See Also:
28642865
This is an alias for :py:func:`array_has`.
28652866
"""
2866-
return array_has(first_array, second_array)
2867+
return array_has(array, element)
28672868

28682869

28692870
def list_has_all(first_array: Expr, second_array: Expr) -> Expr:
@@ -2884,13 +2885,13 @@ def list_has_any(first_array: Expr, second_array: Expr) -> Expr:
28842885
return array_has_any(first_array, second_array)
28852886

28862887

2887-
def list_contains(first_array: Expr, second_array: Expr) -> Expr:
2888-
"""Returns true if the element appears in the first array, otherwise false.
2888+
def list_contains(array: Expr, element: Expr) -> Expr:
2889+
"""Returns true if the element appears in the array, otherwise false.
28892890
28902891
See Also:
28912892
This is an alias for :py:func:`array_has`.
28922893
"""
2893-
return array_has(first_array, second_array)
2894+
return array_has(array, element)
28942895

28952896

28962897
def array_position(array: Expr, element: Expr, index: int | None = 1) -> Expr:
@@ -3590,25 +3591,30 @@ def list_zip(*arrays: Expr) -> Expr:
35903591
return arrays_zip(*arrays)
35913592

35923593

3593-
def string_to_array(string: Expr, delimiter: Expr, null_string: Expr) -> Expr:
3594+
def string_to_array(
3595+
string: Expr, delimiter: Expr, null_string: Expr | None = None
3596+
) -> Expr:
35943597
"""Splits a string based on a delimiter and returns an array of parts.
35953598
3596-
Any parts matching the ``null_string`` will be replaced with ``NULL``.
3599+
Any parts matching the optional ``null_string`` will be replaced with ``NULL``.
35973600
35983601
Examples:
35993602
>>> ctx = dfn.SessionContext()
36003603
>>> df = ctx.from_pydict({"a": ["hello,world"]})
36013604
>>> result = df.select(
36023605
... dfn.functions.string_to_array(
3603-
... dfn.col("a"), dfn.lit(","), dfn.lit(""),
3606+
... dfn.col("a"), dfn.lit(","),
36043607
... ).alias("result"))
36053608
>>> result.collect_column("result")[0].as_py()
36063609
['hello', 'world']
36073610
"""
3608-
return Expr(f.string_to_array(string.expr, delimiter.expr, null_string.expr))
3611+
null_expr = null_string.expr if null_string is not None else None
3612+
return Expr(f.string_to_array(string.expr, delimiter.expr, null_expr))
36093613

36103614

3611-
def string_to_list(string: Expr, delimiter: Expr, null_string: Expr) -> Expr:
3615+
def string_to_list(
3616+
string: Expr, delimiter: Expr, null_string: Expr | None = None
3617+
) -> Expr:
36123618
"""Splits a string based on a delimiter and returns an array of parts.
36133619
36143620
See Also:
@@ -3617,7 +3623,7 @@ def string_to_list(string: Expr, delimiter: Expr, null_string: Expr) -> Expr:
36173623
return string_to_array(string, delimiter, null_string)
36183624

36193625

3620-
def gen_series(start: Expr, stop: Expr, step: Expr) -> Expr:
3626+
def gen_series(start: Expr, stop: Expr, step: Expr | None = None) -> Expr:
36213627
"""Creates a list of values in the range between start and stop.
36223628
36233629
Unlike :py:func:`range`, this includes the upper bound.
@@ -3627,15 +3633,16 @@ def gen_series(start: Expr, stop: Expr, step: Expr) -> Expr:
36273633
>>> df = ctx.from_pydict({"a": [0]})
36283634
>>> result = df.select(
36293635
... dfn.functions.gen_series(
3630-
... dfn.lit(1), dfn.lit(5), dfn.lit(1),
3636+
... dfn.lit(1), dfn.lit(5),
36313637
... ).alias("result"))
36323638
>>> result.collect_column("result")[0].as_py()
36333639
[1, 2, 3, 4, 5]
36343640
"""
3635-
return Expr(f.gen_series(start.expr, stop.expr, step.expr))
3641+
step_expr = step.expr if step is not None else None
3642+
return Expr(f.gen_series(start.expr, stop.expr, step_expr))
36363643

36373644

3638-
def generate_series(start: Expr, stop: Expr, step: Expr) -> Expr:
3645+
def generate_series(start: Expr, stop: Expr, step: Expr | None = None) -> Expr:
36393646
"""Creates a list of values in the range between start and stop.
36403647
36413648
Unlike :py:func:`range`, this includes the upper bound.

python/tests/test_functions.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,35 +1567,48 @@ def test_string_to_array():
15671567
ctx = SessionContext()
15681568
df = ctx.from_pydict({"a": ["hello,world,foo"]})
15691569
result = df.select(
1570-
f.string_to_array(column("a"), literal(","), literal("")).alias("v")
1570+
f.string_to_array(column("a"), literal(",")).alias("v")
15711571
).collect()
15721572
assert result[0].column(0)[0].as_py() == ["hello", "world", "foo"]
15731573

15741574

1575-
def test_string_to_list():
1575+
def test_string_to_array_with_null_string():
15761576
ctx = SessionContext()
1577-
df = ctx.from_pydict({"a": ["a-b-c"]})
1577+
df = ctx.from_pydict({"a": ["hello,NA,world"]})
15781578
result = df.select(
1579-
f.string_to_list(column("a"), literal("-"), literal("")).alias("v")
1579+
f.string_to_array(column("a"), literal(","), literal("NA")).alias("v")
15801580
).collect()
1581+
values = result[0].column(0)[0].as_py()
1582+
assert values == ["hello", None, "world"]
1583+
1584+
1585+
def test_string_to_list():
1586+
ctx = SessionContext()
1587+
df = ctx.from_pydict({"a": ["a-b-c"]})
1588+
result = df.select(f.string_to_list(column("a"), literal("-")).alias("v")).collect()
15811589
assert result[0].column(0)[0].as_py() == ["a", "b", "c"]
15821590

15831591

15841592
def test_gen_series():
15851593
ctx = SessionContext()
15861594
df = ctx.from_pydict({"a": [0]})
1587-
result = df.select(
1588-
f.gen_series(literal(1), literal(5), literal(1)).alias("v")
1589-
).collect()
1595+
result = df.select(f.gen_series(literal(1), literal(5)).alias("v")).collect()
15901596
assert result[0].column(0)[0].as_py() == [1, 2, 3, 4, 5]
15911597

15921598

1593-
def test_generate_series():
1599+
def test_gen_series_with_step():
15941600
ctx = SessionContext()
15951601
df = ctx.from_pydict({"a": [0]})
15961602
result = df.select(
1597-
f.generate_series(literal(1), literal(3), literal(1)).alias("v")
1603+
f.gen_series(literal(1), literal(10), literal(3)).alias("v")
15981604
).collect()
1605+
assert result[0].column(0)[0].as_py() == [1, 4, 7, 10]
1606+
1607+
1608+
def test_generate_series():
1609+
ctx = SessionContext()
1610+
df = ctx.from_pydict({"a": [0]})
1611+
result = df.select(f.generate_series(literal(1), literal(3)).alias("v")).collect()
15991612
assert result[0].column(0)[0].as_py() == [1, 2, 3]
16001613

16011614

0 commit comments

Comments
 (0)