Skip to content

Commit 967066b

Browse files
committed
Support returning row count in prunning aggregate expressions
This lets us effectively prune expressions like IsNotNull Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 1838a7a commit 967066b

23 files changed

Lines changed: 889 additions & 180 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

java/testfiles/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust-toolchain.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[toolchain]
22
channel = "1.91.0"
33
components = ["rust-src", "rustfmt", "clippy", "rust-analyzer"]
4-
profile = "minimal"
4+
profile = "minimal"

vortex-array/public-api.lock

Lines changed: 198 additions & 0 deletions
Large diffs are not rendered by default.

vortex-array/src/aggregate_fn/fns/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ pub mod is_sorted;
88
pub mod last;
99
pub mod min_max;
1010
pub mod nan_count;
11+
pub mod row_count;
1112
pub mod sum;
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexExpect;
5+
use vortex_error::VortexResult;
6+
7+
use crate::ArrayRef;
8+
use crate::Columnar;
9+
use crate::ExecutionCtx;
10+
use crate::aggregate_fn::AggregateFnId;
11+
use crate::aggregate_fn::AggregateFnVTable;
12+
use crate::aggregate_fn::EmptyOptions;
13+
use crate::dtype::DType;
14+
use crate::dtype::Nullability;
15+
use crate::dtype::PType;
16+
use crate::scalar::Scalar;
17+
18+
/// Aggregate that returns the input length, including nulls.
19+
///
20+
/// Applies to all input dtypes, returns a non-nullable `u64`, and has zero as
21+
/// its identity value.
22+
///
23+
/// Unlike [`Count`][crate::aggregate_fn::fns::count::Count], `RowCount`
24+
/// includes null elements. It is primarily used in pruning predicates that need
25+
/// to compare a statistic with the current evaluation scope's row count.
26+
#[derive(Clone, Debug)]
27+
pub struct RowCount;
28+
29+
impl AggregateFnVTable for RowCount {
30+
type Options = EmptyOptions;
31+
type Partial = u64;
32+
33+
fn id(&self) -> AggregateFnId {
34+
AggregateFnId::new("vortex.row_count")
35+
}
36+
37+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
38+
unimplemented!("RowCount is not yet serializable");
39+
}
40+
41+
fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
42+
Some(DType::Primitive(PType::U64, Nullability::NonNullable))
43+
}
44+
45+
fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
46+
self.return_dtype(options, input_dtype)
47+
}
48+
49+
fn empty_partial(
50+
&self,
51+
_options: &Self::Options,
52+
_input_dtype: &DType,
53+
) -> VortexResult<Self::Partial> {
54+
Ok(0u64)
55+
}
56+
57+
fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
58+
let val = other
59+
.as_primitive()
60+
.typed_value::<u64>()
61+
.vortex_expect("row_count partial should not be null");
62+
*partial += val;
63+
Ok(())
64+
}
65+
66+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
67+
Ok(Scalar::primitive(*partial, Nullability::NonNullable))
68+
}
69+
70+
fn reset(&self, partial: &mut Self::Partial) {
71+
*partial = 0;
72+
}
73+
74+
#[inline]
75+
fn is_saturated(&self, _partial: &Self::Partial) -> bool {
76+
false
77+
}
78+
79+
fn try_accumulate(
80+
&self,
81+
state: &mut Self::Partial,
82+
batch: &ArrayRef,
83+
_ctx: &mut ExecutionCtx,
84+
) -> VortexResult<bool> {
85+
*state += batch.len() as u64;
86+
Ok(true)
87+
}
88+
89+
fn accumulate(
90+
&self,
91+
_partial: &mut Self::Partial,
92+
_batch: &Columnar,
93+
_ctx: &mut ExecutionCtx,
94+
) -> VortexResult<()> {
95+
unreachable!("RowCount::try_accumulate handles all arrays")
96+
}
97+
98+
fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
99+
Ok(partials)
100+
}
101+
102+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
103+
self.to_scalar(partial)
104+
}
105+
}
106+
107+
#[cfg(test)]
108+
mod tests {
109+
use std::sync::LazyLock;
110+
111+
use vortex_buffer::buffer;
112+
use vortex_error::VortexResult;
113+
use vortex_session::VortexSession;
114+
115+
use crate::IntoArray;
116+
use crate::VortexSessionExecute;
117+
use crate::aggregate_fn::Accumulator;
118+
use crate::aggregate_fn::DynAccumulator;
119+
use crate::aggregate_fn::EmptyOptions;
120+
use crate::aggregate_fn::fns::row_count::RowCount;
121+
use crate::arrays::PrimitiveArray;
122+
use crate::session::ArraySession;
123+
124+
static SESSION: LazyLock<VortexSession> =
125+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
126+
127+
#[test]
128+
fn row_count_all_valid() -> VortexResult<()> {
129+
let array = buffer![1i32, 2, 3, 4, 5].into_array();
130+
let mut acc = Accumulator::try_new(RowCount, EmptyOptions, array.dtype().clone())?;
131+
acc.accumulate(&array, &mut SESSION.create_execution_ctx())?;
132+
let result = acc.finish()?;
133+
assert_eq!(result.as_primitive().typed_value::<u64>(), Some(5));
134+
Ok(())
135+
}
136+
137+
#[test]
138+
fn row_count_includes_nulls() -> VortexResult<()> {
139+
let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
140+
.into_array();
141+
let mut acc = Accumulator::try_new(RowCount, EmptyOptions, array.dtype().clone())?;
142+
acc.accumulate(&array, &mut SESSION.create_execution_ctx())?;
143+
let result = acc.finish()?;
144+
assert_eq!(result.as_primitive().typed_value::<u64>(), Some(5));
145+
Ok(())
146+
}
147+
}

