Skip to content

Commit 033c73f

Browse files
committed
Centralize stat expression binding
Signed-off-by: "Nicholas Gates" <nick@nickgates.com>
1 parent 02319b3 commit 033c73f

5 files changed

Lines changed: 256 additions & 346 deletions

File tree

vortex-array/src/expr/pruning/pruning_expr.rs

Lines changed: 39 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@ use vortex_session::VortexSession;
1010
use vortex_utils::aliases::hash_map::HashMap;
1111

1212
use super::relation::Relation;
13-
use crate::aggregate_fn::fns::all_nan::AllNan;
14-
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
15-
use crate::aggregate_fn::fns::all_non_null::AllNonNull;
16-
use crate::aggregate_fn::fns::all_null::AllNull;
17-
use crate::aggregate_fn::fns::nan_count::NanCount;
1813
use crate::dtype::DType;
1914
use crate::dtype::Field;
2015
use crate::dtype::FieldName;
@@ -23,18 +18,11 @@ use crate::dtype::FieldPathSet;
2318
use crate::expr::Expression;
2419
use crate::expr::StatsCatalog;
2520
use crate::expr::analysis::referenced_field_paths;
26-
use crate::expr::eq;
2721
use crate::expr::get_item;
28-
use crate::expr::lit;
2922
use crate::expr::root;
3023
use crate::expr::stats::Stat;
31-
use crate::expr::traversal::NodeExt;
32-
use crate::expr::traversal::Transformed;
33-
use crate::scalar::Scalar;
34-
use crate::scalar_fn::EmptyOptions;
35-
use crate::scalar_fn::ScalarFnVTableExt;
36-
use crate::scalar_fn::fns::stat::StatFn;
37-
use crate::scalar_fn::internal::row_count::RowCount;
24+
use crate::stats::bind::StatBinder;
25+
use crate::stats::bind::bind_stats;
3826

3927
pub type RequiredStats = Relation<FieldPath, Stat>;
4028

@@ -146,146 +134,54 @@ pub fn checked_pruning_expr_with_session(
146134
return Ok(None);
147135
};
148136

149-
lower_stat_fns(predicate, scope, available_stats)
150-
}
151-
152-
fn lower_stat_fns(
153-
predicate: Expression,
154-
scope: &DType,
155-
available_stats: &FieldPathSet,
156-
) -> VortexResult<Option<(Expression, RequiredStats)>> {
157-
let mut required_stats = Relation::new();
158-
let mut missing_stat = false;
159-
let lowered = predicate
160-
.transform_down(|expr| {
161-
if !expr.is::<StatFn>() {
162-
return Ok(Transformed::no(expr));
163-
}
164-
165-
if let Some(lowered) =
166-
lower_stat_fn(&expr, scope, available_stats, &mut required_stats)?
167-
{
168-
return Ok(Transformed::yes(lowered));
169-
}
170-
171-
missing_stat = true;
172-
let dtype = expr.return_dtype(scope)?;
173-
Ok(Transformed::yes(null_expr(dtype)))
174-
})?
175-
.into_inner();
176-
177-
if missing_stat {
137+
let mut binder = RequiredStatsBinder {
138+
scope,
139+
available_stats,
140+
required_stats: Relation::new(),
141+
};
142+
let Some(lowered) = bind_stats(predicate, &mut binder)? else {
178143
return Ok(None);
179-
}
144+
};
180145

181-
Ok(Some((lowered, required_stats)))
146+
Ok(Some((lowered, binder.required_stats)))
182147
}
183148

184-
fn lower_stat_fn(
185-
expr: &Expression,
186-
scope: &DType,
187-
available_stats: &FieldPathSet,
188-
required_stats: &mut RequiredStats,
189-
) -> VortexResult<Option<Expression>> {
190-
let options = expr.as_::<StatFn>();
191-
let aggregate_fn = options.aggregate_fn();
192-
let input = expr.child(0);
193-
let input_dtype = input.return_dtype(scope)?;
194-
195-
if aggregate_fn.is::<AllNan>() {
196-
if !has_nans(&input_dtype) {
197-
return Ok(Some(lit(false)));
198-
}
199-
return lower_stat_ref(
200-
input,
201-
Stat::NaNCount,
202-
scope,
203-
available_stats,
204-
required_stats,
205-
)
206-
.map(|stat| stat.map(|stat| eq(stat, row_count_expr())));
207-
}
208-
209-
if aggregate_fn.is::<AllNonNan>() {
210-
if !has_nans(&input_dtype) {
211-
return Ok(Some(lit(true)));
212-
}
213-
return lower_stat_ref(
214-
input,
215-
Stat::NaNCount,
216-
scope,
217-
available_stats,
218-
required_stats,
219-
)
220-
.map(|stat| stat.map(|stat| eq(stat, lit(0u64))));
221-
}
149+
struct RequiredStatsBinder<'a> {
150+
scope: &'a DType,
151+
available_stats: &'a FieldPathSet,
152+
required_stats: RequiredStats,
153+
}
222154

