Skip to content

Commit 02319b3

Browse files
committed
Make stats rewrite rules public
Port file pruning to session stats rewrites Signed-off-by: "Nicholas Gates" <nick@nickgates.com>
1 parent 0a41704 commit 02319b3

7 files changed

Lines changed: 299 additions & 18 deletions

File tree

vortex-array/src/expr/pruning/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod relation;
66

77
pub use pruning_expr::RequiredStats;
88
pub use pruning_expr::checked_pruning_expr;
9+
pub use pruning_expr::checked_pruning_expr_with_session;
910
pub use pruning_expr::field_path_stat_field_name;
1011
pub use relation::Relation;
1112

vortex-array/src/expr/pruning/pruning_expr.rs

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,36 @@ use std::cell::RefCell;
55
use std::iter;
66

77
use itertools::Itertools;
8+
use vortex_error::VortexResult;
9+
use vortex_session::VortexSession;
810
use vortex_utils::aliases::hash_map::HashMap;
911

1012
use super::relation::Relation;
13+
use crate::aggregate_fn::fns::all_nan::AllNan;
14+
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
15+
use crate::aggregate_fn::fns::all_non_null::AllNonNull;
16+
use crate::aggregate_fn::fns::all_null::AllNull;
17+
use crate::aggregate_fn::fns::nan_count::NanCount;
18+
use crate::dtype::DType;
1119
use crate::dtype::Field;
1220
use crate::dtype::FieldName;
1321
use crate::dtype::FieldPath;
1422
use crate::dtype::FieldPathSet;
1523
use crate::expr::Expression;
1624
use crate::expr::StatsCatalog;
25+
use crate::expr::analysis::referenced_field_paths;
26+
use crate::expr::eq;
1727
use crate::expr::get_item;
28+
use crate::expr::lit;
1829
use crate::expr::root;
1930
use crate::expr::stats::Stat;
31+
use crate::expr::traversal::NodeExt;
32+
use crate::expr::traversal::Transformed;
33+
use crate::scalar::Scalar;
34+
use crate::scalar_fn::EmptyOptions;
35+
use crate::scalar_fn::ScalarFnVTableExt;
36+
use crate::scalar_fn::fns::stat::StatFn;
37+
use crate::scalar_fn::internal::row_count::RowCount;
2038

2139
pub type RequiredStats = Relation<FieldPath, Stat>;
2240

@@ -113,6 +131,163 @@ pub fn checked_pruning_expr(
113131
Some((expr, relation))
114132
}
115133

