Skip to content

Commit 9a1e077

Browse files
timsaucerclaude
andcommitted
feat: expose array_compact, array_normalize, cosine_distance, inner_product
Adds Python bindings for four scalar functions from datafusion::functions_nested::expr_fn that were not previously surfaced: - array_compact / list_compact: drop NULLs from an array. - array_normalize / list_normalize: L2-normalize a numeric array. - cosine_distance: 1 - cosine_similarity(a, b). - inner_product: dot product of two numeric arrays. Implementation routes each through the existing array_fn! macro in crates/core/src/functions.rs, mirroring the other functions_nested wrappers. Python wrappers in python/datafusion/functions.py follow the established pattern with doctest examples; list_* aliases use the one-line + See Also form per project convention. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d021e6a commit 9a1e077

2 files changed

Lines changed: 104 additions & 0 deletions

File tree

crates/core/src/functions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,10 @@ array_fn!(array_replace, array from to);
654654
array_fn!(array_replace_n, array from to max);
655655
array_fn!(array_replace_all, array from to);
656656
array_fn!(array_sort, array desc null_first);
657+
array_fn!(array_compact, array);
658+
array_fn!(array_normalize, array);
659+
array_fn!(cosine_distance, array1 array2);
660+
array_fn!(inner_product, array1 array2);
657661
array_fn!(array_intersect, first_array second_array);
658662
array_fn!(array_union, array1 array2);
659663
array_fn!(array_except, first_array second_array);
@@ -1133,6 +1137,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
11331137
m.add_wrapped(wrap_pyfunction!(array_cat))?;
11341138
m.add_wrapped(wrap_pyfunction!(array_dims))?;
11351139
m.add_wrapped(wrap_pyfunction!(array_distinct))?;
1140+
m.add_wrapped(wrap_pyfunction!(array_compact))?;
1141+
m.add_wrapped(wrap_pyfunction!(array_normalize))?;
1142+
m.add_wrapped(wrap_pyfunction!(cosine_distance))?;
1143+
m.add_wrapped(wrap_pyfunction!(inner_product))?;
11361144
m.add_wrapped(wrap_pyfunction!(array_element))?;
11371145
m.add_wrapped(wrap_pyfunction!(array_empty))?;
11381146
m.add_wrapped(wrap_pyfunction!(array_length))?;

python/datafusion/functions.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
"array_any_value",
7777
"array_append",
7878
"array_cat",
79+
"array_compact",
7980
"array_concat",
8081
"array_contains",
8182
"array_dims",
@@ -96,6 +97,7 @@
9697
"array_max",
9798
"array_min",
9899
"array_ndims",
100+
"array_normalize",
99101
"array_pop_back",
100102
"array_pop_front",
101103
"array_position",
@@ -151,6 +153,7 @@
151153
"corr",
152154
"cos",
153155
"cosh",
156+
"cosine_distance",
154157
"cot",
155158
"count",
156159
"count_star",
@@ -192,6 +195,7 @@
192195
"ifnull",
193196
"in_list",
194197
"initcap",
198+
"inner_product",
195199
"instr",
196200
"isnan",
197201
"iszero",
@@ -209,6 +213,7 @@
209213
"list_any_value",
210214
"list_append",
211215
"list_cat",
216+
"list_compact",
212217
"list_concat",
213218
"list_contains",
214219
"list_dims",
@@ -229,6 +234,7 @@
229234
"list_max",
230235
"list_min",
231236
"list_ndims",
237+
"list_normalize",
232238
"list_overlap",
233239
"list_pop_back",
234240
"list_pop_front",
@@ -3204,6 +3210,78 @@ def array_distinct(array: Expr) -> Expr:
32043210
return Expr(f.array_distinct(array.expr))
32053211

32063212

