@@ -6,16 +6,17 @@ use vortex_buffer::BufferMut;
66use vortex_error:: VortexExpect ;
77use vortex_error:: VortexResult ;
88use vortex_error:: vortex_panic;
9+ use vortex_mask:: Mask ;
910
1011use crate :: ArrayRef ;
1112use crate :: IntoArray ;
12- use crate :: LEGACY_SESSION ;
13- use crate :: VortexSessionExecute ;
13+ use crate :: ToCanonical ;
1414use crate :: array:: ArrayView ;
1515use crate :: arrays:: FixedSizeList ;
1616use crate :: arrays:: FixedSizeListArray ;
1717use crate :: arrays:: Primitive ;
1818use crate :: arrays:: PrimitiveArray ;
19+ use crate :: arrays:: bool:: BoolArrayExt ;
1920use crate :: arrays:: dict:: TakeExecute ;
2021use crate :: arrays:: fixed_size_list:: FixedSizeListArrayExt ;
2122use crate :: dtype:: IntegerPType ;
@@ -83,7 +84,7 @@ fn take_with_indices<I: IntegerPType, E: IntegerPType>(
8384 // The result's nullability is the union of the input nullabilities.
8485 if array. dtype ( ) . is_nullable ( ) || indices_array. dtype ( ) . is_nullable ( ) {
8586 let indices_array = indices_array. as_view ( ) ;
86- take_nullable_fsl :: < I , E > ( array, indices_array)
87+ take_nullable_fsl :: < I , E > ( array, indices_array, ctx )
8788 } else {
8889 let indices_array = indices_array. as_view ( ) ;
8990 take_non_nullable_fsl :: < I , E > ( array, indices_array)
@@ -144,20 +145,25 @@ fn take_non_nullable_fsl<I: IntegerPType, E: IntegerPType>(
144145fn take_nullable_fsl < I : IntegerPType , E : IntegerPType > (
145146 array : ArrayView < ' _ , FixedSizeList > ,
146147 indices_array : ArrayView < ' _ , Primitive > ,
148+ ctx : & mut ExecutionCtx ,
147149) -> VortexResult < ArrayRef > {
148150 let list_size = array. list_size ( ) as usize ;
149151 let indices: & [ I ] = indices_array. as_slice :: < I > ( ) ;
150152 let new_len = indices. len ( ) ;
151153
152- let array_validity = array. fixed_size_list_validity_mask ( ) ;
153- let indices_validity = indices_array
154+ let array_validity = array
155+ . fixed_size_list_validity ( )
156+ . to_mask ( array. as_ref ( ) . len ( ) , ctx)
157+ . vortex_expect ( "Failed to compute validity mask" ) ;
158+ let indices_len = indices_array. as_ref ( ) . len ( ) ;
159+ let indices_validity = match indices_array
154160 . validity ( )
155161 . vortex_expect ( "Failed to compute validity mask" )
156- . to_mask (
157- indices_array . as_ref ( ) . len ( ) ,
158- & mut LEGACY_SESSION . create_execution_ctx ( ) ,
159- )
160- . vortex_expect ( "Failed to compute validity mask" ) ;
162+ {
163+ Validity :: NonNullable | Validity :: AllValid => Mask :: new_true ( indices_len ) ,
164+ Validity :: AllInvalid => Mask :: new_false ( indices_len ) ,
165+ Validity :: Array ( a ) => a . to_bool ( ) . to_mask ( ) ,
166+ } ;
161167
162168 // We must use placeholder zeros for null lists to maintain the array length without
163169 // propagating nullability to the element array's take operation.
0 commit comments