vortex-array/src/aggregate_fn/session.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::aggregate_fn::fns::is_sorted::IsSorted;
1818
use crate::aggregate_fn::fns::last::Last;
1919
use crate::aggregate_fn::fns::min_max::MinMax;
2020
use crate::aggregate_fn::fns::nan_count::NanCount;
21+
use crate::aggregate_fn::fns::row_count::RowCount;
2122
use crate::aggregate_fn::fns::sum::Sum;
2223
use crate::aggregate_fn::kernels::DynAggregateKernel;
2324
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
@@ -59,6 +60,7 @@ impl Default for AggregateFnSession {
5960
this.register(Last);
6061
this.register(MinMax);
6162
this.register(NanCount);
63+
this.register(RowCount);
6264
this.register(Sum);
6365

6466
// Register the built-in aggregate kernels.

vortex-array/src/expr/exprs.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ use vortex_error::VortexExpect;
99
use vortex_error::vortex_panic;
1010
use vortex_utils::iter::ReduceBalancedIterExt;
1111

12+
use crate::aggregate_fn::AggregateFnRef;
13+
use crate::aggregate_fn::AggregateFnVTableExt;
14+
use crate::aggregate_fn::fns::row_count::RowCount;
1215
use crate::dtype::DType;
1316
use crate::dtype::FieldName;
1417
use crate::dtype::FieldNames;
@@ -46,6 +49,7 @@ use crate::scalar_fn::fns::pack::PackOptions;
4649
use crate::scalar_fn::fns::root::Root;
4750
use crate::scalar_fn::fns::select::FieldSelection;
4851
use crate::scalar_fn::fns::select::Select;
52+
use crate::scalar_fn::fns::stats_expression::StatsExpression;
4953
use crate::scalar_fn::fns::zip::Zip;
5054

5155
// ---- Root ----
@@ -701,3 +705,21 @@ pub fn dynamic(
701705
pub fn list_contains(list: Expression, value: Expression) -> Expression {
702706
ListContains.new_expr(EmptyOptions, [list, value])
703707
}
708+
709+
// ---- StatsExpression ----
710+
711+
/// Creates a placeholder expression for an aggregate-derived scope statistic.
712+
///
713+
/// The caller that owns the evaluation scope must substitute the placeholder
714+
/// before execution; see [`StatsExpression`].
715+
pub fn stats_expression(agg: AggregateFnRef) -> Expression {
716+
StatsExpression.new_expr(agg, [])
717+
}
718+
719+
/// Creates a placeholder for the current evaluation scope's row count.
720+
///
721+
/// This is used by pruning rewrites that need to compare a stored statistic with
722+
/// the number of rows in the file, zone, or other scope being evaluated.
723+
pub fn row_count() -> Expression {
724+
stats_expression(RowCount.bind(crate::aggregate_fn::EmptyOptions))
725+
}

vortex-array/src/expr/pruning/pruning_expr.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa
8686
/// cannot hold, and false if it cannot be determined from stats alone whether the positions can
8787
/// be pruned.
8888
///
89+
/// Some rewrites, such as `is_not_null(...)`, emit
90+
/// [`row_count`][crate::expr::row_count] placeholders. The evaluation layer must
91+
/// replace those placeholders with the row count for its current scope before
92+
/// executing the returned expression.
93+
///
8994
/// If the falsification logic attempts to access an unknown stat,
9095
/// this function will return `None`.
9196
pub fn checked_pruning_expr(

vortex-array/src/scalar_fn/fns/is_not_null.rs

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@ use crate::dtype::DType;
1414
use crate::dtype::Nullability;
1515
use crate::expr::Expression;
1616
use crate::expr::StatsCatalog;
17-
use crate::expr::and;
1817
use crate::expr::eq;
19-
use crate::expr::gt;
20-
use crate::expr::lit;
18+
use crate::expr::row_count;
2119
use crate::expr::stats::Stat;
2220
use crate::scalar_fn::Arity;
2321
use crate::scalar_fn::ChildName;
@@ -106,20 +104,10 @@ impl ScalarFnVTable for IsNotNull {
106104
expr: &Expression,
107105
catalog: &dyn StatsCatalog,
108106
) -> Option<Expression> {
109-
// is_not_null is falsified when ALL values are null, i.e. null_count == len.
110-
// Since there is no len stat in the zone map, we approximate using IsConstant:
111-
// if the zone is constant and has any nulls, then all values must be null.
112-
//
113-
// TODO(#7187): Add a len stat to enable the more general falsification:
114-
// null_count == len => is_not_null is all false.
115-
let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
116-
let is_constant_expr = expr.child(0).stat_expression(Stat::IsConstant, catalog)?;
117-
// If the zone is constant (is_constant == true) and has nulls (null_count > 0),
118-
// then all values must be null, so is_not_null is all false.
119-
Some(and(
120-
eq(is_constant_expr, lit(true)),
121-
gt(null_count_expr, lit(0u64)),
122-
))
107+
// is_not_null is falsified when ALL values are null, i.e. null_count == row_count.
108+
let child = expr.child(0);
109+
let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?;
110+
Some(eq(null_count_expr, row_count()))
123111
}
124112
}
125113

@@ -267,38 +255,27 @@ mod tests {
267255
use crate::dtype::Field;
268256
use crate::dtype::FieldPath;
269257
use crate::dtype::FieldPathSet;
270-
use crate::expr::and;
271258
use crate::expr::col;
272259
use crate::expr::eq;
273-
use crate::expr::gt;
274-
use crate::expr::lit;
275260
use crate::expr::pruning::checked_pruning_expr;
261+
use crate::expr::row_count;
276262
use crate::expr::stats::Stat;
277263

278264
let expr = is_not_null(col("a"));
279265

280266
let (pruning_expr, st) = checked_pruning_expr(
281267
&expr,
282-
&FieldPathSet::from_iter([
283-
FieldPath::from_iter([Field::Name("a".into()), Field::Name("null_count".into())]),
284-
FieldPath::from_iter([Field::Name("a".into()), Field::Name("is_constant".into())]),
285-
]),
268+
&FieldPathSet::from_iter([FieldPath::from_iter([
269+
Field::Name("a".into()),
270+
Field::Name("null_count".into()),
271+
])]),
286272
)
287273
.unwrap();
288274

289-
assert_eq!(
290-
&pruning_expr,
291-
&and(
292-
eq(col("a_is_constant"), lit(true)),
293-
gt(col("a_null_count"), lit(0u64)),
294-
)
295-
);
275+
assert_eq!(&pruning_expr, &eq(col("a_null_count"), row_count()));
296276
assert_eq!(
297277
st.map(),
298-
&HashMap::from_iter([(
299-
FieldPath::from_name("a"),
300-
HashSet::from([Stat::NullCount, Stat::IsConstant])
301-
)])
278+
&HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
302279
);
303280
}
304281
}

0 commit comments

Comments
 (0)