@@ -5,12 +5,18 @@ use rustc_hash::FxHashMap;
55use vortex_buffer:: Buffer ;
66use vortex_error:: VortexResult ;
77
8+ use crate :: ArrayRef ;
9+ use crate :: DynArray ;
10+ use crate :: ExecutionCtx ;
11+ use crate :: IntoArray ;
12+ use crate :: arrays:: Patched ;
13+ use crate :: arrays:: PrimitiveArray ;
814use crate :: arrays:: dict:: TakeExecute ;
915use crate :: arrays:: primitive:: PrimitiveArrayParts ;
10- use crate :: arrays :: { Patched , PrimitiveArray } ;
11- use crate :: dtype:: { IntegerPType , NativePType } ;
12- use crate :: { ArrayRef , DynArray , IntoArray , match_each_native_ptype} ;
13- use crate :: { ExecutionCtx , match_each_unsigned_integer_ptype} ;
16+ use crate :: dtype :: IntegerPType ;
17+ use crate :: dtype:: NativePType ;
18+ use crate :: match_each_native_ptype;
19+ use crate :: match_each_unsigned_integer_ptype;
1420
1521impl TakeExecute for Patched {
1622 fn take (
@@ -50,19 +56,20 @@ impl TakeExecute for Patched {
5056
5157 // SAFETY: output and validity still have same length after take_map returns.
5258 unsafe {
53- return Ok ( Some (
59+ Ok ( Some (
5460 PrimitiveArray :: new_unchecked( output. freeze( ) , validity) . into_array( ) ,
55- ) ) ;
61+ ) )
5662 }
5763 } )
58- } ) ;
64+ } )
5965 }
6066}
6167
6268/// Take patches for the given `indices` and apply them onto an `output` using a hash map.
6369///
6470/// First, builds a hashmap from index to patch value, then uses the hashmap in a loop to collect
6571/// the values.
72+ #[ allow( clippy:: too_many_arguments) ]
6673fn take_map < I : IntegerPType , V : NativePType > (
6774 output : & mut [ V ] ,
6875 indices : & [ I ] ,
@@ -75,10 +82,11 @@ fn take_map<I: IntegerPType, V: NativePType>(
7582 patch_value : & [ V ] ,
7683) {
7784 // Build a hashmap of patch_index -> values.
78- let mut index_map = FxHashMap :: with_capacity ( indices. len ( ) ) ;
85+ let mut index_map = FxHashMap :: with_capacity_and_hasher ( indices. len ( ) , Default :: default ( ) ) ;
7986 for chunk in 0 ..n_chunks {
8087 for lane in 0 ..n_lanes {
81- let [ lane_start, lane_end] = lane_offsets[ chunk * n_lanes + lane..] [ ..2 ] ;
88+ let lane_start = lane_offsets[ chunk * n_lanes + lane] ;
89+ let lane_end = lane_offsets[ chunk * n_lanes + lane + 1 ] ;
8290 for i in lane_start..lane_end {
8391 let patch_idx = patch_index[ i as usize ] ;
8492 let patch_value = patch_value[ i as usize ] ;
@@ -103,8 +111,120 @@ fn take_map<I: IntegerPType, V: NativePType>(
103111
104112#[ cfg( test) ]
105113mod tests {
114+ use vortex_buffer:: buffer;
115+ use vortex_error:: VortexResult ;
116+ use vortex_session:: VortexSession ;
117+
118+ use crate :: DynArray ;
119+ use crate :: ExecutionCtx ;
120+ use crate :: IntoArray ;
121+ use crate :: arrays:: PatchedArray ;
122+ use crate :: arrays:: PrimitiveArray ;
123+ use crate :: assert_arrays_eq;
124+ use crate :: patches:: Patches ;
125+
126+ fn make_patched_array (
127+ base : & [ u16 ] ,
128+ patch_indices : & [ u32 ] ,
129+ patch_values : & [ u16 ] ,
130+ ) -> VortexResult < PatchedArray > {
131+ let values = PrimitiveArray :: from_iter ( base. iter ( ) . copied ( ) ) . into_array ( ) ;
132+ let patches = Patches :: new (
133+ base. len ( ) ,
134+ 0 ,
135+ PrimitiveArray :: from_iter ( patch_indices. iter ( ) . copied ( ) ) . into_array ( ) ,
136+ PrimitiveArray :: from_iter ( patch_values. iter ( ) . copied ( ) ) . into_array ( ) ,
137+ None ,
138+ ) ?;
139+
140+ let session = VortexSession :: empty ( ) ;
141+ let mut ctx = ExecutionCtx :: new ( session) ;
142+
143+ PatchedArray :: from_array_and_patches ( values, & patches, & mut ctx)
144+ }
145+
146+ #[ test]
147+ fn test_take_basic ( ) -> VortexResult < ( ) > {
148+ // Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
149+ let array = make_patched_array ( & [ 0 ; 5 ] , & [ 1 , 3 ] , & [ 10 , 30 ] ) ?. into_array ( ) ;
150+
151+ // Take indices [0, 1, 2, 3, 4] - should get [0, 10, 0, 30, 0]
152+ let indices = buffer ! [ 0u32 , 1 , 2 , 3 , 4 ] . into_array ( ) ;
153+ let result = array. take ( indices) ?;
154+
155+ let expected = PrimitiveArray :: from_iter ( [ 0u16 , 10 , 0 , 30 , 0 ] ) . into_array ( ) ;
156+ assert_arrays_eq ! ( expected, result) ;
157+
158+ Ok ( ( ) )
159+ }
160+
161+ #[ test]
162+ fn test_take_out_of_order ( ) -> VortexResult < ( ) > {
163+ // Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
164+ let array = make_patched_array ( & [ 0 ; 5 ] , & [ 1 , 3 ] , & [ 10 , 30 ] ) ?. into_array ( ) ;
165+
166+ // Take indices in reverse order
167+ let indices = buffer ! [ 4u32 , 3 , 2 , 1 , 0 ] . into_array ( ) ;
168+ let result = array. take ( indices) ?;
169+
170+ let expected = PrimitiveArray :: from_iter ( [ 0u16 , 30 , 0 , 10 , 0 ] ) . into_array ( ) ;
171+ assert_arrays_eq ! ( expected, result) ;
172+
173+ Ok ( ( ) )
174+ }
175+
176+ #[ test]
177+ fn test_take_duplicates ( ) -> VortexResult < ( ) > {
178+ // Array with base values [0, 0, 0, 0, 0] patched at index [2] with value [99]
179+ let array = make_patched_array ( & [ 0 ; 5 ] , & [ 2 ] , & [ 99 ] ) ?. into_array ( ) ;
180+
181+ // Take the same patched index multiple times
182+ let indices = buffer ! [ 2u32 , 2 , 0 , 2 ] . into_array ( ) ;
183+ let result = array. take ( indices) ?;
184+
185+ let expected = PrimitiveArray :: from_iter ( [ 99u16 , 99 , 0 , 99 ] ) . into_array ( ) ;
186+ assert_arrays_eq ! ( expected, result) ;
187+
188+ Ok ( ( ) )
189+ }
190+
106191 #[ test]
107- fn test_take ( ) {
108- // Patch some values here instead.
192+ fn test_take_with_null_indices ( ) -> VortexResult < ( ) > {
193+ use crate :: arrays:: BoolArray ;
194+ use crate :: validity:: Validity ;
195+
196+ // Array: 10 elements, base value 0, patches at indices 2, 5, 8 with values 20, 50, 80
197+ let array = make_patched_array ( & [ 0 ; 10 ] , & [ 2 , 5 , 8 ] , & [ 20 , 50 , 80 ] ) ?. into_array ( ) ;
198+
199+ // Take 10 indices, with nulls at positions 1, 4, 7
200+ // Indices: [0, 2, 2, 5, 8, 0, 5, 8, 3, 1]
201+ // Nulls: [ , , N, , , N, , , N, ]
202+ // Position 2 (index=2, patched) is null
203+ // Position 5 (index=0, unpatched) is null
204+ // Position 8 (index=3, unpatched) is null
205+ let indices = PrimitiveArray :: new (
206+ buffer ! [ 0u32 , 2 , 2 , 5 , 8 , 0 , 5 , 8 , 3 , 1 ] ,
207+ Validity :: Array (
208+ BoolArray :: from_iter ( [
209+ true , true , false , true , true , false , true , true , false , true ,
210+ ] )
211+ . into_array ( ) ,
212+ ) ,
213+ ) ;
214+ let result = array. take ( indices. into_array ( ) ) ?;
215+
216+ // Expected: [0, 20, null, 50, 80, null, 50, 80, null, 0]
217+ let expected = PrimitiveArray :: new (
218+ buffer ! [ 0u16 , 20 , 0 , 50 , 80 , 0 , 50 , 80 , 0 , 0 ] ,
219+ Validity :: Array (
220+ BoolArray :: from_iter ( [
221+ true , true , false , true , true , false , true , true , false , true ,
222+ ] )
223+ . into_array ( ) ,
224+ ) ,
225+ ) ;
226+ assert_arrays_eq ! ( expected. into_array( ) , result) ;
227+
228+ Ok ( ( ) )
109229 }
110230}
0 commit comments