Skip to content

Commit d694abd

Browse files
fix: balance list_contains OR tree for large IN-list filters (#37)
1 parent 8a09bc7 commit d694abd

1 file changed

Lines changed: 70 additions & 7 deletions

File tree

  • vortex-array/src/scalar_fn/fns/list_contains

vortex-array/src/scalar_fn/fns/list_contains/mod.rs

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ fn constant_list_scalar_contains(
239239
let elements = list_scalar.elements().vortex_expect("non null");
240240

241241
let len = values.len();
242-
let mut result: Option<ArrayRef> = None;
243242
let false_scalar = Scalar::bool(false, nullability);
243+
let values = values.to_array();
244+
let mut partials = Vec::with_capacity(elements.len());
244245

245246
for element in elements {
246247
let res = Binary
@@ -249,17 +250,39 @@ fn constant_list_scalar_contains(
249250
Operator::Eq,
250251
[
251252
ConstantArray::new(element, len).into_array(),
252-
values.to_array(),
253+
values.clone(),
253254
],
254255
)?
255256
.fill_null(false_scalar.clone())?;
256-
if let Some(acc) = result {
257-
result = Some(acc.binary(res, Operator::Or)?)
258-
} else {
259-
result = Some(res);
257+
partials.push(res);
258+
}
259+
260+
if partials.is_empty() {
261+
return Ok(ConstantArray::new(false_scalar, len).to_array());
262+
}
263+
264+
or_arrays_balanced(partials)
265+
}
266+
267+
fn or_arrays_balanced(mut arrays: Vec<ArrayRef>) -> VortexResult<ArrayRef> {
268+
debug_assert!(!arrays.is_empty());
269+
270+
while arrays.len() > 1 {
271+
let mut next = Vec::with_capacity(arrays.len().div_ceil(2));
272+
let mut i = 0;
273+
while i + 1 < arrays.len() {
274+
next.push(arrays[i].binary(arrays[i + 1].clone(), Operator::Or)?);
275+
i += 2;
260276
}
277+
if i < arrays.len() {
278+
next.push(arrays[i].clone());
279+
}
280+
arrays = next;
261281
}
262-
Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
282+
283+
Ok(arrays
284+
.pop()
285+
.expect("or_arrays_balanced must be called with at least one array"))
263286
}
264287

265288
/// Returns a [`BoolArray`] where each bit represents if a list contains the scalar.
@@ -429,6 +452,8 @@ fn list_is_not_empty(
429452
mod tests {
430453
use std::sync::Arc;
431454

455+
use super::or_arrays_balanced;
456+
432457
use itertools::Itertools;
433458
use rstest::rstest;
434459
use vortex_buffer::BitBuffer;
@@ -789,6 +814,44 @@ mod tests {
789814
assert_arrays_eq!(contains, expected);
790815
}
791816

817+
fn array_depth(array: &dyn Array) -> usize {
818+
1 + (0..array.nchildren())
819+
.filter_map(|idx| array.nth_child(idx))
820+
.map(|child| array_depth(child.as_ref()))
821+
.max()
822+
.unwrap_or(0)
823+
}
824+
825+
#[test]
826+
fn test_or_arrays_balanced_depth() {
827+
let arrays = vec![
828+
BoolArray::from_iter([true, false]).into_array(),
829+
BoolArray::from_iter([false, true]).into_array(),
830+
BoolArray::from_iter([false, false]).into_array(),
831+
BoolArray::from_iter([true, true]).into_array(),
832+
BoolArray::from_iter([true, false]).into_array(),
833+
];
834+
835+
let result = or_arrays_balanced(arrays).unwrap();
836+
assert_eq!(array_depth(result.as_ref()), 4);
837+
}
838+
839+
#[test]
840+
fn test_constant_list_large_regression() {
841+
let list_scalar = Scalar::list(
842+
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
843+
(0i32..2048).map(Into::into).collect(),
844+
Nullability::NonNullable,
845+
);
846+
847+
let values = PrimitiveArray::from_iter(0i32..2048).into_array();
848+
let expr = list_contains(lit(list_scalar), root());
849+
let contains = values.apply(&expr).unwrap();
850+
851+
let expected = BoolArray::from_iter(std::iter::repeat_n(true, 2048));
852+
assert_arrays_eq!(contains, expected);
853+
}
854+
792855
#[test]
793856
fn test_all_nulls() {
794857
let list_array = ConstantArray::new(

0 commit comments

Comments
 (0)