@@ -14,52 +14,46 @@ use crate::arrays::Filter;
1414use crate :: arrays:: extension:: ExtensionArrayExt ;
1515use crate :: arrays:: filter:: FilterReduceAdaptor ;
1616use crate :: arrays:: slice:: SliceReduceAdaptor ;
17- use crate :: matcher:: AnyArray ;
1817use crate :: optimizer:: rules:: ArrayParentReduceRule ;
18+ use crate :: optimizer:: rules:: ArrayReduceRule ;
1919use crate :: optimizer:: rules:: ParentRuleSet ;
20+ use crate :: optimizer:: rules:: ReduceRuleSet ;
2021use crate :: scalar:: Scalar ;
2122use crate :: scalar_fn:: fns:: cast:: CastReduceAdaptor ;
2223use crate :: scalar_fn:: fns:: mask:: MaskReduceAdaptor ;
2324
24- pub ( crate ) const PARENT_RULES : ParentRuleSet < Extension > = ParentRuleSet :: new ( & [
25- ParentRuleSet :: lift ( & ExtensionConstantParentRule ) ,
26- ParentRuleSet :: lift ( & ExtensionFilterPushDownRule ) ,
27- ParentRuleSet :: lift ( & CastReduceAdaptor ( Extension ) ) ,
28- ParentRuleSet :: lift ( & FilterReduceAdaptor ( Extension ) ) ,
29- ParentRuleSet :: lift ( & MaskReduceAdaptor ( Extension ) ) ,
30- ParentRuleSet :: lift ( & SliceReduceAdaptor ( Extension ) ) ,
31- ] ) ;
25+ pub ( crate ) const RULES : ReduceRuleSet < Extension > = ReduceRuleSet :: new ( & [ & ExtensionConstantRule ] ) ;
3226
3327/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`.
3428#[ derive( Debug ) ]
35- struct ExtensionConstantParentRule ;
36-
37- impl ArrayParentReduceRule < Extension > for ExtensionConstantParentRule {
38- type Parent = AnyArray ;
29+ struct ExtensionConstantRule ;
3930
40- fn reduce_parent (
41- & self ,
42- child : ArrayView < ' _ , Extension > ,
43- parent : & ArrayRef ,
44- child_idx : usize ,
45- ) -> VortexResult < Option < ArrayRef > > {
46- let Some ( const_array) = child. storage_array ( ) . as_opt :: < Constant > ( ) else {
31+ impl ArrayReduceRule < Extension > for ExtensionConstantRule {
32+ fn reduce ( & self , array : ArrayView < ' _ , Extension > ) -> VortexResult < Option < ArrayRef > > {
33+ let Some ( const_array) = array. storage_array ( ) . as_opt :: < Constant > ( ) else {
34+ println ! ( "not constant" ) ;
4735 return Ok ( None ) ;
4836 } ;
37+ println ! ( "reducing" ) ;
4938
5039 let storage_scalar = const_array. scalar ( ) . clone ( ) ;
51- let ext_scalar = Scalar :: extension_ref ( child . ext_dtype ( ) . clone ( ) , storage_scalar) ;
40+ let ext_scalar = Scalar :: extension_ref ( array . ext_dtype ( ) . clone ( ) , storage_scalar) ;
5241
5342 let constant_with_extension_scalar =
54- ConstantArray :: new ( ext_scalar, child . len ( ) ) . into_array ( ) ;
43+ ConstantArray :: new ( ext_scalar, array . len ( ) ) . into_array ( ) ;
5544
56- parent
57- . clone ( )
58- . with_slot ( child_idx, constant_with_extension_scalar)
59- . map ( Some )
45+ Ok ( Some ( constant_with_extension_scalar. into_array ( ) ) )
6046 }
6147}
6248
49+ pub ( crate ) const PARENT_RULES : ParentRuleSet < Extension > = ParentRuleSet :: new ( & [
50+ ParentRuleSet :: lift ( & ExtensionFilterPushDownRule ) ,
51+ ParentRuleSet :: lift ( & CastReduceAdaptor ( Extension ) ) ,
52+ ParentRuleSet :: lift ( & FilterReduceAdaptor ( Extension ) ) ,
53+ ParentRuleSet :: lift ( & MaskReduceAdaptor ( Extension ) ) ,
54+ ParentRuleSet :: lift ( & SliceReduceAdaptor ( Extension ) ) ,
55+ ] ) ;
56+
6357/// Push filter operations into the storage array of an extension array.
6458#[ derive( Debug ) ]
6559struct ExtensionFilterPushDownRule ;
@@ -99,6 +93,7 @@ mod tests {
9993 use crate :: arrays:: ExtensionArray ;
10094 use crate :: arrays:: FilterArray ;
10195 use crate :: arrays:: PrimitiveArray ;
96+ use crate :: arrays:: ScalarFnVTable ;
10297 use crate :: arrays:: extension:: ExtensionArrayExt ;
10398 use crate :: arrays:: scalar_fn:: ScalarFnArrayExt ;
10499 use crate :: arrays:: scalar_fn:: ScalarFnFactoryExt ;
@@ -227,8 +222,8 @@ mod tests {
227222 . try_new_array ( 3 , Operator :: Lt , [ constant_ext, ext_array] )
228223 . unwrap ( ) ;
229224
230- let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
231- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
225+ let optimized = scalar_fn_array. optimize_recursive ( ) . unwrap ( ) ;
226+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
232227 let children = scalar_fn. children ( ) ;
233228 let constant = children[ 0 ]
234229 . as_opt :: < Constant > ( )
@@ -291,7 +286,7 @@ mod tests {
291286 let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
292287
293288 // The first child should still be an ExtensionArray (no pushdown happened)
294- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
289+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
295290 assert ! (
296291 scalar_fn. children( ) [ 0 ] . as_opt:: <Extension >( ) . is_some( ) ,
297292 "Expected first child to remain ExtensionArray when ext types differ"
@@ -316,7 +311,7 @@ mod tests {
316311 let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
317312
318313 // No pushdown should happen because sibling is not a constant
319- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
314+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
320315 assert ! (
321316 scalar_fn. children( ) [ 0 ] . as_opt:: <Extension >( ) . is_some( ) ,
322317 "Expected first child to remain ExtensionArray when sibling is not constant"
@@ -339,7 +334,7 @@ mod tests {
339334 let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
340335
341336 // No pushdown should happen because constant is not an extension scalar
342- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
337+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
343338 assert ! (
344339 scalar_fn. children( ) [ 0 ] . as_opt:: <Extension >( ) . is_some( ) ,
345340 "Expected first child to remain ExtensionArray when constant is not extension"
0 commit comments