@@ -239,8 +239,9 @@ fn constant_list_scalar_contains(
239239 let elements = list_scalar. elements ( ) . vortex_expect ( "non null" ) ;
240240
241241 let len = values. len ( ) ;
242- let mut result: Option < ArrayRef > = None ;
243242 let false_scalar = Scalar :: bool ( false , nullability) ;
243+ let values = values. to_array ( ) ;
244+ let mut partials = Vec :: with_capacity ( elements. len ( ) ) ;
244245
245246 for element in elements {
246247 let res = Binary
@@ -249,17 +250,39 @@ fn constant_list_scalar_contains(
249250 Operator :: Eq ,
250251 [
251252 ConstantArray :: new ( element, len) . into_array ( ) ,
252- values. to_array ( ) ,
253+ values. clone ( ) ,
253254 ] ,
254255 ) ?
255256 . fill_null ( false_scalar. clone ( ) ) ?;
256- if let Some ( acc) = result {
257- result = Some ( acc. binary ( res, Operator :: Or ) ?)
258- } else {
259- result = Some ( res) ;
257+ partials. push ( res) ;
258+ }
259+
260+ if partials. is_empty ( ) {
261+ return Ok ( ConstantArray :: new ( false_scalar, len) . to_array ( ) ) ;
262+ }
263+
264+ or_arrays_balanced ( partials)
265+ }
266+
267+ fn or_arrays_balanced ( mut arrays : Vec < ArrayRef > ) -> VortexResult < ArrayRef > {
268+ debug_assert ! ( !arrays. is_empty( ) ) ;
269+
270+ while arrays. len ( ) > 1 {
271+ let mut next = Vec :: with_capacity ( arrays. len ( ) . div_ceil ( 2 ) ) ;
272+ let mut i = 0 ;
273+ while i + 1 < arrays. len ( ) {
274+ next. push ( arrays[ i] . binary ( arrays[ i + 1 ] . clone ( ) , Operator :: Or ) ?) ;
275+ i += 2 ;
260276 }
277+ if i < arrays. len ( ) {
278+ next. push ( arrays[ i] . clone ( ) ) ;
279+ }
280+ arrays = next;
261281 }
262- Ok ( result. unwrap_or_else ( || ConstantArray :: new ( false_scalar, len) . to_array ( ) ) )
282+
283+ Ok ( arrays
284+ . pop ( )
285+ . expect ( "or_arrays_balanced must be called with at least one array" ) )
263286}
264287
265288/// Returns a [`BoolArray`] where each bit represents if a list contains the scalar.
@@ -429,6 +452,8 @@ fn list_is_not_empty(
429452mod tests {
430453 use std:: sync:: Arc ;
431454
455+ use super :: or_arrays_balanced;
456+
432457 use itertools:: Itertools ;
433458 use rstest:: rstest;
434459 use vortex_buffer:: BitBuffer ;
@@ -789,6 +814,44 @@ mod tests {
789814 assert_arrays_eq ! ( contains, expected) ;
790815 }
791816
817+ fn array_depth ( array : & dyn Array ) -> usize {
818+ 1 + ( 0 ..array. nchildren ( ) )
819+ . filter_map ( |idx| array. nth_child ( idx) )
820+ . map ( |child| array_depth ( child. as_ref ( ) ) )
821+ . max ( )
822+ . unwrap_or ( 0 )
823+ }
824+
825+ #[ test]
826+ fn test_or_arrays_balanced_depth ( ) {
827+ let arrays = vec ! [
828+ BoolArray :: from_iter( [ true , false ] ) . into_array( ) ,
829+ BoolArray :: from_iter( [ false , true ] ) . into_array( ) ,
830+ BoolArray :: from_iter( [ false , false ] ) . into_array( ) ,
831+ BoolArray :: from_iter( [ true , true ] ) . into_array( ) ,
832+ BoolArray :: from_iter( [ true , false ] ) . into_array( ) ,
833+ ] ;
834+
835+ let result = or_arrays_balanced ( arrays) . unwrap ( ) ;
836+ assert_eq ! ( array_depth( result. as_ref( ) ) , 4 ) ;
837+ }
838+
839+ #[ test]
840+ fn test_constant_list_large_regression ( ) {
841+ let list_scalar = Scalar :: list (
842+ Arc :: new ( DType :: Primitive ( I32 , Nullability :: NonNullable ) ) ,
843+ ( 0i32 ..2048 ) . map ( Into :: into) . collect ( ) ,
844+ Nullability :: NonNullable ,
845+ ) ;
846+
847+ let values = PrimitiveArray :: from_iter ( 0i32 ..2048 ) . into_array ( ) ;
848+ let expr = list_contains ( lit ( list_scalar) , root ( ) ) ;
849+ let contains = values. apply ( & expr) . unwrap ( ) ;
850+
851+ let expected = BoolArray :: from_iter ( std:: iter:: repeat_n ( true , 2048 ) ) ;
852+ assert_arrays_eq ! ( contains, expected) ;
853+ }
854+
792855 #[ test]
793856 fn test_all_nulls ( ) {
794857 let list_array = ConstantArray :: new (
0 commit comments