11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use itertools:: Itertools ;
54use vortex_error:: VortexResult ;
65use vortex_error:: vortex_bail;
76use vortex_session:: registry:: Id ;
7+ use vortex_utils:: aliases:: hash_set:: HashSet ;
88
99use crate :: ArrayRef ;
10- use crate :: session :: ArrayRegistry ;
10+ use crate :: ExecutionCtx ;
1111
1212/// Options for normalizing an array.
1313pub struct NormalizeOptions < ' a > {
1414 /// The set of allowed array encodings (in addition to the canonical ones) that are permitted
1515 /// in the normalized array.
16- pub allowed : & ' a ArrayRegistry ,
16+ pub allowed : & ' a HashSet < Id > ,
1717 /// The operation to perform when a non-allowed encoding is encountered.
18- pub operation : Operation ,
18+ pub operation : Operation < ' a > ,
1919}
2020
2121/// The operation to perform when a non-allowed encoding is encountered.
22- pub enum Operation {
22+ pub enum Operation < ' a > {
2323 Error ,
24- // TODO(joe): add into canonical variant
24+ Execute ( & ' a mut ExecutionCtx ) ,
2525}
2626
2727impl ArrayRef {
@@ -30,14 +30,18 @@ impl ArrayRef {
3030 /// This operation performs a recursive traversal of the array. Any non-allowed encoding is
3131 /// normalized per the configured operation.
3232 pub fn normalize ( self , options : & mut NormalizeOptions ) -> VortexResult < ArrayRef > {
33- let array_ids = options. allowed . ids ( ) . collect_vec ( ) ;
34- self . normalize_with_error ( & array_ids) ?;
35- // Note this takes ownership so we can at a later date remove non-allowed encodings.
36- Ok ( self )
33+ match & mut options. operation {
34+ Operation :: Error => {
35+ self . normalize_with_error ( options. allowed ) ?;
36+ // Note this takes ownership so we can at a later date remove non-allowed encodings.
37+ Ok ( self )
38+ }
39+ Operation :: Execute ( ctx) => self . normalize_with_execution ( options. allowed , ctx) ,
40+ }
3741 }
3842
39- fn normalize_with_error ( & self , allowed : & [ Id ] ) -> VortexResult < ( ) > {
40- if !allowed . contains ( & self . encoding_id ( ) ) {
43+ fn normalize_with_error ( & self , allowed : & HashSet < Id > ) -> VortexResult < ( ) > {
44+ if !self . is_allowed_encoding ( allowed ) {
4145 vortex_bail ! ( AssertionFailed : "normalize forbids encoding ({})" , self . encoding_id( ) )
4246 }
4347
@@ -46,4 +50,183 @@ impl ArrayRef {
4650 }
4751 Ok ( ( ) )
4852 }
53+
54+ fn normalize_with_execution (
55+ self ,
56+ allowed : & HashSet < Id > ,
57+ ctx : & mut ExecutionCtx ,
58+ ) -> VortexResult < ArrayRef > {
59+ let mut normalized = self ;
60+
61+ // Top-first execute the array tree while we hit non-allowed encodings.
62+ while !normalized. is_allowed_encoding ( allowed) {
63+ normalized = normalized. execute ( ctx) ?;
64+ }
65+
66+ // Now we've normalized the root, we need to ensure the children are normalized also.
67+ let slots = normalized. slots ( ) ;
68+ let mut normalized_slots = Vec :: with_capacity ( slots. len ( ) ) ;
69+ let mut any_slot_changed = false ;
70+
71+ for slot in slots {
72+ match slot {
73+ Some ( child) => {
74+ let normalized_child = child. clone ( ) . normalize ( & mut NormalizeOptions {
75+ allowed,
76+ operation : Operation :: Execute ( ctx) ,
77+ } ) ?;
78+ any_slot_changed |= !ArrayRef :: ptr_eq ( child, & normalized_child) ;
79+ normalized_slots. push ( Some ( normalized_child) ) ;
80+ }
81+ None => normalized_slots. push ( None ) ,
82+ }
83+ }
84+
85+ if any_slot_changed {
86+ normalized = normalized. with_slots ( normalized_slots) ?;
87+ }
88+
89+ Ok ( normalized)
90+ }
91+
92+ fn is_allowed_encoding ( & self , allowed : & HashSet < Id > ) -> bool {
93+ allowed. contains ( & self . encoding_id ( ) ) || self . is_canonical ( )
94+ }
95+ }
96+
97+ #[ cfg( test) ]
98+ mod tests {
99+ use vortex_error:: VortexResult ;
100+ use vortex_session:: VortexSession ;
101+ use vortex_utils:: aliases:: hash_set:: HashSet ;
102+
103+ use super :: NormalizeOptions ;
104+ use super :: Operation ;
105+ use crate :: ArrayRef ;
106+ use crate :: ExecutionCtx ;
107+ use crate :: IntoArray ;
108+ use crate :: arrays:: Dict ;
109+ use crate :: arrays:: DictArray ;
110+ use crate :: arrays:: Primitive ;
111+ use crate :: arrays:: PrimitiveArray ;
112+ use crate :: arrays:: Slice ;
113+ use crate :: arrays:: SliceArray ;
114+ use crate :: arrays:: StructArray ;
115+ use crate :: assert_arrays_eq;
116+ use crate :: validity:: Validity ;
117+
118+ #[ test]
119+ fn normalize_with_execution_keeps_parent_when_children_are_unchanged ( ) -> VortexResult < ( ) > {
120+ let field = PrimitiveArray :: from_iter ( 0i32 ..4 ) . into_array ( ) ;
121+ let array = StructArray :: try_new (
122+ [ "field" ] . into ( ) ,
123+ vec ! [ field. clone( ) ] ,
124+ field. len ( ) ,
125+ Validity :: NonNullable ,
126+ ) ?
127+ . into_array ( ) ;
128+ let allowed = HashSet :: from_iter ( [ array. encoding_id ( ) , field. encoding_id ( ) ] ) ;
129+ let mut ctx = ExecutionCtx :: new ( VortexSession :: empty ( ) ) ;
130+
131+ let normalized = array. clone ( ) . normalize ( & mut NormalizeOptions {
132+ allowed : & allowed,
133+ operation : Operation :: Execute ( & mut ctx) ,
134+ } ) ?;
135+
136+ assert ! ( ArrayRef :: ptr_eq( & array, & normalized) ) ;
137+ Ok ( ( ) )
138+ }
139+
140+ #[ test]
141+ fn normalize_with_error_allows_canonical_arrays ( ) -> VortexResult < ( ) > {
142+ let field = PrimitiveArray :: from_iter ( 0i32 ..4 ) . into_array ( ) ;
143+ let array = StructArray :: try_new (
144+ [ "field" ] . into ( ) ,
145+ vec ! [ field. clone( ) ] ,
146+ field. len ( ) ,
147+ Validity :: NonNullable ,
148+ ) ?
149+ . into_array ( ) ;
150+ let allowed = HashSet :: default ( ) ;
151+
152+ let normalized = array. clone ( ) . normalize ( & mut NormalizeOptions {
153+ allowed : & allowed,
154+ operation : Operation :: Error ,
155+ } ) ?;
156+
157+ assert ! ( ArrayRef :: ptr_eq( & array, & normalized) ) ;
158+ Ok ( ( ) )
159+ }
160+
161+ #[ test]
162+ fn normalize_with_execution_rebuilds_parent_when_a_child_changes ( ) -> VortexResult < ( ) > {
163+ let unchanged = PrimitiveArray :: from_iter ( 0i32 ..4 ) . into_array ( ) ;
164+ let sliced =
165+ SliceArray :: new ( PrimitiveArray :: from_iter ( 10i32 ..20 ) . into_array ( ) , 2 ..6 ) . into_array ( ) ;
166+ let array = StructArray :: try_new (
167+ [ "lhs" , "rhs" ] . into ( ) ,
168+ vec ! [ unchanged. clone( ) , sliced] ,
169+ unchanged. len ( ) ,
170+ Validity :: NonNullable ,
171+ ) ?
172+ . into_array ( ) ;
173+ let allowed = HashSet :: from_iter ( [ array. encoding_id ( ) , unchanged. encoding_id ( ) ] ) ;
174+ let mut ctx = ExecutionCtx :: new ( VortexSession :: empty ( ) ) ;
175+
176+ let normalized = array. clone ( ) . normalize ( & mut NormalizeOptions {
177+ allowed : & allowed,
178+ operation : Operation :: Execute ( & mut ctx) ,
179+ } ) ?;
180+
181+ assert ! ( !ArrayRef :: ptr_eq( & array, & normalized) ) ;
182+
183+ let original_children = array. children ( ) ;
184+ let normalized_children = normalized. children ( ) ;
185+ assert ! ( ArrayRef :: ptr_eq(
186+ & original_children[ 0 ] ,
187+ & normalized_children[ 0 ]
188+ ) ) ;
189+ assert ! ( !ArrayRef :: ptr_eq(
190+ & original_children[ 1 ] ,
191+ & normalized_children[ 1 ]
192+ ) ) ;
193+ assert_arrays_eq ! ( normalized_children[ 1 ] , PrimitiveArray :: from_iter( 12i32 ..16 ) ) ;
194+
195+ Ok ( ( ) )
196+ }
197+
198+ #[ test]
199+ fn normalize_slice_of_dict_returns_dict ( ) -> VortexResult < ( ) > {
200+ let codes = PrimitiveArray :: from_iter ( vec ! [ 0u32 , 1 , 0 , 1 , 2 ] ) . into_array ( ) ;
201+ let values = PrimitiveArray :: from_iter ( vec ! [ 10i32 , 20 , 30 ] ) . into_array ( ) ;
202+ let dict = DictArray :: try_new ( codes, values) ?. into_array ( ) ;
203+
204+ // Slice the dict array to get a SliceArray wrapping a DictArray.
205+ let sliced = SliceArray :: new ( dict, 1 ..4 ) . into_array ( ) ;
206+ assert_eq ! ( sliced. encoding_id( ) , Slice :: ID ) ;
207+
208+ let allowed = HashSet :: from_iter ( [ Dict :: ID , Primitive :: ID ] ) ;
209+ let mut ctx = ExecutionCtx :: new ( VortexSession :: empty ( ) ) ;
210+
211+ println ! ( "sliced {}" , sliced. display_tree( ) ) ;
212+
213+ let normalized = sliced. normalize ( & mut NormalizeOptions {
214+ allowed : & allowed,
215+ operation : Operation :: Execute ( & mut ctx) ,
216+ } ) ?;
217+
218+ println ! ( "after {}" , normalized. display_tree( ) ) ;
219+
220+ // The normalized result should be a DictArray, not a SliceArray.
221+ assert_eq ! ( normalized. encoding_id( ) , Dict :: ID ) ;
222+ assert_eq ! ( normalized. len( ) , 3 ) ;
223+
224+ // Verify the data: codes [1,0,1] -> values [20, 10, 20]
225+ assert_arrays_eq ! (
226+ normalized. to_canonical( ) ?,
227+ PrimitiveArray :: from_iter( vec![ 20i32 , 10 , 20 ] )
228+ ) ;
229+
230+ Ok ( ( ) )
231+ }
49232}
0 commit comments