3213+
def array_compact(array: Expr) -> Expr:
3214+
"""Removes NULL values from the array.
3215+
3216+
Examples:
3217+
>>> ctx = dfn.SessionContext()
3218+
>>> df = ctx.from_pydict({"a": [[1, None, 2, None, 3]]})
3219+
>>> result = df.select(
3220+
... dfn.functions.array_compact(dfn.col("a")).alias("result")
3221+
... )
3222+
>>> result.collect_column("result")[0].as_py()
3223+
[1, 2, 3]
3224+
"""
3225+
return Expr(f.array_compact(array.expr))
3226+
3227+
3228+
def array_normalize(array: Expr) -> Expr:
3229+
"""Returns the L2-normalized vector for a numeric array.
3230+
3231+
Examples:
3232+
>>> ctx = dfn.SessionContext()
3233+
>>> df = ctx.from_pydict({"a": [[3.0, 4.0]]})
3234+
>>> result = df.select(
3235+
... dfn.functions.array_normalize(dfn.col("a")).alias("result")
3236+
... )
3237+
>>> result.collect_column("result")[0].as_py()
3238+
[0.6, 0.8]
3239+
"""
3240+
return Expr(f.array_normalize(array.expr))
3241+
3242+
3243+
def cosine_distance(array1: Expr, array2: Expr) -> Expr:
3244+
"""Returns the cosine distance between two numeric arrays.
3245+
3246+
Computed as ``1 - cosine_similarity(array1, array2)``.
3247+
3248+
Examples:
3249+
>>> ctx = dfn.SessionContext()
3250+
>>> df = ctx.from_pydict(
3251+
... {"a": [[1.0, 2.0, 3.0]], "b": [[1.0, 2.0, 3.0]]}
3252+
... )
3253+
>>> result = df.select(
3254+
... dfn.functions.cosine_distance(
3255+
... dfn.col("a"), dfn.col("b")
3256+
... ).alias("result")
3257+
... )
3258+
>>> result.collect_column("result")[0].as_py()
3259+
0.0
3260+
"""
3261+
return Expr(f.cosine_distance(array1.expr, array2.expr))
3262+
3263+
3264+
def inner_product(array1: Expr, array2: Expr) -> Expr:
3265+
"""Returns the inner (dot) product of two numeric arrays.
3266+
3267+
The SQL name ``dot_product`` is an alias for this function in raw SQL.
3268+
3269+
Examples:
3270+
>>> ctx = dfn.SessionContext()
3271+
>>> df = ctx.from_pydict(
3272+
... {"a": [[1.0, 2.0, 3.0]], "b": [[4.0, 5.0, 6.0]]}
3273+
... )
3274+
>>> result = df.select(
3275+
... dfn.functions.inner_product(
3276+
... dfn.col("a"), dfn.col("b")
3277+
... ).alias("result")
3278+
... )
3279+
>>> result.collect_column("result")[0].as_py()
3280+
32.0
3281+
"""
3282+
return Expr(f.inner_product(array1.expr, array2.expr))
3283+
3284+
32073285
def list_cat(*args: Expr) -> Expr:
32083286
"""Concatenates the input arrays.
32093287
@@ -3231,6 +3309,24 @@ def list_distinct(array: Expr) -> Expr:
32313309
return array_distinct(array)
32323310

32333311

3312+
def list_compact(array: Expr) -> Expr:
3313+
"""Removes NULL values from the array.
3314+
3315+
See Also:
3316+
This is an alias for :py:func:`array_compact`.
3317+
"""
3318+
return array_compact(array)
3319+
3320+
3321+
def list_normalize(array: Expr) -> Expr:
3322+
"""Returns the L2-normalized vector for a numeric array.
3323+
3324+
See Also:
3325+
This is an alias for :py:func:`array_normalize`.
3326+
"""
3327+
return array_normalize(array)
3328+
3329+
32343330
def list_dims(array: Expr) -> Expr:
32353331
"""Returns an array of the array's dimensions.
32363332

0 commit comments

Comments
 (0)