@@ -14,52 +14,44 @@ use crate::arrays::Filter;
1414use crate :: arrays:: extension:: ExtensionArrayExt ;
1515use crate :: arrays:: filter:: FilterReduceAdaptor ;
1616use crate :: arrays:: slice:: SliceReduceAdaptor ;
17- use crate :: matcher:: AnyArray ;
1817use crate :: optimizer:: rules:: ArrayParentReduceRule ;
18+ use crate :: optimizer:: rules:: ArrayReduceRule ;
1919use crate :: optimizer:: rules:: ParentRuleSet ;
20+ use crate :: optimizer:: rules:: ReduceRuleSet ;
2021use crate :: scalar:: Scalar ;
2122use crate :: scalar_fn:: fns:: cast:: CastReduceAdaptor ;
2223use crate :: scalar_fn:: fns:: mask:: MaskReduceAdaptor ;
2324
24- pub ( crate ) const PARENT_RULES : ParentRuleSet < Extension > = ParentRuleSet :: new ( & [
25- ParentRuleSet :: lift ( & ExtensionConstantParentRule ) ,
26- ParentRuleSet :: lift ( & ExtensionFilterPushDownRule ) ,
27- ParentRuleSet :: lift ( & CastReduceAdaptor ( Extension ) ) ,
28- ParentRuleSet :: lift ( & FilterReduceAdaptor ( Extension ) ) ,
29- ParentRuleSet :: lift ( & MaskReduceAdaptor ( Extension ) ) ,
30- ParentRuleSet :: lift ( & SliceReduceAdaptor ( Extension ) ) ,
31- ] ) ;
25+ pub ( crate ) const RULES : ReduceRuleSet < Extension > = ReduceRuleSet :: new ( & [ & ExtensionConstantRule ] ) ;
3226
3327/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`.
3428#[ derive( Debug ) ]
35- struct ExtensionConstantParentRule ;
36-
37- impl ArrayParentReduceRule < Extension > for ExtensionConstantParentRule {
38- type Parent = AnyArray ;
29+ struct ExtensionConstantRule ;
3930
40- fn reduce_parent (
41- & self ,
42- child : ArrayView < ' _ , Extension > ,
43- parent : & ArrayRef ,
44- child_idx : usize ,
45- ) -> VortexResult < Option < ArrayRef > > {
46- let Some ( const_array) = child. storage_array ( ) . as_opt :: < Constant > ( ) else {
31+ impl ArrayReduceRule < Extension > for ExtensionConstantRule {
32+ fn reduce ( & self , array : ArrayView < ' _ , Extension > ) -> VortexResult < Option < ArrayRef > > {
33+ let Some ( const_array) = array. storage_array ( ) . as_opt :: < Constant > ( ) else {
4734 return Ok ( None ) ;
4835 } ;
4936
5037 let storage_scalar = const_array. scalar ( ) . clone ( ) ;
51- let ext_scalar = Scalar :: extension_ref ( child . ext_dtype ( ) . clone ( ) , storage_scalar) ;
38+ let ext_scalar = Scalar :: extension_ref ( array . ext_dtype ( ) . clone ( ) , storage_scalar) ;
5239
5340 let constant_with_extension_scalar =
54- ConstantArray :: new ( ext_scalar, child . len ( ) ) . into_array ( ) ;
41+ ConstantArray :: new ( ext_scalar, array . len ( ) ) . into_array ( ) ;
5542
56- parent
57- . clone ( )
58- . with_slot ( child_idx, constant_with_extension_scalar)
59- . map ( Some )
43+ Ok ( Some ( constant_with_extension_scalar. into_array ( ) ) )
6044 }
6145}
6246
47+ pub ( crate ) const PARENT_RULES : ParentRuleSet < Extension > = ParentRuleSet :: new ( & [
48+ ParentRuleSet :: lift ( & ExtensionFilterPushDownRule ) ,
49+ ParentRuleSet :: lift ( & CastReduceAdaptor ( Extension ) ) ,
50+ ParentRuleSet :: lift ( & FilterReduceAdaptor ( Extension ) ) ,
51+ ParentRuleSet :: lift ( & MaskReduceAdaptor ( Extension ) ) ,
52+ ParentRuleSet :: lift ( & SliceReduceAdaptor ( Extension ) ) ,
53+ ] ) ;
54+
6355/// Push filter operations into the storage array of an extension array.
6456#[ derive( Debug ) ]
6557struct ExtensionFilterPushDownRule ;
@@ -99,6 +91,7 @@ mod tests {
9991 use crate :: arrays:: ExtensionArray ;
10092 use crate :: arrays:: FilterArray ;
10193 use crate :: arrays:: PrimitiveArray ;
94+ use crate :: arrays:: ScalarFnVTable ;
10295 use crate :: arrays:: extension:: ExtensionArrayExt ;
10396 use crate :: arrays:: scalar_fn:: ScalarFnArrayExt ;
10497 use crate :: arrays:: scalar_fn:: ScalarFnFactoryExt ;
@@ -227,8 +220,8 @@ mod tests {
227220 . try_new_array ( 3 , Operator :: Lt , [ constant_ext, ext_array] )
228221 . unwrap ( ) ;
229222
230- let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
231- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
223+ let optimized = scalar_fn_array. optimize_recursive ( ) . unwrap ( ) ;
224+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
232225 let children = scalar_fn. children ( ) ;
233226 let constant = children[ 0 ]
234227 . as_opt :: < Constant > ( )
@@ -291,7 +284,7 @@ mod tests {
291284 let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
292285
293286 // The first child should still be an ExtensionArray (no pushdown happened)
294- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
287+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
295288 assert ! (
296289 scalar_fn. children( ) [ 0 ] . as_opt:: <Extension >( ) . is_some( ) ,
297290 "Expected first child to remain ExtensionArray when ext types differ"
@@ -316,7 +309,7 @@ mod tests {
316309 let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
317310
318311 // No pushdown should happen because sibling is not a constant
319- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
312+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
320313 assert ! (
321314 scalar_fn. children( ) [ 0 ] . as_opt:: <Extension >( ) . is_some( ) ,
322315 "Expected first child to remain ExtensionArray when sibling is not constant"
@@ -339,7 +332,7 @@ mod tests {
339332 let optimized = scalar_fn_array. optimize ( ) . unwrap ( ) ;
340333
341334 // No pushdown should happen because constant is not an extension scalar
342- let scalar_fn = optimized. as_opt :: < crate :: arrays :: ScalarFnVTable > ( ) . unwrap ( ) ;
335+ let scalar_fn = optimized. as_opt :: < ScalarFnVTable > ( ) . unwrap ( ) ;
343336 assert ! (
344337 scalar_fn. children( ) [ 0 ] . as_opt:: <Extension >( ) . is_some( ) ,
345338 "Expected first child to remain ExtensionArray when constant is not extension"
0 commit comments