Skip to content

Commit d265045

Browse files
committed
add any_match higher-order function
1 parent 720aaff commit d265045

3 files changed

Lines changed: 422 additions & 1 deletion

File tree

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`HigherOrderUDF`] definitions for array_any_match function.
19+
20+
use arrow::{
21+
array::{Array, AsArray, BooleanArray, new_null_array},
22+
datatypes::{ArrowNativeType, DataType, Field, FieldRef},
23+
};
24+
use datafusion_common::{
25+
Result, exec_err, plan_err,
26+
utils::{list_values, take_function_args},
27+
};
28+
use datafusion_expr::{
29+
ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs,
30+
HigherOrderSignature, HigherOrderUDF, LambdaParametersProgress, ValueOrLambda,
31+
Volatility,
32+
};
33+
use datafusion_macros::user_doc;
34+
use std::{fmt::Debug, sync::Arc};
35+
36+
make_higher_order_function_expr_and_func!(
37+
ArrayAnyMatch,
38+
array_any_match,
39+
array lambda,
40+
"returns true if any element in the array satisfies the predicate",
41+
array_any_match_higher_order_function
42+
);
43+
44+
#[user_doc(
45+
doc_section(label = "Array Functions"),
46+
description = "Returns whether any elements of an array match the given predicate. Returns true if one or more elements match, false if none match (including empty arrays), and null if the predicate returns null for some elements and false for all others.",
47+
syntax_example = "any_match(array, predicate)",
48+
sql_example = r#"```sql
49+
> select any_match([1, 2, 3], x -> x > 2);
50+
+----------------------------------+
51+
| any_match([1, 2, 3], x -> x > 2) |
52+
+----------------------------------+
53+
| true |
54+
+----------------------------------+
55+
```"#,
56+
argument(
57+
name = "array",
58+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
59+
),
60+
argument(
61+
name = "predicate",
62+
description = "Lambda predicate that returns a boolean"
63+
)
64+
)]
65+
#[derive(Debug, PartialEq, Eq, Hash)]
66+
pub struct ArrayAnyMatch {
67+
signature: HigherOrderSignature,
68+
aliases: Vec<String>,
69+
}
70+
71+
impl Default for ArrayAnyMatch {
72+
fn default() -> Self {
73+
Self::new()
74+
}
75+
}
76+
77+
impl ArrayAnyMatch {
78+
pub fn new() -> Self {
79+
Self {
80+
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
81+
aliases: vec![String::from("any_match"), String::from("list_any_match")],
82+
}
83+
}
84+
}
85+
86+
// Returns Some(true) if any element in [start, end) is true,
87+
// None if no element is true but some are null,
88+
// Some(false) if all are false or range is empty.
89+
fn any_match_for_range(
90+
predicate: &BooleanArray,
91+
start: usize,
92+
end: usize,
93+
) -> Option<bool> {
94+
let any_true = (start..end).any(|j| predicate.is_valid(j) && predicate.value(j));
95+
if any_true {
96+
return Some(true);
97+
}
98+
let any_null = (start..end).any(|j| predicate.is_null(j));
99+
if any_null { None } else { Some(false) }
100+
}
101+
102+
impl HigherOrderUDF for ArrayAnyMatch {
103+
fn name(&self) -> &str {
104+
"array_any_match"
105+
}
106+
107+
fn aliases(&self) -> &[String] {
108+
&self.aliases
109+
}
110+
111+
fn signature(&self) -> &HigherOrderSignature {
112+
&self.signature
113+
}
114+
115+
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
116+
let list = if arg_types.len() == 1 {
117+
&arg_types[0]
118+
} else {
119+
return plan_err!(
120+
"{} function requires 1 value argument, got {}",
121+
self.name(),
122+
arg_types.len()
123+
);
124+
};
125+
126+
let coerced = match list {
127+
DataType::List(_) | DataType::LargeList(_) => list.clone(),
128+
DataType::ListView(field) | DataType::FixedSizeList(field, _) => {
129+
DataType::List(Arc::clone(field))
130+
}
131+
DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)),
132+
_ => {
133+
return plan_err!(
134+
"{} expected a list as first argument, got {}",
135+
self.name(),
136+
list
137+
);
138+
}
139+
};
140+
141+
Ok(vec![coerced])
142+
}
143+
144+
fn lambda_parameters(
145+
&self,
146+
_step: usize,
147+
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
148+
) -> Result<LambdaParametersProgress> {
149+
let [list, _lambda] = take_function_args(self.name(), fields)?;
150+
151+
let field = match list {
152+
ValueOrLambda::Value(f) => match f.data_type() {
153+
DataType::List(field) => field,
154+
DataType::LargeList(field) => field,
155+
other => return plan_err!("expected list, got {other}"),
156+
},
157+
_ => return plan_err!("{} expected a value as first argument", self.name()),
158+
};
159+
160+
Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone(
161+
field,
162+
)]]))
163+
}
164+
165+
fn return_field_from_args(
166+
&self,
167+
args: HigherOrderReturnFieldArgs,
168+
) -> Result<Arc<Field>> {
169+
let [list, _lambda] = take_function_args(self.name(), args.arg_fields)?;
170+
let nullable = matches!(list, ValueOrLambda::Value(f) if f.is_nullable());
171+
Ok(Arc::new(Field::new("", DataType::Boolean, nullable)))
172+
}
173+
174+
fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
175+
let [list, lambda] = take_function_args(self.name(), &args.args)?;
176+
177+
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
178+
else {
179+
return exec_err!("{} expects a value followed by a lambda", self.name());
180+
};
181+
182+
let list_array = list.to_array(args.number_rows)?;
183+
184+
// fast path: fully null input — also required for FixedSizeList which can't be
185+
// handled by clear_null_values when fully null
186+
if list_array.null_count() == list_array.len() {
187+
return Ok(ColumnarValue::Array(new_null_array(
188+
args.return_type(),
189+
list_array.len(),
190+
)));
191+
}
192+
193+
let list_values = list_values(&list_array)?;
194+
195+
let values_param = || Ok(Arc::clone(&list_values));
196+
197+
let predicate_results = lambda
198+
.evaluate(&[&values_param])?
199+
.into_array(list_values.len())?;
200+
201+
let predicate_bool = predicate_results
202+
.as_any()
203+
.downcast_ref::<BooleanArray>()
204+
.ok_or_else(|| {
205+
datafusion_common::DataFusionError::Execution(format!(
206+
"{} predicate must return boolean array",
207+
self.name()
208+
))
209+
})?;
210+
211+
let mut results = Vec::with_capacity(list_array.len());
212+
213+
macro_rules! process_list {
214+
($list_typed:expr) => {{
215+
let offsets = $list_typed.offsets();
216+
for i in 0..$list_typed.len() {
217+
if $list_typed.is_null(i) {
218+
results.push(None);
219+
continue;
220+
}
221+
let start = offsets[i].as_usize();
222+
let end = offsets[i + 1].as_usize();
223+
results.push(any_match_for_range(predicate_bool, start, end));
224+
}
225+
}};
226+
}
227+
228+
match list_array.data_type() {
229+
DataType::List(_) => {
230+
process_list!(list_array.as_list::<i32>());
231+
}
232+
DataType::LargeList(_) => {
233+
process_list!(list_array.as_list::<i64>());
234+
}
235+
other => return exec_err!("expected list, got {other}"),
236+
}
237+
238+
Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(results))))
239+
}
240+
241+
fn documentation(&self) -> Option<&Documentation> {
242+
self.doc()
243+
}
244+
}
245+
246+
#[cfg(test)]
247+
mod tests {
248+
use std::{collections::HashMap, sync::Arc};
249+
250+
use arrow::{
251+
array::{ArrayRef, BooleanArray, Int32Array, ListArray, RecordBatch},
252+
buffer::OffsetBuffer,
253+
datatypes::{DataType, Field},
254+
};
255+
use datafusion_common::{DFSchema, Result};
256+
use datafusion_expr::{
257+
Expr, col,
258+
execution_props::ExecutionProps,
259+
expr::{HigherOrderFunction, LambdaVariable},
260+
lambda, lit,
261+
};
262+
use datafusion_physical_expr::create_physical_expr;
263+
264+
use crate::any_match::array_any_match_higher_order_function;
265+
266+
fn run_any_match(
267+
list: impl arrow::array::Array + Clone + 'static,
268+
) -> Result<ArrayRef> {
269+
let schema = DFSchema::from_unqualified_fields(
270+
vec![Field::new(
271+
"list",
272+
list.data_type().clone(),
273+
list.is_nullable(),
274+
)]
275+
.into(),
276+
HashMap::new(),
277+
)?;
278+
279+
create_physical_expr(
280+
&Expr::HigherOrderFunction(HigherOrderFunction::new(
281+
array_any_match_higher_order_function(),
282+
vec![
283+
col("list"),
284+
lambda(
285+
["x"],
286+
Expr::LambdaVariable(LambdaVariable::new(
287+
"x".to_string(),
288+
Some(Arc::new(Field::new("x", DataType::Int32, true))),
289+
))
290+
.gt(lit(2i32)),
291+
),
292+
],
293+
)),
294+
&schema,
295+
&ExecutionProps::new(),
296+
)?
297+
.evaluate(&RecordBatch::try_new(
298+
Arc::clone(schema.inner()),
299+
vec![Arc::new(list.clone())],
300+
)?)?
301+
.into_array(list.len())
302+
}
303+
304+
fn make_list(values: Vec<i32>, offsets: OffsetBuffer<i32>) -> ListArray {
305+
ListArray::new(
306+
Arc::new(Field::new_list_field(DataType::Int32, true)),
307+
offsets,
308+
Arc::new(Int32Array::from(values)),
309+
None,
310+
)
311+
}
312+
313+
#[test]
314+
fn test_any_match_some_true() -> Result<()> {
315+
let list = make_list(vec![1, 2, 3], OffsetBuffer::from_lengths(vec![3]));
316+
let result = run_any_match(list)?;
317+
assert_eq!(
318+
result.as_any().downcast_ref::<BooleanArray>().unwrap(),
319+
&BooleanArray::from(vec![Some(true)])
320+
);
321+
Ok(())
322+
}
323+
324+
#[test]
325+
fn test_any_match_none_true() -> Result<()> {
326+
let list = make_list(vec![1, 2], OffsetBuffer::from_lengths(vec![2]));
327+
let result = run_any_match(list)?;
328+
assert_eq!(
329+
result.as_any().downcast_ref::<BooleanArray>().unwrap(),
330+
&BooleanArray::from(vec![Some(false)])
331+
);
332+
Ok(())
333+
}
334+
335+
#[test]
336+
fn test_any_match_empty_array() -> Result<()> {
337+
let list = make_list(vec![], OffsetBuffer::from_lengths(vec![0]));
338+
let result = run_any_match(list)?;
339+
assert_eq!(
340+
result.as_any().downcast_ref::<BooleanArray>().unwrap(),
341+
&BooleanArray::from(vec![Some(false)])
342+
);
343+
Ok(())
344+
}
345+
346+
#[test]
347+
fn test_any_match_multiple_rows() -> Result<()> {
348+
let list = make_list(vec![1, 2, 3, 1, 2], OffsetBuffer::from_lengths(vec![3, 2]));
349+
let result = run_any_match(list)?;
350+
assert_eq!(
351+
result.as_any().downcast_ref::<BooleanArray>().unwrap(),
352+
&BooleanArray::from(vec![Some(true), Some(false)])
353+
);
354+
Ok(())
355+
}
356+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub mod macros;
4040
#[macro_use]
4141
pub mod macros_lambda;
4242

43+
pub mod any_match;
4344
pub mod array_compact;
4445
pub mod array_has;
4546
pub mod array_transform;
@@ -83,6 +84,7 @@ use std::sync::Arc;
8384

8485
/// Fluent-style API for creating `Expr`s
8586
pub mod expr_fn {
87+
pub use super::any_match::array_any_match;
8688
pub use super::array_compact::array_compact;
8789
pub use super::array_has::array_has;
8890
pub use super::array_has::array_has_all;
@@ -190,7 +192,10 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
190192
}
191193

192194
pub fn all_default_higher_order_functions() -> Vec<Arc<dyn HigherOrderUDF>> {
193-
vec![array_transform::array_transform_higher_order_function()]
195+
vec![
196+
any_match::array_any_match_higher_order_function(),
197+
array_transform::array_transform_higher_order_function(),
198+
]
194199
}
195200

196201
/// Registers all enabled packages with a [`FunctionRegistry`]

0 commit comments

Comments
 (0)