11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use itertools:: Itertools ;
54use vortex_error:: VortexResult ;
65use vortex_error:: vortex_bail;
76use vortex_session:: registry:: Id ;
7+ use vortex_utils:: aliases:: hash_set:: HashSet ;
88
99use crate :: ArrayRef ;
10- use crate :: session :: ArrayRegistry ;
10+ use crate :: ExecutionCtx ;
1111
1212/// Options for normalizing an array.
1313pub struct NormalizeOptions < ' a > {
1414 /// The set of allowed array encodings (in addition to the canonical ones) that are permitted
1515 /// in the normalized array.
16- pub allowed : & ' a ArrayRegistry ,
16+ pub allowed : & ' a HashSet < Id > ,
1717 /// The operation to perform when a non-allowed encoding is encountered.
18- pub operation : Operation ,
18+ pub operation : Operation < ' a > ,
1919}
2020
2121/// The operation to perform when a non-allowed encoding is encountered.
22- pub enum Operation {
22+ pub enum Operation < ' a > {
2323 Error ,
24- // TODO(joe): add into canonical variant
24+ Execute ( & ' a mut ExecutionCtx ) ,
2525}
2626
2727impl ArrayRef {
@@ -30,13 +30,17 @@ impl ArrayRef {
3030 /// This operation performs a recursive traversal of the array. Any non-allowed encoding is
3131 /// normalized per the configured operation.
3232 pub fn normalize ( self , options : & mut NormalizeOptions ) -> VortexResult < ArrayRef > {
33- let array_ids = options. allowed . ids ( ) . collect_vec ( ) ;
34- self . normalize_with_error ( & array_ids) ?;
35- // Note this takes ownership so we can at a later date remove non-allowed encodings.
36- Ok ( self )
33+ match & mut options. operation {
34+ Operation :: Error => {
35+ self . normalize_with_error ( options. allowed ) ?;
36+ // Note this takes ownership so we can at a later date remove non-allowed encodings.
37+ Ok ( self )
38+ }
39+ Operation :: Execute ( ctx) => self . normalize_with_execution ( options. allowed , * ctx) ,
40+ }
3741 }
3842
39- fn normalize_with_error ( & self , allowed : & [ Id ] ) -> VortexResult < ( ) > {
43+ fn normalize_with_error ( & self , allowed : & HashSet < Id > ) -> VortexResult < ( ) > {
4044 if !allowed. contains ( & self . encoding_id ( ) ) {
4145 vortex_bail ! ( AssertionFailed : "normalize forbids encoding ({})" , self . encoding_id( ) )
4246 }
@@ -46,4 +50,118 @@ impl ArrayRef {
4650 }
4751 Ok ( ( ) )
4852 }
53+
54+ fn normalize_with_execution (
55+ self ,
56+ allowed : & HashSet < Id > ,
57+ ctx : & mut ExecutionCtx ,
58+ ) -> VortexResult < ArrayRef > {
59+ let mut normalized = self ;
60+
61+ // Top-first execute the array tree while we hit non-allowed encodings.
62+ while !allowed. contains ( & normalized. encoding_id ( ) ) {
63+ normalized = normalized. execute ( ctx) ?;
64+ }
65+
66+ // Now we've normalized the root, we need to ensure the children are normalized also.
67+ let slots = normalized. slots ( ) ;
68+ let mut normalized_slots = Vec :: with_capacity ( slots. len ( ) ) ;
69+ let mut any_slot_changed = false ;
70+
71+ for slot in slots {
72+ match slot {
73+ Some ( child) => {
74+ let normalized_child = child. clone ( ) . normalize ( & mut NormalizeOptions {
75+ allowed,
76+ operation : Operation :: Execute ( ctx) ,
77+ } ) ?;
78+ any_slot_changed |= !ArrayRef :: ptr_eq ( & child, & normalized_child) ;
79+ normalized_slots. push ( Some ( normalized_child) ) ;
80+ }
81+ None => normalized_slots. push ( None ) ,
82+ }
83+ }
84+
85+ if any_slot_changed {
86+ normalized = normalized. with_slots ( normalized_slots) ?;
87+ }
88+
89+ Ok ( normalized)
90+ }
91+ }
92+
93+ #[ cfg( test) ]
94+ mod tests {
95+ use vortex_error:: VortexResult ;
96+ use vortex_session:: VortexSession ;
97+
98+ use super :: NormalizeOptions ;
99+ use super :: Operation ;
100+ use crate :: ArrayRef ;
101+ use crate :: ExecutionCtx ;
102+ use crate :: IntoArray ;
103+ use crate :: arrays:: PrimitiveArray ;
104+ use crate :: arrays:: SliceArray ;
105+ use crate :: arrays:: StructArray ;
106+ use crate :: assert_arrays_eq;
107+ use crate :: validity:: Validity ;
108+
109+ #[ test]
110+ fn normalize_with_execution_keeps_parent_when_children_are_unchanged ( ) -> VortexResult < ( ) > {
111+ let field = PrimitiveArray :: from_iter ( 0i32 ..4 ) . into_array ( ) ;
112+ let array = StructArray :: try_new (
113+ [ "field" ] . into ( ) ,
114+ vec ! [ field. clone( ) ] ,
115+ field. len ( ) ,
116+ Validity :: NonNullable ,
117+ ) ?
118+ . into_array ( ) ;
119+ let allowed = HashSet :: from_iter ( [ array. encoding_id ( ) , field. encoding_id ( ) ] ) ;
120+ let mut ctx = ExecutionCtx :: new ( VortexSession :: empty ( ) ) ;
121+
122+ let normalized = array. clone ( ) . normalize ( & mut NormalizeOptions {
123+ allowed : & allowed,
124+ operation : Operation :: Execute ( & mut ctx) ,
125+ } ) ?;
126+
127+ assert ! ( ArrayRef :: ptr_eq( & array, & normalized) ) ;
128+ Ok ( ( ) )
129+ }
130+
131+ #[ test]
132+ fn normalize_with_execution_rebuilds_parent_when_a_child_changes ( ) -> VortexResult < ( ) > {
133+ let unchanged = PrimitiveArray :: from_iter ( 0i32 ..4 ) . into_array ( ) ;
134+ let sliced =
135+ SliceArray :: new ( PrimitiveArray :: from_iter ( 10i32 ..20 ) . into_array ( ) , 2 ..6 ) . into_array ( ) ;
136+ let array = StructArray :: try_new (
137+ [ "lhs" , "rhs" ] . into ( ) ,
138+ vec ! [ unchanged. clone( ) , sliced. clone( ) ] ,
139+ unchanged. len ( ) ,
140+ Validity :: NonNullable ,
141+ ) ?
142+ . into_array ( ) ;
143+ let allowed = HashSet :: from_iter ( [ array. encoding_id ( ) , unchanged. encoding_id ( ) ] ) ;
144+ let mut ctx = ExecutionCtx :: new ( VortexSession :: empty ( ) ) ;
145+
146+ let normalized = array. clone ( ) . normalize ( & mut NormalizeOptions {
147+ allowed : & allowed,
148+ operation : Operation :: Execute ( & mut ctx) ,
149+ } ) ?;
150+
151+ assert ! ( !ArrayRef :: ptr_eq( & array, & normalized) ) ;
152+
153+ let original_children = array. children ( ) ;
154+ let normalized_children = normalized. children ( ) ;
155+ assert ! ( ArrayRef :: ptr_eq(
156+ & original_children[ 0 ] ,
157+ & normalized_children[ 0 ]
158+ ) ) ;
159+ assert ! ( !ArrayRef :: ptr_eq(
160+ & original_children[ 1 ] ,
161+ & normalized_children[ 1 ]
162+ ) ) ;
163+ assert_arrays_eq ! ( normalized_children[ 1 ] , PrimitiveArray :: from_iter( 12i32 ..16 ) ) ;
164+
165+ Ok ( ( ) )
166+ }
49167}
0 commit comments