Skip to content

Commit 5d6bbdb

Browse files
committed
fix[scalar_fns]: add correct pre-condition checks for all custom scalar_fn pushdown
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 2a0d1b7 commit 5d6bbdb

7 files changed

Lines changed: 71 additions & 0 deletions

File tree

vortex-array/src/scalar_fn/fns/binary/boolean.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use vortex_error::VortexResult;
77
use vortex_error::vortex_err;
88

99
use crate::ArrayRef;
10+
use crate::Canonical;
1011
use crate::DynArray;
1112
use crate::IntoArray;
1213
use crate::arrays::Constant;
@@ -39,6 +40,10 @@ pub(crate) fn execute_boolean(
3940
rhs: &ArrayRef,
4041
op: Operator,
4142
) -> VortexResult<ArrayRef> {
43+
if lhs.is_empty() {
44+
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
45+
return Ok(Canonical::empty(&DType::Bool(nullable.into())).into_array());
46+
}
4247
if let Some(result) = constant_boolean(lhs, rhs, op)? {
4348
return Ok(result);
4449
}

vortex-array/src/scalar_fn/fns/binary/numeric.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use vortex_error::VortexResult;
55

66
use crate::ArrayRef;
7+
use crate::Canonical;
78
use crate::IntoArray;
89
use crate::arrays::Constant;
910
use crate::arrays::ConstantArray;
@@ -20,6 +21,10 @@ pub(crate) fn execute_numeric(
2021
rhs: &ArrayRef,
2122
op: NumericOperator,
2223
) -> VortexResult<ArrayRef> {
24+
if lhs.is_empty() {
25+
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
26+
return Ok(Canonical::empty(&lhs.dtype().with_nullability(nullable.into())).into_array());
27+
}
2328
if let Some(result) = constant_numeric(lhs, rhs, op)? {
2429
return Ok(result);
2530
}

vortex-array/src/scalar_fn/fns/cast/kernel.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use vortex_error::VortexResult;
55

66
use crate::ArrayRef;
7+
use crate::Canonical;
78
use crate::ExecutionCtx;
89
use crate::IntoArray;
910
use crate::arrays::scalar_fn::ExactScalarFn;
@@ -60,6 +61,9 @@ where
6061
if array.dtype() == dtype {
6162
return Ok(Some(array.clone().into_array()));
6263
}
64+
if array.len() == 0 {
65+
return Ok(Some(Canonical::empty(dtype).into_array()));
66+
}
6367
<V as CastReduce>::cast(array, dtype)
6468
}
6569
}
@@ -85,6 +89,9 @@ where
8589
if array.dtype() == dtype {
8690
return Ok(Some(array.clone().into_array()));
8791
}
92+
if array.len() == 0 {
93+
return Ok(Some(Canonical::empty(dtype).into_array()));
94+
}
8895
<V as CastKernel>::cast(array, dtype, ctx)
8996
}
9097
}

vortex-array/src/scalar_fn/fns/cast/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ use vortex_session::VortexSession;
1515

1616
use crate::AnyColumnar;
1717
use crate::ArrayRef;
18+
use crate::Canonical;
1819
use crate::CanonicalView;
1920
use crate::ColumnarView;
2021
use crate::ExecutionCtx;
22+
use crate::IntoArray;
2123
use crate::arrays::Bool;
2224
use crate::arrays::Constant;
2325
use crate::arrays::ConstantArray;
@@ -113,6 +115,10 @@ impl ScalarFnVTable for Cast {
113115
return input.execute::<ArrayRef>(ctx)?.cast(target_dtype.clone());
114116
};
115117

