@@ -6,11 +6,16 @@ use vortex_error::vortex_err;
66
77use crate :: ArrayRef ;
88use crate :: ExecutionCtx ;
9+ use crate :: IntoArray ;
910use crate :: arrays:: Bool ;
11+ use crate :: arrays:: Constant ;
12+ use crate :: arrays:: ConstantArray ;
1013use crate :: arrays:: scalar_fn:: ExactScalarFn ;
1114use crate :: arrays:: scalar_fn:: ScalarFnArrayView ;
15+ use crate :: builtins:: ArrayBuiltins ;
1216use crate :: kernel:: ExecuteParentKernel ;
1317use crate :: optimizer:: rules:: ArrayParentReduceRule ;
18+ use crate :: scalar:: Scalar ;
1419use crate :: scalar_fn:: fns:: mask:: Mask as MaskExpr ;
1520use crate :: vtable:: VTable ;
1621
@@ -49,6 +54,26 @@ pub trait MaskKernel: VTable {
4954 ) -> VortexResult < Option < ArrayRef > > ;
5055}
5156
57+ /// If the mask is a constant boolean, handle the trivial cases and return `Some`.
58+ /// Returns `None` if the mask is not a constant.
59+ fn handle_constant_mask (
60+ array : & dyn crate :: array:: DynArray ,
61+ mask : & ArrayRef ,
62+ ) -> VortexResult < Option < ArrayRef > > {
63+ if let Some ( constant_mask) = mask. as_opt :: < Constant > ( ) {
64+ let mask_value = constant_mask. scalar ( ) . as_bool ( ) . value ( ) . unwrap_or ( false ) ;
65+ return if mask_value {
66+ array. to_array ( ) . cast ( array. dtype ( ) . as_nullable ( ) ) . map ( Some )
67+ } else {
68+ Ok ( Some (
69+ ConstantArray :: new ( Scalar :: null ( array. dtype ( ) . as_nullable ( ) ) , array. len ( ) )
70+ . into_array ( ) ,
71+ ) )
72+ } ;
73+ }
74+ Ok ( None )
75+ }
76+
5277/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`].
5378#[ derive( Default , Debug ) ]
5479pub struct MaskReduceAdaptor < V > ( pub V ) ;
7499 let mask_child = parent
75100 . nth_child ( 1 )
76101 . ok_or_else ( || vortex_err ! ( "Mask expression must have 2 children" ) ) ?;
102+
103+ // Handle trivial constant mask cases before dispatching to the encoding.
104+ if let Some ( result) = handle_constant_mask ( & * * array, & mask_child) ? {
105+ return Ok ( Some ( result) ) ;
106+ }
107+
77108 if mask_child. as_opt :: < Bool > ( ) . is_none ( ) {
78109 return Ok ( None ) ;
79110 } ;
@@ -105,6 +136,12 @@ where
105136 let mask_child = parent
106137 . nth_child ( 1 )
107138 . ok_or_else ( || vortex_err ! ( "Mask expression must have 2 children" ) ) ?;
139+
140+ // Handle trivial constant mask cases before dispatching to the encoding.
141+ if let Some ( result) = handle_constant_mask ( & * * array, & mask_child) ? {
142+ return Ok ( Some ( result) ) ;
143+ }
144+
108145 <V as MaskKernel >:: mask ( array, & mask_child, ctx)
109146 }
110147}
0 commit comments