223-
if aggregate_fn.is::<NanCount>() && !has_nans(&input_dtype) {
224-
return Ok(Some(lit(0u64)));
155+
impl StatBinder for RequiredStatsBinder<'_> {
156+
fn scope(&self) -> &DType {
157+
self.scope
225158
}
226159

227-
if aggregate_fn.is::<AllNull>() {
228-
return lower_stat_ref(
229-
input,
230-
Stat::NullCount,
231-
scope,
232-
available_stats,
233-
required_stats,
234-
)
235-
.map(|stat| stat.map(|stat| eq(stat, row_count_expr())));
236-
}
160+
fn bind_stat(
161+
&mut self,
162+
input: &Expression,
163+
stat: Stat,
164+
_stat_dtype: &DType,
165+
) -> VortexResult<Option<Expression>> {
166+
let field_paths = referenced_field_paths(input, self.scope)?;
167+
let Some(field_path) = field_paths.iter().exactly_one().ok() else {
168+
return Ok(None);
169+
};
170+
let stat_path = field_path.clone().push(stat.name());
171+
if !self.available_stats.contains(&stat_path) {
172+
return Ok(None);
173+
}
237174

238-
if aggregate_fn.is::<AllNonNull>() {
239-
return lower_stat_ref(
240-
input,
241-
Stat::NullCount,
242-
scope,
243-
available_stats,
244-
required_stats,
245-
)
246-
.map(|stat| stat.map(|stat| eq(stat, lit(0u64))));
175+
self.required_stats.insert(field_path.clone(), stat);
176+
Ok(Some(get_item(
177+
field_path_stat_field_name(field_path, stat),
178+
root(),
179+
)))
247180
}
248181

249-
let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
250-
return Ok(None);
251-
};
252-
253-
lower_stat_ref(input, stat, scope, available_stats, required_stats)
254-
}
255-
256-
fn lower_stat_ref(
257-
input: &Expression,
258-
stat: Stat,
259-
scope: &DType,
260-
available_stats: &FieldPathSet,
261-
required_stats: &mut RequiredStats,
262-
) -> VortexResult<Option<Expression>> {
263-
let field_paths = referenced_field_paths(input, scope)?;
264-
let Some(field_path) = field_paths.iter().exactly_one().ok() else {
265-
return Ok(None);
266-
};
267-
let stat_path = field_path.clone().push(stat.name());
268-
if !available_stats.contains(&stat_path) {
269-
return Ok(None);
182+
fn missing_stat(&mut self, _dtype: DType) -> VortexResult<Option<Expression>> {
183+
Ok(None)
270184
}
271-
272-
required_stats.insert(field_path.clone(), stat);
273-
Ok(Some(get_item(
274-
field_path_stat_field_name(field_path, stat),
275-
root(),
276-
)))
277-
}
278-
279-
fn row_count_expr() -> Expression {
280-
RowCount.new_expr(EmptyOptions, [])
281-
}
282-
283-
fn null_expr(dtype: DType) -> Expression {
284-
lit(Scalar::null(dtype.as_nullable()))
285-
}
286-
287-
fn has_nans(dtype: &DType) -> bool {
288-
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
289185
}
290186

291187
#[cfg(test)]

vortex-array/src/stats/bind.rs

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Bind abstract `vortex.stat` expressions to a concrete stats representation.
5+
6+
use vortex_error::VortexResult;
7+
8+
use crate::aggregate_fn::fns::all_nan::AllNan;
9+
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
10+
use crate::aggregate_fn::fns::all_non_null::AllNonNull;
11+
use crate::aggregate_fn::fns::all_null::AllNull;
12+
use crate::aggregate_fn::fns::nan_count::NanCount;
13+
use crate::dtype::DType;
14+
use crate::expr::Expression;
15+
use crate::expr::eq;
16+
use crate::expr::lit;
17+
use crate::expr::stats::Stat;
18+
use crate::expr::traversal::NodeExt;
19+
use crate::expr::traversal::Transformed;
20+
use crate::scalar::Scalar;
21+
use crate::scalar_fn::EmptyOptions;
22+
use crate::scalar_fn::ScalarFnVTableExt;
23+
use crate::scalar_fn::fns::stat::StatFn;
24+
use crate::scalar_fn::internal::row_count::RowCount;
25+
26+
/// A target that can bind abstract statistics to concrete expressions.
27+
pub trait StatBinder {
28+
/// The dtype scope used to type-check expressions before stats are bound.
29+
fn scope(&self) -> &DType;
30+
31+
/// Bind `stat(input)` to a concrete expression.
32+
///
33+
/// Returning `Ok(None)` marks the stat as unavailable. [`bind_stats`] will
34+
/// then call [`Self::missing_stat`] with the dtype expected from the
35+
/// original `vortex.stat` expression.
36+
fn bind_stat(
37+
&mut self,
38+
input: &Expression,
39+
stat: Stat,
40+
stat_dtype: &DType,
41+
) -> VortexResult<Option<Expression>>;
42+
43+
/// Expression to use when a stat is unavailable.
44+
///
45+
/// The default is a nullable null literal, which preserves three-valued
46+
/// pruning semantics for stats-table execution. Catalog-like binders can
47+
/// return `Ok(None)` to reject expressions that require unavailable stats.
48+
fn missing_stat(&mut self, dtype: DType) -> VortexResult<Option<Expression>> {
49+
Ok(Some(null_expr(dtype)))
50+
}
51+
}
52+
53+
/// Bind all `vortex.stat` expressions in `predicate`.
54+
///
55+
/// The predicate is usually the output of a stats rewrite rule. This function
56+
/// centralizes the legacy aggregate/stat mapping: `all_null` and `all_nan`
57+
/// style aggregate expressions are expanded through exact count stats, while
58+
/// direct aggregate stats are delegated to the supplied binder.
59+
pub fn bind_stats(
60+
predicate: Expression,
61+
binder: &mut impl StatBinder,
62+
) -> VortexResult<Option<Expression>> {
63+
let scope = binder.scope().clone();
64+
let mut missing_stat = false;
65+
let lowered = predicate
66+
.transform_down(|expr| {
67+
if !expr.is::<StatFn>() {
68+
return Ok(Transformed::no(expr));
69+
}
70+
71+
match bind_stat_fn(&expr, &scope, binder)? {
72+
Some(bound) => Ok(Transformed::yes(bound)),
73+
None => {
74+
let dtype = expr.return_dtype(&scope)?;
75+
match binder.missing_stat(dtype.clone())? {
76+
Some(missing) => Ok(Transformed::yes(missing)),
77+
None => {
78+
missing_stat = true;
79+
Ok(Transformed::yes(null_expr(dtype)))
80+
}
81+
}
82+
}
83+
}
84+
})?
85+
.into_inner();
86+
87+
if missing_stat {
88+
return Ok(None);
89+
}
90+
91+
Ok(Some(lowered))
92+
}
93+
94+
fn bind_stat_fn(
95+
expr: &Expression,
96+
scope: &DType,
97+
binder: &mut impl StatBinder,
98+
) -> VortexResult<Option<Expression>> {
99+
let options = expr.as_::<StatFn>();
100+
let aggregate_fn = options.aggregate_fn();
101+
let input = expr.child(0);
102+
let input_dtype = input.return_dtype(scope)?;
103+
104+
if aggregate_fn.is::<AllNan>() {
105+
if !has_nans(&input_dtype) {
106+
return Ok(Some(lit(false)));
107+
}
108+
let stat_dtype = expr.return_dtype(scope)?;
109+
return Ok(binder
110+
.bind_stat(input, Stat::NaNCount, &stat_dtype)?
111+
.map(|stat| eq(stat, row_count_expr())));
112+
}
113+
114+
if aggregate_fn.is::<AllNonNan>() {
115+
if !has_nans(&input_dtype) {
116+
return Ok(Some(lit(true)));
117+
}
118+
let stat_dtype = expr.return_dtype(scope)?;
119+
return Ok(binder
120+
.bind_stat(input, Stat::NaNCount, &stat_dtype)?
121+
.map(|stat| eq(stat, lit(0u64))));
122+
}
123+
124+
if aggregate_fn.is::<NanCount>() && !has_nans(&input_dtype) {
125+
return Ok(Some(lit(0u64)));
126+
}
127+
128+
if aggregate_fn.is::<AllNull>() {
129+
let stat_dtype = expr.return_dtype(scope)?;
130+
return Ok(binder
131+
.bind_stat(input, Stat::NullCount, &stat_dtype)?
132+
.map(|stat| eq(stat, row_count_expr())));
133+
}
134+
135+
if aggregate_fn.is::<AllNonNull>() {
136+
let stat_dtype = expr.return_dtype(scope)?;
137+
return Ok(binder
138+
.bind_stat(input, Stat::NullCount, &stat_dtype)?
139+
.map(|stat| eq(stat, lit(0u64))));
140+
}
141+
142+
let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
143+
return Ok(None);
144+
};
145+
146+
let stat_dtype = expr.return_dtype(scope)?;
147+
binder.bind_stat(input, stat, &stat_dtype)
148+
}
149+
150+
fn row_count_expr() -> Expression {
151+
RowCount.new_expr(EmptyOptions, [])
152+
}
153+
154+
fn null_expr(dtype: DType) -> Expression {
155+
lit(Scalar::null(dtype.as_nullable()))
156+
}
157+
158+
fn has_nans(dtype: &DType) -> bool {
159+
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
160+
}

vortex-array/src/stats/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub use expr::sum;
1919
pub use stats_set::*;
2020

2121
mod array;
22+
pub mod bind;
2223
pub mod expr;
2324
pub mod flatbuffers;
2425
pub mod rewrite;

0 commit comments

Comments
 (0)