|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +//! [`HigherOrderUDF`] definitions for array_any_match function. |
| 19 | +
|
| 20 | +use arrow::{ |
| 21 | + array::{Array, AsArray, BooleanArray, new_null_array}, |
| 22 | + datatypes::{ArrowNativeType, DataType, Field, FieldRef}, |
| 23 | +}; |
| 24 | +use datafusion_common::{ |
| 25 | + Result, exec_err, plan_err, |
| 26 | + utils::{list_values, take_function_args}, |
| 27 | +}; |
| 28 | +use datafusion_expr::{ |
| 29 | + ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, |
| 30 | + HigherOrderSignature, HigherOrderUDF, LambdaParametersProgress, ValueOrLambda, |
| 31 | + Volatility, |
| 32 | +}; |
| 33 | +use datafusion_macros::user_doc; |
| 34 | +use std::{fmt::Debug, sync::Arc}; |
| 35 | + |
| 36 | +make_higher_order_function_expr_and_func!( |
| 37 | + ArrayAnyMatch, |
| 38 | + array_any_match, |
| 39 | + array lambda, |
| 40 | + "returns true if any element in the array satisfies the predicate", |
| 41 | + array_any_match_higher_order_function |
| 42 | +); |
| 43 | + |
| 44 | +#[user_doc( |
| 45 | + doc_section(label = "Array Functions"), |
| 46 | + description = "Returns whether any elements of an array match the given predicate. Returns true if one or more elements match, false if none match (including empty arrays), and null if the predicate returns null for some elements and false for all others.", |
| 47 | + syntax_example = "any_match(array, predicate)", |
| 48 | + sql_example = r#"```sql |
| 49 | +> select any_match([1, 2, 3], x -> x > 2); |
| 50 | ++----------------------------------+ |
| 51 | +| any_match([1, 2, 3], x -> x > 2) | |
| 52 | ++----------------------------------+ |
| 53 | +| true | |
| 54 | ++----------------------------------+ |
| 55 | +```"#, |
| 56 | + argument( |
| 57 | + name = "array", |
| 58 | + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." |
| 59 | + ), |
| 60 | + argument( |
| 61 | + name = "predicate", |
| 62 | + description = "Lambda predicate that returns a boolean" |
| 63 | + ) |
| 64 | +)] |
| 65 | +#[derive(Debug, PartialEq, Eq, Hash)] |
| 66 | +pub struct ArrayAnyMatch { |
| 67 | + signature: HigherOrderSignature, |
| 68 | + aliases: Vec<String>, |
| 69 | +} |
| 70 | + |
| 71 | +impl Default for ArrayAnyMatch { |
| 72 | + fn default() -> Self { |
| 73 | + Self::new() |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +impl ArrayAnyMatch { |
| 78 | + pub fn new() -> Self { |
| 79 | + Self { |
| 80 | + signature: HigherOrderSignature::user_defined(Volatility::Immutable), |
| 81 | + aliases: vec![String::from("any_match"), String::from("list_any_match")], |
| 82 | + } |
| 83 | + } |
| 84 | +} |
| 85 | + |
| 86 | +// Returns Some(true) if any element in [start, end) is true, |
| 87 | +// None if no element is true but some are null, |
| 88 | +// Some(false) if all are false or range is empty. |
| 89 | +fn any_match_for_range( |
| 90 | + predicate: &BooleanArray, |
| 91 | + start: usize, |
| 92 | + end: usize, |
| 93 | +) -> Option<bool> { |
| 94 | + let any_true = (start..end).any(|j| predicate.is_valid(j) && predicate.value(j)); |
| 95 | + if any_true { |
| 96 | + return Some(true); |
| 97 | + } |
| 98 | + let any_null = (start..end).any(|j| predicate.is_null(j)); |
| 99 | + if any_null { None } else { Some(false) } |
| 100 | +} |
| 101 | + |
| 102 | +impl HigherOrderUDF for ArrayAnyMatch { |
| 103 | + fn name(&self) -> &str { |
| 104 | + "array_any_match" |
| 105 | + } |
| 106 | + |
| 107 | + fn aliases(&self) -> &[String] { |
| 108 | + &self.aliases |
| 109 | + } |
| 110 | + |
| 111 | + fn signature(&self) -> &HigherOrderSignature { |
| 112 | + &self.signature |
| 113 | + } |
| 114 | + |
| 115 | + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
| 116 | + let list = if arg_types.len() == 1 { |
| 117 | + &arg_types[0] |
| 118 | + } else { |
| 119 | + return plan_err!( |
| 120 | + "{} function requires 1 value argument, got {}", |
| 121 | + self.name(), |
| 122 | + arg_types.len() |
| 123 | + ); |
| 124 | + }; |
| 125 | + |
| 126 | + let coerced = match list { |
| 127 | + DataType::List(_) | DataType::LargeList(_) => list.clone(), |
| 128 | + DataType::ListView(field) | DataType::FixedSizeList(field, _) => { |
| 129 | + DataType::List(Arc::clone(field)) |
| 130 | + } |
| 131 | + DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), |
| 132 | + _ => { |
| 133 | + return plan_err!( |
| 134 | + "{} expected a list as first argument, got {}", |
| 135 | + self.name(), |
| 136 | + list |
| 137 | + ); |
| 138 | + } |
| 139 | + }; |
| 140 | + |
| 141 | + Ok(vec![coerced]) |
| 142 | + } |
| 143 | + |
| 144 | + fn lambda_parameters( |
| 145 | + &self, |
| 146 | + _step: usize, |
| 147 | + fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>], |
| 148 | + ) -> Result<LambdaParametersProgress> { |
| 149 | + let [list, _lambda] = take_function_args(self.name(), fields)?; |
| 150 | + |
| 151 | + let field = match list { |
| 152 | + ValueOrLambda::Value(f) => match f.data_type() { |
| 153 | + DataType::List(field) => field, |
| 154 | + DataType::LargeList(field) => field, |
| 155 | + other => return plan_err!("expected list, got {other}"), |
| 156 | + }, |
| 157 | + _ => return plan_err!("{} expected a value as first argument", self.name()), |
| 158 | + }; |
| 159 | + |
| 160 | + Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone( |
| 161 | + field, |
| 162 | + )]])) |
| 163 | + } |
| 164 | + |
| 165 | + fn return_field_from_args( |
| 166 | + &self, |
| 167 | + args: HigherOrderReturnFieldArgs, |
| 168 | + ) -> Result<Arc<Field>> { |
| 169 | + let [list, _lambda] = take_function_args(self.name(), args.arg_fields)?; |
| 170 | + let nullable = matches!(list, ValueOrLambda::Value(f) if f.is_nullable()); |
| 171 | + Ok(Arc::new(Field::new("", DataType::Boolean, nullable))) |
| 172 | + } |
| 173 | + |
| 174 | + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> { |
| 175 | + let [list, lambda] = take_function_args(self.name(), &args.args)?; |
| 176 | + |
| 177 | + let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda) |
| 178 | + else { |
| 179 | + return exec_err!("{} expects a value followed by a lambda", self.name()); |
| 180 | + }; |
| 181 | + |
| 182 | + let list_array = list.to_array(args.number_rows)?; |
| 183 | + |
| 184 | + // fast path: fully null input — also required for FixedSizeList which can't be |
| 185 | + // handled by clear_null_values when fully null |
| 186 | + if list_array.null_count() == list_array.len() { |
| 187 | + return Ok(ColumnarValue::Array(new_null_array( |
| 188 | + args.return_type(), |
| 189 | + list_array.len(), |
| 190 | + ))); |
| 191 | + } |
| 192 | + |
| 193 | + let list_values = list_values(&list_array)?; |
| 194 | + |
| 195 | + let values_param = || Ok(Arc::clone(&list_values)); |
| 196 | + |
| 197 | + let predicate_results = lambda |
| 198 | + .evaluate(&[&values_param])? |
| 199 | + .into_array(list_values.len())?; |
| 200 | + |
| 201 | + let predicate_bool = predicate_results |
| 202 | + .as_any() |
| 203 | + .downcast_ref::<BooleanArray>() |
| 204 | + .ok_or_else(|| { |
| 205 | + datafusion_common::DataFusionError::Execution(format!( |
| 206 | + "{} predicate must return boolean array", |
| 207 | + self.name() |
| 208 | + )) |
| 209 | + })?; |
| 210 | + |
| 211 | + let mut results = Vec::with_capacity(list_array.len()); |
| 212 | + |
| 213 | + macro_rules! process_list { |
| 214 | + ($list_typed:expr) => {{ |
| 215 | + let offsets = $list_typed.offsets(); |
| 216 | + for i in 0..$list_typed.len() { |
| 217 | + if $list_typed.is_null(i) { |
| 218 | + results.push(None); |
| 219 | + continue; |
| 220 | + } |
| 221 | + let start = offsets[i].as_usize(); |
| 222 | + let end = offsets[i + 1].as_usize(); |
| 223 | + results.push(any_match_for_range(predicate_bool, start, end)); |
| 224 | + } |
| 225 | + }}; |
| 226 | + } |
| 227 | + |
| 228 | + match list_array.data_type() { |
| 229 | + DataType::List(_) => { |
| 230 | + process_list!(list_array.as_list::<i32>()); |
| 231 | + } |
| 232 | + DataType::LargeList(_) => { |
| 233 | + process_list!(list_array.as_list::<i64>()); |
| 234 | + } |
| 235 | + other => return exec_err!("expected list, got {other}"), |
| 236 | + } |
| 237 | + |
| 238 | + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(results)))) |
| 239 | + } |
| 240 | + |
| 241 | + fn documentation(&self) -> Option<&Documentation> { |
| 242 | + self.doc() |
| 243 | + } |
| 244 | +} |
| 245 | + |
| 246 | +#[cfg(test)] |
| 247 | +mod tests { |
| 248 | + use std::{collections::HashMap, sync::Arc}; |
| 249 | + |
| 250 | + use arrow::{ |
| 251 | + array::{ArrayRef, BooleanArray, Int32Array, ListArray, RecordBatch}, |
| 252 | + buffer::OffsetBuffer, |
| 253 | + datatypes::{DataType, Field}, |
| 254 | + }; |
| 255 | + use datafusion_common::{DFSchema, Result}; |
| 256 | + use datafusion_expr::{ |
| 257 | + Expr, col, |
| 258 | + execution_props::ExecutionProps, |
| 259 | + expr::{HigherOrderFunction, LambdaVariable}, |
| 260 | + lambda, lit, |
| 261 | + }; |
| 262 | + use datafusion_physical_expr::create_physical_expr; |
| 263 | + |
| 264 | + use crate::any_match::array_any_match_higher_order_function; |
| 265 | + |
| 266 | + fn run_any_match( |
| 267 | + list: impl arrow::array::Array + Clone + 'static, |
| 268 | + ) -> Result<ArrayRef> { |
| 269 | + let schema = DFSchema::from_unqualified_fields( |
| 270 | + vec![Field::new( |
| 271 | + "list", |
| 272 | + list.data_type().clone(), |
| 273 | + list.is_nullable(), |
| 274 | + )] |
| 275 | + .into(), |
| 276 | + HashMap::new(), |
| 277 | + )?; |
| 278 | + |
| 279 | + create_physical_expr( |
| 280 | + &Expr::HigherOrderFunction(HigherOrderFunction::new( |
| 281 | + array_any_match_higher_order_function(), |
| 282 | + vec![ |
| 283 | + col("list"), |
| 284 | + lambda( |
| 285 | + ["x"], |
| 286 | + Expr::LambdaVariable(LambdaVariable::new( |
| 287 | + "x".to_string(), |
| 288 | + Some(Arc::new(Field::new("x", DataType::Int32, true))), |
| 289 | + )) |
| 290 | + .gt(lit(2i32)), |
| 291 | + ), |
| 292 | + ], |
| 293 | + )), |
| 294 | + &schema, |
| 295 | + &ExecutionProps::new(), |
| 296 | + )? |
| 297 | + .evaluate(&RecordBatch::try_new( |
| 298 | + Arc::clone(schema.inner()), |
| 299 | + vec![Arc::new(list.clone())], |
| 300 | + )?)? |
| 301 | + .into_array(list.len()) |
| 302 | + } |
| 303 | + |
| 304 | + fn make_list(values: Vec<i32>, offsets: OffsetBuffer<i32>) -> ListArray { |
| 305 | + ListArray::new( |
| 306 | + Arc::new(Field::new_list_field(DataType::Int32, true)), |
| 307 | + offsets, |
| 308 | + Arc::new(Int32Array::from(values)), |
| 309 | + None, |
| 310 | + ) |
| 311 | + } |
| 312 | + |
| 313 | + #[test] |
| 314 | + fn test_any_match_some_true() -> Result<()> { |
| 315 | + let list = make_list(vec![1, 2, 3], OffsetBuffer::from_lengths(vec![3])); |
| 316 | + let result = run_any_match(list)?; |
| 317 | + assert_eq!( |
| 318 | + result.as_any().downcast_ref::<BooleanArray>().unwrap(), |
| 319 | + &BooleanArray::from(vec![Some(true)]) |
| 320 | + ); |
| 321 | + Ok(()) |
| 322 | + } |
| 323 | + |
| 324 | + #[test] |
| 325 | + fn test_any_match_none_true() -> Result<()> { |
| 326 | + let list = make_list(vec![1, 2], OffsetBuffer::from_lengths(vec![2])); |
| 327 | + let result = run_any_match(list)?; |
| 328 | + assert_eq!( |
| 329 | + result.as_any().downcast_ref::<BooleanArray>().unwrap(), |
| 330 | + &BooleanArray::from(vec![Some(false)]) |
| 331 | + ); |
| 332 | + Ok(()) |
| 333 | + } |
| 334 | + |
| 335 | + #[test] |
| 336 | + fn test_any_match_empty_array() -> Result<()> { |
| 337 | + let list = make_list(vec![], OffsetBuffer::from_lengths(vec![0])); |
| 338 | + let result = run_any_match(list)?; |
| 339 | + assert_eq!( |
| 340 | + result.as_any().downcast_ref::<BooleanArray>().unwrap(), |
| 341 | + &BooleanArray::from(vec![Some(false)]) |
| 342 | + ); |
| 343 | + Ok(()) |
| 344 | + } |
| 345 | + |
| 346 | + #[test] |
| 347 | + fn test_any_match_multiple_rows() -> Result<()> { |
| 348 | + let list = make_list(vec![1, 2, 3, 1, 2], OffsetBuffer::from_lengths(vec![3, 2])); |
| 349 | + let result = run_any_match(list)?; |
| 350 | + assert_eq!( |
| 351 | + result.as_any().downcast_ref::<BooleanArray>().unwrap(), |
| 352 | + &BooleanArray::from(vec![Some(true), Some(false)]) |
| 353 | + ); |
| 354 | + Ok(()) |
| 355 | + } |
| 356 | +} |
0 commit comments