134+
/// Build a pruning expression using session-registered stats rewrite rules.
135+
///
136+
/// The returned expression is lowered to the same stats-table field references as
137+
/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in
138+
/// `available_stats`, this returns `Ok(None)`.
139+
pub fn checked_pruning_expr_with_session(
140+
expr: &Expression,
141+
scope: &DType,
142+
available_stats: &FieldPathSet,
143+
session: &VortexSession,
144+
) -> VortexResult<Option<(Expression, RequiredStats)>> {
145+
let Some(predicate) = expr.falsify(scope, session)? else {
146+
return Ok(None);
147+
};
148+
149+
lower_stat_fns(predicate, scope, available_stats)
150+
}
151+
152+
fn lower_stat_fns(
153+
predicate: Expression,
154+
scope: &DType,
155+
available_stats: &FieldPathSet,
156+
) -> VortexResult<Option<(Expression, RequiredStats)>> {
157+
let mut required_stats = Relation::new();
158+
let mut missing_stat = false;
159+
let lowered = predicate
160+
.transform_down(|expr| {
161+
if !expr.is::<StatFn>() {
162+
return Ok(Transformed::no(expr));
163+
}
164+
165+
if let Some(lowered) =
166+
lower_stat_fn(&expr, scope, available_stats, &mut required_stats)?
167+
{
168+
return Ok(Transformed::yes(lowered));
169+
}
170+
171+
missing_stat = true;
172+
let dtype = expr.return_dtype(scope)?;
173+
Ok(Transformed::yes(null_expr(dtype)))
174+
})?
175+
.into_inner();
176+
177+
if missing_stat {
178+
return Ok(None);
179+
}
180+
181+
Ok(Some((lowered, required_stats)))
182+
}
183+
184+
fn lower_stat_fn(
185+
expr: &Expression,
186+
scope: &DType,
187+
available_stats: &FieldPathSet,
188+
required_stats: &mut RequiredStats,
189+
) -> VortexResult<Option<Expression>> {
190+
let options = expr.as_::<StatFn>();
191+
let aggregate_fn = options.aggregate_fn();
192+
let input = expr.child(0);
193+
let input_dtype = input.return_dtype(scope)?;
194+
195+
if aggregate_fn.is::<AllNan>() {
196+
if !has_nans(&input_dtype) {
197+
return Ok(Some(lit(false)));
198+
}
199+
return lower_stat_ref(
200+
input,
201+
Stat::NaNCount,
202+
scope,
203+
available_stats,
204+
required_stats,
205+
)
206+
.map(|stat| stat.map(|stat| eq(stat, row_count_expr())));
207+
}
208+
209+
if aggregate_fn.is::<AllNonNan>() {
210+
if !has_nans(&input_dtype) {
211+
return Ok(Some(lit(true)));
212+
}
213+
return lower_stat_ref(
214+
input,
215+
Stat::NaNCount,
216+
scope,
217+
available_stats,
218+
required_stats,
219+
)
220+
.map(|stat| stat.map(|stat| eq(stat, lit(0u64))));
221+
}
222+
223+
if aggregate_fn.is::<NanCount>() && !has_nans(&input_dtype) {
224+
return Ok(Some(lit(0u64)));
225+
}
226+
227+
if aggregate_fn.is::<AllNull>() {
228+
return lower_stat_ref(
229+
input,
230+
Stat::NullCount,
231+
scope,
232+
available_stats,
233+
required_stats,
234+
)
235+
.map(|stat| stat.map(|stat| eq(stat, row_count_expr())));
236+
}
237+
238+
if aggregate_fn.is::<AllNonNull>() {
239+
return lower_stat_ref(
240+
input,
241+
Stat::NullCount,
242+
scope,
243+
available_stats,
244+
required_stats,
245+
)
246+
.map(|stat| stat.map(|stat| eq(stat, lit(0u64))));
247+
}
248+
249+
let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
250+
return Ok(None);
251+
};
252+
253+
lower_stat_ref(input, stat, scope, available_stats, required_stats)
254+
}
255+
256+
fn lower_stat_ref(
257+
input: &Expression,
258+
stat: Stat,
259+
scope: &DType,
260+
available_stats: &FieldPathSet,
261+
required_stats: &mut RequiredStats,
262+
) -> VortexResult<Option<Expression>> {
263+
let field_paths = referenced_field_paths(input, scope)?;
264+
let Some(field_path) = field_paths.iter().exactly_one().ok() else {
265+
return Ok(None);
266+
};
267+
let stat_path = field_path.clone().push(stat.name());
268+
if !available_stats.contains(&stat_path) {
269+
return Ok(None);
270+
}
271+
272+
required_stats.insert(field_path.clone(), stat);
273+
Ok(Some(get_item(
274+
field_path_stat_field_name(field_path, stat),
275+
root(),
276+
)))
277+
}
278+
279+
fn row_count_expr() -> Expression {
280+
RowCount.new_expr(EmptyOptions, [])
281+
}
282+
283+
fn null_expr(dtype: DType) -> Expression {
284+
lit(Scalar::null(dtype.as_nullable()))
285+
}
286+
287+
fn has_nans(dtype: &DType) -> bool {
288+
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
289+
}
290+
116291
#[cfg(test)]
117292
mod tests {
118293
use rstest::fixture;

vortex-array/src/stats/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub use stats_set::*;
2121
mod array;
2222
pub mod expr;
2323
pub mod flatbuffers;
24-
pub(crate) mod rewrite;
24+
pub mod rewrite;
2525
pub mod session;
2626
mod stats_set;
2727

vortex-array/src/stats/rewrite.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,15 @@ mod builtins;
2121
pub(crate) use builtins::register_builtins;
2222

2323
/// Shared reference to a stats rewrite rule.
24-
pub(crate) type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;
24+
pub type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;
2525

2626
/// A plugin-provided rule that rewrites predicates into stats-backed proof expressions.
2727
///
2828
/// A falsifier evaluates to `true` only when the original predicate is definitely false for the
2929
/// current stats scope. A satisfier evaluates to `true` only when the original predicate is
3030
/// definitely true for the current stats scope. Returning `None` means the rule cannot prove
3131
/// anything for the expression.
32-
#[allow(dead_code)]
33-
pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static {
32+
pub trait StatsRewriteRule: Debug + Send + Sync + 'static {
3433
/// The scalar function ID this rule applies to.
3534
fn scalar_fn_id(&self) -> ScalarFnId;
3635

@@ -58,35 +57,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static {
5857
}
5958

6059
/// Context passed to stats rewrite rules.
61-
pub(crate) struct StatsRewriteCtx<'a> {
60+
pub struct StatsRewriteCtx<'a> {
6261
session: &'a VortexSession,
6362
scope: &'a DType,
6463
}
6564

6665
impl<'a> StatsRewriteCtx<'a> {
6766
/// Create a rewrite context for `session`.
68-
pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
67+
pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self {
6968
Self { session, scope }
7069
}
7170

7271
/// Returns the session that owns the rewrite registry.
73-
pub(crate) fn session(&self) -> &'a VortexSession {
72+
pub fn session(&self) -> &'a VortexSession {
7473
self.session
7574
}
7675

7776
/// Return the dtype of `expr` within this rewrite scope.
78-
pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
77+
pub fn return_dtype(&self, expr: &Expression) -> VortexResult<DType> {
7978
expr.return_dtype(self.scope)
8079
}
8180

8281
/// Rewrite `expr` into a stats-backed falsifier.
83-
pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
82+
pub fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
8483
self.ensure_predicate(expr)?;
8584
rewrite(expr, self, StatsRewriteRule::falsify)
8685
}
8786