118+
if columnar.as_ref().is_empty() {
119+
return Ok(Canonical::empty(target_dtype).into_array());
120+
}
121+
116122
match columnar {
117123
ColumnarView::Canonical(canonical) => {
118124
match cast_canonical(canonical.clone(), target_dtype, ctx)? {

vortex-array/src/scalar_fn/fns/fill_null/kernel.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
use vortex_error::VortexExpect;
55
use vortex_error::VortexResult;
6+
use vortex_error::vortex_ensure;
67

78
use crate::ArrayRef;
89
use crate::ExecutionCtx;
@@ -57,6 +58,11 @@ pub(super) fn precondition(
5758
array: &ArrayRef,
5859
fill_value: &Scalar,
5960
) -> VortexResult<Option<ArrayRef>> {
61+
vortex_ensure!(
62+
!fill_value.is_null(),
63+
"fill_null requires a non-null fill value"
64+
);
65+
6066
// If the array has no nulls, fill_null is a no-op (just cast for nullability).
6167
if !array.dtype().is_nullable() || array.all_valid()? {
6268
return array.to_array().cast(fill_value.dtype().clone()).map(Some);

vortex-array/src/scalar_fn/fns/fill_null/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ impl ScalarFnVTable for FillNull {
105105
.as_constant()
106106
.ok_or_else(|| vortex_err!("fill_null fill_value must be a constant/scalar"))?;
107107

108+
vortex_ensure!(
109+
!fill_scalar.is_null(),
110+
"fill_null requires a non-null fill value"
111+
);
112+
108113
let Some(columnar) = input.as_opt::<AnyColumnar>() else {
109114
return input.execute::<ArrayRef>(ctx)?.fill_null(fill_scalar);
110115
};

vortex-array/src/scalar_fn/fns/mask/kernel.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@ use vortex_error::vortex_err;
66

77
use crate::ArrayRef;
88
use crate::ExecutionCtx;
9+
use crate::IntoArray;
910
use crate::arrays::Bool;
11+
use crate::arrays::Constant;
12+
use crate::arrays::ConstantArray;
1013
use crate::arrays::scalar_fn::ExactScalarFn;
1114
use crate::arrays::scalar_fn::ScalarFnArrayView;
15+
use crate::builtins::ArrayBuiltins;
1216
use crate::kernel::ExecuteParentKernel;
1317
use crate::optimizer::rules::ArrayParentReduceRule;
18+
use crate::scalar::Scalar;
1419
use crate::scalar_fn::fns::mask::Mask as MaskExpr;
1520
use crate::vtable::VTable;
1621

@@ -49,6 +54,26 @@ pub trait MaskKernel: VTable {
4954
) -> VortexResult<Option<ArrayRef>>;
5055
}
5156

57+
/// If the mask is a constant boolean, handle the trivial cases and return `Some`.
58+
/// Returns `None` if the mask is not a constant.
59+
fn handle_constant_mask(
60+
array: &dyn crate::array::DynArray,
61+
mask: &ArrayRef,
62+
) -> VortexResult<Option<ArrayRef>> {
63+
if let Some(constant_mask) = mask.as_opt::<Constant>() {
64+
let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
65+
return if mask_value {
66+
array.to_array().cast(array.dtype().as_nullable()).map(Some)
67+
} else {
68+
Ok(Some(
69+
ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.len())
70+
.into_array(),
71+
))
72+
};
73+
}
74+
Ok(None)
75+
}
76+
5277
/// Adaptor that wraps a [`MaskReduce`] impl as an [`ArrayParentReduceRule`].
5378
#[derive(Default, Debug)]
5479
pub struct MaskReduceAdaptor<V>(pub V);
@@ -74,6 +99,12 @@ where
7499
let mask_child = parent
75100
.nth_child(1)
76101
.ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
102+
103+
// Handle trivial constant mask cases before dispatching to the encoding.
104+
if let Some(result) = handle_constant_mask(&**array, &mask_child)? {
105+
return Ok(Some(result));
106+
}
107+
77108
if mask_child.as_opt::<Bool>().is_none() {
78109
return Ok(None);
79110
};
@@ -105,6 +136,12 @@ where
105136
let mask_child = parent
106137
.nth_child(1)
107138
.ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
139+
140+
// Handle trivial constant mask cases before dispatching to the encoding.
141+
if let Some(result) = handle_constant_mask(&**array, &mask_child)? {
142+
return Ok(Some(result));
143+
}
144+
108145
<V as MaskKernel>::mask(array, &mask_child, ctx)
109146
}
110147
}

0 commit comments

Comments
 (0)