33
44mod kernel;
55use std:: fmt:: Formatter ;
6+ use std:: sync:: Arc ;
67
78pub use kernel:: * ;
89use vortex_error:: VortexExpect ;
@@ -30,8 +31,12 @@ use crate::scalar_fn::Arity;
3031use crate :: scalar_fn:: ChildName ;
3132use crate :: scalar_fn:: EmptyOptions ;
3233use crate :: scalar_fn:: ExecutionArgs ;
34+ use crate :: scalar_fn:: ReduceCtx ;
35+ use crate :: scalar_fn:: ReduceNode ;
36+ use crate :: scalar_fn:: ReduceNodeRef ;
3337use crate :: scalar_fn:: ScalarFnId ;
3438use crate :: scalar_fn:: ScalarFnVTable ;
39+ use crate :: scalar_fn:: ScalarFnVTableExt ;
3540use crate :: scalar_fn:: SimplifyCtx ;
3641use crate :: scalar_fn:: fns:: literal:: Literal ;
3742
@@ -111,6 +116,43 @@ impl ScalarFnVTable for Mask {
111116 execute_canonical ( input, mask_array, ctx)
112117 }
113118
119+ fn reduce (
120+ & self ,
121+ options : & Self :: Options ,
122+ node : & dyn ReduceNode ,
123+ ctx : & dyn ReduceCtx ,
124+ ) -> VortexResult < Option < ReduceNodeRef > > {
125+ _ = options;
126+ let input = node. child ( 0 ) ;
127+ let Some ( input_scalar_fn) = input. scalar_fn ( ) else {
128+ return Ok ( None ) ;
129+ } ;
130+
131+ // The null-sensitivity property is exactly whether this rewrite is valid.
132+ if input_scalar_fn. signature ( ) . is_null_sensitive ( ) {
133+ return Ok ( None ) ;
134+ }
135+
136+ // Zero-arity scalar functions (e.g. literals) have no children to push the mask into.
137+ if input. child_count ( ) == 0 {
138+ return Ok ( None ) ;
139+ }
140+
141+ let mask = node. child ( 1 ) ;
142+ let mut masked_children = Vec :: with_capacity ( input. child_count ( ) ) ;
143+ for child_idx in 0 ..input. child_count ( ) {
144+ let masked_child = ctx. new_node (
145+ Mask . bind ( EmptyOptions ) ,
146+ & [ input. child ( child_idx) , Arc :: clone ( & mask) ] ,
147+ ) ?;
148+ masked_children. push ( masked_child) ;
149+ }
150+
151+ Ok ( Some (
152+ ctx. new_node ( input_scalar_fn. clone ( ) , & masked_children) ?,
153+ ) )
154+ }
155+
114156 fn simplify (
115157 & self ,
116158 _options : & Self :: Options ,
@@ -193,12 +235,27 @@ fn execute_canonical(
193235mod test {
194236 use vortex_error:: VortexExpect ;
195237
238+ use super :: Mask ;
239+ use crate :: IntoArray ;
240+ use crate :: arrays:: BoolArray ;
241+ use crate :: arrays:: ScalarFnVTable ;
242+ use crate :: arrays:: scalar_fn:: ScalarFnArrayExt ;
243+ use crate :: arrays:: scalar_fn:: ScalarFnFactoryExt ;
196244 use crate :: dtype:: DType ;
245+ use crate :: dtype:: Nullability ;
197246 use crate :: dtype:: Nullability :: Nullable ;
198247 use crate :: dtype:: PType ;
248+ use crate :: dtype:: StructFields ;
249+ use crate :: expr:: col;
250+ use crate :: expr:: is_null;
199251 use crate :: expr:: lit;
200252 use crate :: expr:: mask;
253+ use crate :: expr:: not;
254+ use crate :: optimizer:: ArrayOptimizer ;
201255 use crate :: scalar:: Scalar ;
256+ use crate :: scalar_fn:: EmptyOptions ;
257+ use crate :: scalar_fn:: fns:: is_null:: IsNull ;
258+ use crate :: scalar_fn:: fns:: not:: Not ;
202259
203260 #[ test]
204261 fn test_simplify ( ) {
@@ -219,4 +276,90 @@ mod test {
219276 let expected_null_expr = lit ( Scalar :: null ( DType :: Primitive ( PType :: U32 , Nullable ) ) ) ;
220277 assert_eq ! ( & simplified_false, & expected_null_expr) ;
221278 }
279+
280+ #[ test]
281+ fn test_reduce_pushdown_expression_not ( ) {
282+ let scope = DType :: Struct (
283+ StructFields :: new (
284+ [ "bool1" , "m" ] . into ( ) ,
285+ vec ! [
286+ DType :: Bool ( Nullability :: NonNullable ) ,
287+ DType :: Bool ( Nullability :: NonNullable ) ,
288+ ] ,
289+ ) ,
290+ Nullability :: NonNullable ,
291+ ) ;
292+
293+ let expr = mask ( not ( col ( "bool1" ) ) , col ( "m" ) ) ;
294+ let reduced = expr. optimize ( & scope) . vortex_expect ( "optimize" ) ;
295+
296+ let expected = not ( mask ( col ( "bool1" ) , col ( "m" ) ) ) ;
297+ assert_eq ! ( reduced, expected) ;
298+ }
299+
300+ #[ test]
301+ fn test_reduce_no_pushdown_expression_null_sensitive ( ) {
302+ let scope = DType :: Struct (
303+ StructFields :: new (
304+ [ "bool1" , "m" ] . into ( ) ,
305+ vec ! [
306+ DType :: Bool ( Nullability :: NonNullable ) ,
307+ DType :: Bool ( Nullability :: NonNullable ) ,
308+ ] ,
309+ ) ,
310+ Nullability :: NonNullable ,
311+ ) ;
312+
313+ let expr = mask ( is_null ( col ( "bool1" ) ) , col ( "m" ) ) ;
314+ let reduced = expr. optimize ( & scope) . vortex_expect ( "optimize" ) ;
315+ assert_eq ! ( reduced, expr) ;
316+ }
317+
318+ #[ test]
319+ fn test_reduce_pushdown_array_not ( ) {
320+ let values = BoolArray :: from_iter ( [ true , false , true ] ) . into_array ( ) ;
321+ let mask_values = BoolArray :: from_iter ( [ true , false , true ] ) . into_array ( ) ;
322+
323+ let not_array = Not
324+ . try_new_array ( values. len ( ) , EmptyOptions , [ values] )
325+ . vortex_expect ( "not array" ) ;
326+ let mask_array = Mask
327+ . try_new_array ( mask_values. len ( ) , EmptyOptions , [ not_array, mask_values] )
328+ . vortex_expect ( "mask array" ) ;
329+
330+ let reduced = mask_array. optimize ( ) . vortex_expect ( "optimize" ) ;
331+ let reduced_sfn = reduced
332+ . as_opt :: < ScalarFnVTable > ( )
333+ . vortex_expect ( "expected scalar fn root" ) ;
334+ assert ! ( reduced_sfn. scalar_fn( ) . is:: <Not >( ) ) ;
335+
336+ let child = reduced_sfn. child_at ( 0 ) ;
337+ let child_sfn = child
338+ . as_opt :: < ScalarFnVTable > ( )
339+ . vortex_expect ( "expected masked child" ) ;
340+ assert ! ( child_sfn. scalar_fn( ) . is:: <Mask >( ) ) ;
341+ }
342+
343+ #[ test]
344+ fn test_reduce_no_pushdown_array_null_sensitive ( ) {
345+ let values = BoolArray :: from_iter ( [ true , false , true ] ) . into_array ( ) ;
346+ let mask_values = BoolArray :: from_iter ( [ true , false , true ] ) . into_array ( ) ;
347+
348+ let is_null_array = IsNull
349+ . try_new_array ( values. len ( ) , EmptyOptions , [ values] )
350+ . vortex_expect ( "is_null array" ) ;
351+ let mask_array = Mask
352+ . try_new_array (
353+ mask_values. len ( ) ,
354+ EmptyOptions ,
355+ [ is_null_array, mask_values] ,
356+ )
357+ . vortex_expect ( "mask array" ) ;
358+
359+ let reduced = mask_array. optimize ( ) . vortex_expect ( "optimize" ) ;
360+ let reduced_sfn = reduced
361+ . as_opt :: < ScalarFnVTable > ( )
362+ . vortex_expect ( "expected scalar fn root" ) ;
363+ assert ! ( reduced_sfn. scalar_fn( ) . is:: <Mask >( ) ) ;
364+ }
222365}
0 commit comments