8887
/// Rewrite `expr` into a stats-backed satisfier.
89-
pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
88+
pub fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
9089
self.ensure_predicate(expr)?;
9190
rewrite(expr, self, StatsRewriteRule::satisfy)
9291
}

vortex-array/src/stats/session.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ impl Default for StatsSession {
3737

3838
impl StatsSession {
3939
/// Register a stats rewrite rule.
40-
#[allow(dead_code)]
41-
pub(crate) fn register_rewrite<R: StatsRewriteRule>(&self, rule: R) {
40+
pub fn register_rewrite<R: StatsRewriteRule>(&self, rule: R) {
4241
self.register_rewrite_ref(Arc::new(rule));
4342
}
4443

4544
/// Register a shared stats rewrite rule.
46-
#[allow(dead_code)]
47-
pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) {
45+
pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) {
4846
let mut rules = self.rewrite_rules.write();
4947
let rule_id = rule.scalar_fn_id();
5048
let mut updated_rules = rules
@@ -75,7 +73,7 @@ impl SessionVar for StatsSession {
7573
}
7674

7775
/// Extension trait for accessing stats session data.
78-
pub(crate) trait StatsSessionExt: SessionExt {
76+
pub trait StatsSessionExt: SessionExt {
7977
/// Returns the stats session state.
8078
fn stats(&self) -> Ref<'_, StatsSession> {
8179
self.get::<StatsSession>()

vortex-file/src/file.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask;
2222
use vortex_array::dtype::FieldPath;
2323
use vortex_array::dtype::FieldPathSet;
2424
use vortex_array::expr::Expression;
25-
use vortex_array::expr::pruning::checked_pruning_expr;
25+
use vortex_array::expr::pruning::checked_pruning_expr_with_session;
2626
use vortex_array::scalar_fn::internal::row_count::substitute_row_count;
2727
use vortex_error::VortexResult;
2828
use vortex_layout::LayoutReader;
@@ -217,7 +217,9 @@ impl VortexFile {
217217
}),
218218
);
219219

220-
let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else {
220+
let Some((predicate, required_stats)) =
221+
checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)?
222+
else {
221223
return Ok(false);
222224
};
223225

0 commit comments

Comments
 (0)