Skip to content

Commit 56a16a8

Browse files
committed
Push mask through scalar function
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent a30de02 commit 56a16a8

1 file changed

Lines changed: 143 additions & 0 deletions

File tree

  • vortex-array/src/scalar_fn/fns/mask

vortex-array/src/scalar_fn/fns/mask/mod.rs

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
mod kernel;
55
use std::fmt::Formatter;
6+
use std::sync::Arc;
67

78
pub use kernel::*;
89
use vortex_error::VortexExpect;
@@ -30,8 +31,12 @@ use crate::scalar_fn::Arity;
3031
use crate::scalar_fn::ChildName;
3132
use crate::scalar_fn::EmptyOptions;
3233
use crate::scalar_fn::ExecutionArgs;
34+
use crate::scalar_fn::ReduceCtx;
35+
use crate::scalar_fn::ReduceNode;
36+
use crate::scalar_fn::ReduceNodeRef;
3337
use crate::scalar_fn::ScalarFnId;
3438
use crate::scalar_fn::ScalarFnVTable;
39+
use crate::scalar_fn::ScalarFnVTableExt;
3540
use crate::scalar_fn::SimplifyCtx;
3641
use crate::scalar_fn::fns::literal::Literal;
3742

@@ -111,6 +116,43 @@ impl ScalarFnVTable for Mask {
111116
execute_canonical(input, mask_array, ctx)
112117
}
113118

119+
fn reduce(
120+
&self,
121+
options: &Self::Options,
122+
node: &dyn ReduceNode,
123+
ctx: &dyn ReduceCtx,
124+
) -> VortexResult<Option<ReduceNodeRef>> {
125+
_ = options;
126+
let input = node.child(0);
127+
let Some(input_scalar_fn) = input.scalar_fn() else {
128+
return Ok(None);
129+
};
130+
131+
// The null-sensitivity property is exactly whether this rewrite is valid.
132+
if input_scalar_fn.signature().is_null_sensitive() {
133+
return Ok(None);
134+
}
135+
136+
// Zero-arity scalar functions (e.g. literals) have no children to push the mask into.
137+
if input.child_count() == 0 {
138+
return Ok(None);
139+
}
140+
141+
let mask = node.child(1);
142+
let mut masked_children = Vec::with_capacity(input.child_count());
143+
for child_idx in 0..input.child_count() {
144+
let masked_child = ctx.new_node(
145+
Mask.bind(EmptyOptions),
146+
&[input.child(child_idx), Arc::clone(&mask)],
147+
)?;
148+
masked_children.push(masked_child);
149+
}
150+
151+
Ok(Some(
152+
ctx.new_node(input_scalar_fn.clone(), &masked_children)?,
153+
))
154+
}
155+
114156
fn simplify(
115157
&self,
116158
_options: &Self::Options,
@@ -193,12 +235,27 @@ fn execute_canonical(
193235
mod test {
194236
use vortex_error::VortexExpect;
195237

238+
use super::Mask;
239+
use crate::IntoArray;
240+
use crate::arrays::BoolArray;
241+
use crate::arrays::ScalarFnVTable;
242+
use crate::arrays::scalar_fn::ScalarFnArrayExt;
243+
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
196244
use crate::dtype::DType;
245+
use crate::dtype::Nullability;
197246
use crate::dtype::Nullability::Nullable;
198247
use crate::dtype::PType;
248+
use crate::dtype::StructFields;
249+
use crate::expr::col;
250+
use crate::expr::is_null;
199251
use crate::expr::lit;
200252
use crate::expr::mask;
253+
use crate::expr::not;
254+
use crate::optimizer::ArrayOptimizer;
201255
use crate::scalar::Scalar;
256+
use crate::scalar_fn::EmptyOptions;
257+
use crate::scalar_fn::fns::is_null::IsNull;
258+
use crate::scalar_fn::fns::not::Not;
202259

203260
#[test]
204261
fn test_simplify() {
@@ -219,4 +276,90 @@ mod test {
219276
let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
220277
assert_eq!(&simplified_false, &expected_null_expr);
221278
}
279+
280+
#[test]
281+
fn test_reduce_pushdown_expression_not() {
282+
let scope = DType::Struct(
283+
StructFields::new(
284+
["bool1", "m"].into(),
285+
vec![
286+
DType::Bool(Nullability::NonNullable),
287+
DType::Bool(Nullability::NonNullable),
288+
],
289+
),
290+
Nullability::NonNullable,
291+
);
292+
293+
let expr = mask(not(col("bool1")), col("m"));
294+
let reduced = expr.optimize(&scope).vortex_expect("optimize");
295+
296+
let expected = not(mask(col("bool1"), col("m")));
297+
assert_eq!(reduced, expected);
298+
}
299+
300+
#[test]
301+
fn test_reduce_no_pushdown_expression_null_sensitive() {
302+
let scope = DType::Struct(
303+
StructFields::new(
304+
["bool1", "m"].into(),
305+
vec![
306+
DType::Bool(Nullability::NonNullable),
307+
DType::Bool(Nullability::NonNullable),
308+
],
309+
),
310+
Nullability::NonNullable,
311+
);
312+
313+
let expr = mask(is_null(col("bool1")), col("m"));
314+
let reduced = expr.optimize(&scope).vortex_expect("optimize");
315+
assert_eq!(reduced, expr);
316+
}
317+
318+
#[test]
319+
fn test_reduce_pushdown_array_not() {
320+
let values = BoolArray::from_iter([true, false, true]).into_array();
321+
let mask_values = BoolArray::from_iter([true, false, true]).into_array();
322+
323+
let not_array = Not
324+
.try_new_array(values.len(), EmptyOptions, [values])
325+
.vortex_expect("not array");
326+
let mask_array = Mask
327+
.try_new_array(mask_values.len(), EmptyOptions, [not_array, mask_values])
328+
.vortex_expect("mask array");
329+
330+
let reduced = mask_array.optimize().vortex_expect("optimize");
331+
let reduced_sfn = reduced
332+
.as_opt::<ScalarFnVTable>()
333+
.vortex_expect("expected scalar fn root");
334+
assert!(reduced_sfn.scalar_fn().is::<Not>());
335+
336+
let child = reduced_sfn.child_at(0);
337+
let child_sfn = child
338+
.as_opt::<ScalarFnVTable>()
339+
.vortex_expect("expected masked child");
340+
assert!(child_sfn.scalar_fn().is::<Mask>());
341+
}
342+
343+
#[test]
344+
fn test_reduce_no_pushdown_array_null_sensitive() {
345+
let values = BoolArray::from_iter([true, false, true]).into_array();
346+
let mask_values = BoolArray::from_iter([true, false, true]).into_array();
347+
348+
let is_null_array = IsNull
349+
.try_new_array(values.len(), EmptyOptions, [values])
350+
.vortex_expect("is_null array");
351+
let mask_array = Mask
352+
.try_new_array(
353+
mask_values.len(),
354+
EmptyOptions,
355+
[is_null_array, mask_values],
356+
)
357+
.vortex_expect("mask array");
358+
359+
let reduced = mask_array.optimize().vortex_expect("optimize");
360+
let reduced_sfn = reduced
361+
.as_opt::<ScalarFnVTable>()
362+
.vortex_expect("expected scalar fn root");
363+
assert!(reduced_sfn.scalar_fn().is::<Mask>());
364+
}
222365
}

0 commit comments

Comments
 (0)