From 5b6a86224d49e0378484d32588a27cf9d19e7a36 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 10 Jun 2026 14:49:56 -0700 Subject: [PATCH 1/8] Make stats rewrite rules public Port file pruning to session stats rewrites Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/mod.rs | 1 + vortex-array/src/expr/pruning/pruning_expr.rs | 175 ++++++++++++++++++ vortex-array/src/stats/mod.rs | 2 +- vortex-array/src/stats/rewrite.rs | 16 +- vortex-array/src/stats/session.rs | 8 +- vortex-file/src/file.rs | 6 +- vortex-file/src/v2/file_stats_reader.rs | 108 ++++++++++- 7 files changed, 299 insertions(+), 17 deletions(-) diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index 7c20508b7a8..bbcfa5942a0 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -6,6 +6,7 @@ mod relation; pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; +pub use pruning_expr::checked_pruning_expr_with_session; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 00d29fbcf99..54c9666c283 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -5,18 +5,36 @@ use std::cell::RefCell; use std::iter; use itertools::Itertools; +use vortex_error::VortexResult; +use vortex_session::VortexSession; use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; +use crate::aggregate_fn::fns::all_nan::AllNan; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; +use crate::aggregate_fn::fns::all_non_null::AllNonNull; +use crate::aggregate_fn::fns::all_null::AllNull; +use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; use crate::expr::StatsCatalog; +use crate::expr::analysis::referenced_field_paths; +use crate::expr::eq; use crate::expr::get_item; +use crate::expr::lit; use crate::expr::root; use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::stat::StatFn; +use crate::scalar_fn::internal::row_count::RowCount; pub type RequiredStats = Relation; @@ -113,6 +131,163 @@ pub fn checked_pruning_expr( Some((expr, relation)) } +/// Build a pruning expression using session-registered stats rewrite rules. +/// +/// The returned expression is lowered to the same stats-table field references as +/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in +/// `available_stats`, this returns `Ok(None)`. +pub fn checked_pruning_expr_with_session( + expr: &Expression, + scope: &DType, + available_stats: &FieldPathSet, + session: &VortexSession, +) -> VortexResult> { + let Some(predicate) = expr.falsify(scope, session)? else { + return Ok(None); + }; + + lower_stat_fns(predicate, scope, available_stats) +} + +fn lower_stat_fns( + predicate: Expression, + scope: &DType, + available_stats: &FieldPathSet, +) -> VortexResult> { + let mut required_stats = Relation::new(); + let mut missing_stat = false; + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + if let Some(lowered) = + lower_stat_fn(&expr, scope, available_stats, &mut required_stats)? + { + return Ok(Transformed::yes(lowered)); + } + + missing_stat = true; + let dtype = expr.return_dtype(scope)?; + Ok(Transformed::yes(null_expr(dtype))) + })? + .into_inner(); + + if missing_stat { + return Ok(None); + } + + Ok(Some((lowered, required_stats))) +} + +fn lower_stat_fn( + expr: &Expression, + scope: &DType, + available_stats: &FieldPathSet, + required_stats: &mut RequiredStats, +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(scope)?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(false))); + } + return lower_stat_ref( + input, + Stat::NaNCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(true))); + } + return lower_stat_ref( + input, + Stat::NaNCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(Some(lit(0u64))); + } + + if aggregate_fn.is::() { + return lower_stat_ref( + input, + Stat::NullCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + return lower_stat_ref( + input, + Stat::NullCount, + scope, + available_stats, + required_stats, + ) + .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + + lower_stat_ref(input, stat, scope, available_stats, required_stats) +} + +fn lower_stat_ref( + input: &Expression, + stat: Stat, + scope: &DType, + available_stats: &FieldPathSet, + required_stats: &mut RequiredStats, +) -> VortexResult> { + let field_paths = referenced_field_paths(input, scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + let stat_path = field_path.clone().push(stat.name()); + if !available_stats.contains(&stat_path) { + return Ok(None); + } + + required_stats.insert(field_path.clone(), stat); + Ok(Some(get_item( + field_path_stat_field_name(field_path, stat), + root(), + ))) +} + +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + #[cfg(test)] mod tests { use rstest::fixture; diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index ceb085e0815..3d4cfeb6111 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -21,7 +21,7 @@ pub use stats_set::*; mod array; pub mod expr; pub mod flatbuffers; -pub(crate) mod rewrite; +pub mod rewrite; pub mod session; mod stats_set; diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index 52d354df1a0..dfba62ded40 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -21,7 +21,7 @@ mod builtins; pub(crate) use builtins::register_builtins; /// Shared reference to a stats rewrite rule. -pub(crate) type StatsRewriteRuleRef = Arc; +pub type StatsRewriteRuleRef = Arc; /// A plugin-provided rule for predicates whose root scalar function matches this rule. /// @@ -40,7 +40,7 @@ pub(crate) type StatsRewriteRuleRef = Arc; /// `expr` is the full predicate expression whose root scalar function id is /// [`Self::scalar_fn_id`]. Use [`StatsRewriteCtx`] to resolve dtypes and recursively rewrite child /// predicates. -pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { +pub trait StatsRewriteRule: Debug + Send + Sync + 'static { /// Returns the scalar function id handled by this rule. fn scalar_fn_id(&self) -> ScalarFnId; @@ -83,35 +83,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { } /// Context passed to stats rewrite rules. -pub(crate) struct StatsRewriteCtx<'a> { +pub struct StatsRewriteCtx<'a> { session: &'a VortexSession, scope: &'a DType, } impl<'a> StatsRewriteCtx<'a> { /// Create a rewrite context for `session`. - pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self { + pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self { Self { session, scope } } /// Returns the session that owns the rewrite registry. - pub(crate) fn session(&self) -> &'a VortexSession { + pub fn session(&self) -> &'a VortexSession { self.session } /// Return the dtype of `expr` within this rewrite scope. - pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult { + pub fn return_dtype(&self, expr: &Expression) -> VortexResult { expr.return_dtype(self.scope) } /// Rewrite `expr` into a stats-backed falsifier. - pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult> { + pub fn falsify(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::falsify) } /// Rewrite `expr` into a stats-backed satisfier. - pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult> { + pub fn satisfy(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::satisfy) } diff --git a/vortex-array/src/stats/session.rs b/vortex-array/src/stats/session.rs index 2d4325b2cd7..91eae4a4fa9 100644 --- a/vortex-array/src/stats/session.rs +++ b/vortex-array/src/stats/session.rs @@ -37,14 +37,12 @@ impl Default for StatsSession { impl StatsSession { /// Register a stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite(&self, rule: R) { + pub fn register_rewrite(&self, rule: R) { self.register_rewrite_ref(Arc::new(rule)); } /// Register a shared stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { + pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { let mut rules = self.rewrite_rules.write(); let rule_id = rule.scalar_fn_id(); let mut updated_rules = rules @@ -75,7 +73,7 @@ impl SessionVar for StatsSession { } /// Extension trait for accessing stats session data. -pub(crate) trait StatsSessionExt: SessionExt { +pub trait StatsSessionExt: SessionExt { /// Returns the stats session state. fn stats(&self) -> Ref<'_, StatsSession> { self.get::() diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index ded986f6210..225d18b561a 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr; +use vortex_array::expr::pruning::checked_pruning_expr_with_session; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; @@ -217,7 +217,9 @@ impl VortexFile { }), ); - let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else { + let Some((predicate, required_stats)) = + checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)? + else { return Ok(false); }; diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 0121c12b07d..22c92e817d6 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,22 +10,37 @@ use std::ops::Range; use std::sync::Arc; +use itertools::Itertools; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::all_nan::AllNan; +use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; +use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; +use vortex_array::aggregate_fn::fns::all_null::AllNull; +use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; +use vortex_array::dtype::Nullability; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; +use vortex_array::expr::analysis::referenced_field_paths; +use vortex_array::expr::eq; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; +use vortex_array::expr::traversal::NodeExt; +use vortex_array::expr::traversal::Transformed; use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::EmptyOptions; +use vortex_array::scalar_fn::ScalarFnVTableExt; use vortex_array::scalar_fn::fns::literal::Literal; +use vortex_array::scalar_fn::fns::stat::StatFn; +use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::ArrayFuture; @@ -83,10 +98,11 @@ impl FileStatsLayoutReader { /// Row-count placeholders are resolved against the full file row count, /// independent of the requested row range. fn evaluate_file_stats(&self, expr: &Expression) -> VortexResult { - let Some(pruning_expr) = expr.stat_falsification(self) else { + let Some(pruning_expr) = expr.falsify(self.child.dtype(), &self.session)? else { // If there is no pruning expression, we can't prune. return Ok(false); }; + let pruning_expr = self.lower_stats(pruning_expr)?; // Given how we implemented the StatsCatalog, we know the expression must be all literals // or row_count placeholders. We can therefore optimize with a null scope since there are @@ -115,11 +131,101 @@ impl FileStatsLayoutReader { Ok(result.as_bool().value() == Some(true)) } + fn lower_stats(&self, predicate: Expression) -> VortexResult { + predicate + .transform_down(|expr| { + if expr.is::() { + return self.lower_stat_fn(expr).map(Transformed::yes); + } + + Ok(Transformed::no(expr)) + }) + .map(Transformed::into_inner) + } + + fn lower_stat_fn(&self, expr: Expression) -> VortexResult { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(self.child.dtype())?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(lit(false)); + } + return Ok(self + .stat_ref(input, Stat::NaNCount)? + .map(|stat| eq(stat, row_count_expr())) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(lit(true)); + } + return Ok(self + .stat_ref(input, Stat::NaNCount)? + .map(|stat| eq(stat, lit(0u64))) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(lit(0u64)); + } + + if aggregate_fn.is::() { + return Ok(self + .stat_ref(input, Stat::NullCount)? + .map(|stat| eq(stat, row_count_expr())) + .unwrap_or_else(null_bool_expr)); + } + + if aggregate_fn.is::() { + return Ok(self + .stat_ref(input, Stat::NullCount)? + .map(|stat| eq(stat, lit(0u64))) + .unwrap_or_else(null_bool_expr)); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(null_expr(expr.return_dtype(self.child.dtype())?)); + }; + + let return_dtype = expr.return_dtype(self.child.dtype())?; + Ok(self + .stat_ref(input, stat)? + .unwrap_or_else(|| null_expr(return_dtype))) + } + + fn stat_ref(&self, input: &Expression, stat: Stat) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.child.dtype())?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + Ok(self.stats_ref(field_path, stat)) + } + pub fn file_stats(&self) -> &FileStatistics { &self.file_stats } } +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn null_bool_expr() -> Expression { + null_expr(DType::Bool(Nullability::NonNullable)) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. impl StatsCatalog for FileStatsLayoutReader { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { From e8dd011748a1138750b3946a0af0d30a92fdcadb Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 11:18:48 -0400 Subject: [PATCH 2/8] Centralize stat expression binding Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/pruning_expr.rs | 182 ++++-------------- vortex-array/src/stats/bind.rs | 160 +++++++++++++++ vortex-array/src/stats/mod.rs | 1 + vortex-file/src/v2/file_stats_reader.rs | 119 +++--------- vortex-layout/src/layouts/zoned/zone_map.rs | 140 +++----------- 5 files changed, 256 insertions(+), 346 deletions(-) create mode 100644 vortex-array/src/stats/bind.rs diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 54c9666c283..5208b2efada 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -10,11 +10,6 @@ use vortex_session::VortexSession; use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; -use crate::aggregate_fn::fns::all_nan::AllNan; -use crate::aggregate_fn::fns::all_non_nan::AllNonNan; -use crate::aggregate_fn::fns::all_non_null::AllNonNull; -use crate::aggregate_fn::fns::all_null::AllNull; -use crate::aggregate_fn::fns::nan_count::NanCount; use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; @@ -23,18 +18,11 @@ use crate::dtype::FieldPathSet; use crate::expr::Expression; use crate::expr::StatsCatalog; use crate::expr::analysis::referenced_field_paths; -use crate::expr::eq; use crate::expr::get_item; -use crate::expr::lit; use crate::expr::root; use crate::expr::stats::Stat; -use crate::expr::traversal::NodeExt; -use crate::expr::traversal::Transformed; -use crate::scalar::Scalar; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::ScalarFnVTableExt; -use crate::scalar_fn::fns::stat::StatFn; -use crate::scalar_fn::internal::row_count::RowCount; +use crate::stats::bind::StatBinder; +use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; @@ -146,146 +134,54 @@ pub fn checked_pruning_expr_with_session( return Ok(None); }; - lower_stat_fns(predicate, scope, available_stats) -} - -fn lower_stat_fns( - predicate: Expression, - scope: &DType, - available_stats: &FieldPathSet, -) -> VortexResult> { - let mut required_stats = Relation::new(); - let mut missing_stat = false; - let lowered = predicate - .transform_down(|expr| { - if !expr.is::() { - return Ok(Transformed::no(expr)); - } - - if let Some(lowered) = - lower_stat_fn(&expr, scope, available_stats, &mut required_stats)? - { - return Ok(Transformed::yes(lowered)); - } - - missing_stat = true; - let dtype = expr.return_dtype(scope)?; - Ok(Transformed::yes(null_expr(dtype))) - })? - .into_inner(); - - if missing_stat { + let mut binder = RequiredStatsBinder { + scope, + available_stats, + required_stats: Relation::new(), + }; + let Some(lowered) = bind_stats(predicate, &mut binder)? else { return Ok(None); - } + }; - Ok(Some((lowered, required_stats))) + Ok(Some((lowered, binder.required_stats))) } -fn lower_stat_fn( - expr: &Expression, - scope: &DType, - available_stats: &FieldPathSet, - required_stats: &mut RequiredStats, -) -> VortexResult> { - let options = expr.as_::(); - let aggregate_fn = options.aggregate_fn(); - let input = expr.child(0); - let input_dtype = input.return_dtype(scope)?; - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(false))); - } - return lower_stat_ref( - input, - Stat::NaNCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); - } - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(true))); - } - return lower_stat_ref( - input, - Stat::NaNCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); - } +struct RequiredStatsBinder<'a> { + scope: &'a DType, + available_stats: &'a FieldPathSet, + required_stats: RequiredStats, +} - if aggregate_fn.is::() && !has_nans(&input_dtype) { - return Ok(Some(lit(0u64))); +impl StatBinder for RequiredStatsBinder<'_> { + fn scope(&self) -> &DType { + self.scope } - if aggregate_fn.is::() { - return lower_stat_ref( - input, - Stat::NullCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, row_count_expr()))); - } + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + let stat_path = field_path.clone().push(stat.name()); + if !self.available_stats.contains(&stat_path) { + return Ok(None); + } - if aggregate_fn.is::() { - return lower_stat_ref( - input, - Stat::NullCount, - scope, - available_stats, - required_stats, - ) - .map(|stat| stat.map(|stat| eq(stat, lit(0u64)))); + self.required_stats.insert(field_path.clone(), stat); + Ok(Some(get_item( + field_path_stat_field_name(field_path, stat), + root(), + ))) } - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(None); - }; - - lower_stat_ref(input, stat, scope, available_stats, required_stats) -} - -fn lower_stat_ref( - input: &Expression, - stat: Stat, - scope: &DType, - available_stats: &FieldPathSet, - required_stats: &mut RequiredStats, -) -> VortexResult> { - let field_paths = referenced_field_paths(input, scope)?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); - }; - let stat_path = field_path.clone().push(stat.name()); - if !available_stats.contains(&stat_path) { - return Ok(None); + fn missing_stat(&mut self, _dtype: DType) -> VortexResult> { + Ok(None) } - - required_stats.insert(field_path.clone(), stat); - Ok(Some(get_item( - field_path_stat_field_name(field_path, stat), - root(), - ))) -} - -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) -} - -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) -} - -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) } #[cfg(test)] diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs new file mode 100644 index 00000000000..714404fdceb --- /dev/null +++ b/vortex-array/src/stats/bind.rs @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Bind abstract `vortex.stat` expressions to a concrete stats representation. + +use vortex_error::VortexResult; + +use crate::aggregate_fn::fns::all_nan::AllNan; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; +use crate::aggregate_fn::fns::all_non_null::AllNonNull; +use crate::aggregate_fn::fns::all_null::AllNull; +use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::expr::eq; +use crate::expr::lit; +use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::EmptyOptions; +use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::stat::StatFn; +use crate::scalar_fn::internal::row_count::RowCount; + +/// A target that can bind abstract statistics to concrete expressions. +pub trait StatBinder { + /// The dtype scope used to type-check expressions before stats are bound. + fn scope(&self) -> &DType; + + /// Bind `stat(input)` to a concrete expression. + /// + /// Returning `Ok(None)` marks the stat as unavailable. [`bind_stats`] will + /// then call [`Self::missing_stat`] with the dtype expected from the + /// original `vortex.stat` expression. + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + stat_dtype: &DType, + ) -> VortexResult>; + + /// Expression to use when a stat is unavailable. + /// + /// The default is a nullable null literal, which preserves three-valued + /// pruning semantics for stats-table execution. Catalog-like binders can + /// return `Ok(None)` to reject expressions that require unavailable stats. + fn missing_stat(&mut self, dtype: DType) -> VortexResult> { + Ok(Some(null_expr(dtype))) + } +} + +/// Bind all `vortex.stat` expressions in `predicate`. +/// +/// The predicate is usually the output of a stats rewrite rule. This function +/// centralizes the legacy aggregate/stat mapping: `all_null` and `all_nan` +/// style aggregate expressions are expanded through exact count stats, while +/// direct aggregate stats are delegated to the supplied binder. +pub fn bind_stats( + predicate: Expression, + binder: &mut impl StatBinder, +) -> VortexResult> { + let scope = binder.scope().clone(); + let mut missing_stat = false; + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + match bind_stat_fn(&expr, &scope, binder)? { + Some(bound) => Ok(Transformed::yes(bound)), + None => { + let dtype = expr.return_dtype(&scope)?; + match binder.missing_stat(dtype.clone())? { + Some(missing) => Ok(Transformed::yes(missing)), + None => { + missing_stat = true; + Ok(Transformed::yes(null_expr(dtype))) + } + } + } + } + })? + .into_inner(); + + if missing_stat { + return Ok(None); + } + + Ok(Some(lowered)) +} + +fn bind_stat_fn( + expr: &Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + let input_dtype = input.return_dtype(scope)?; + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(false))); + } + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NaNCount, &stat_dtype)? + .map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + if !has_nans(&input_dtype) { + return Ok(Some(lit(true))); + } + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NaNCount, &stat_dtype)? + .map(|stat| eq(stat, lit(0u64)))); + } + + if aggregate_fn.is::() && !has_nans(&input_dtype) { + return Ok(Some(lit(0u64))); + } + + if aggregate_fn.is::() { + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NullCount, &stat_dtype)? + .map(|stat| eq(stat, row_count_expr()))); + } + + if aggregate_fn.is::() { + let stat_dtype = expr.return_dtype(scope)?; + return Ok(binder + .bind_stat(input, Stat::NullCount, &stat_dtype)? + .map(|stat| eq(stat, lit(0u64)))); + } + + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + + let stat_dtype = expr.return_dtype(scope)?; + binder.bind_stat(input, stat, &stat_dtype) +} + +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 3d4cfeb6111..5f5684dbde2 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -19,6 +19,7 @@ pub use expr::sum; pub use stats_set::*; mod array; +pub mod bind; pub mod expr; pub mod flatbuffers; pub mod rewrite; diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 22c92e817d6..be2efce92a2 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -15,34 +15,24 @@ use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::fns::all_nan::AllNan; -use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; -use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; -use vortex_array::aggregate_fn::fns::all_null::AllNull; -use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::NullArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; -use vortex_array::dtype::Nullability; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; use vortex_array::expr::analysis::referenced_field_paths; -use vortex_array::expr::eq; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; -use vortex_array::expr::traversal::NodeExt; -use vortex_array::expr::traversal::Transformed; use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::EmptyOptions; -use vortex_array::scalar_fn::ScalarFnVTableExt; use vortex_array::scalar_fn::fns::literal::Literal; -use vortex_array::scalar_fn::fns::stat::StatFn; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; use vortex_layout::LayoutReaderRef; @@ -132,77 +122,11 @@ impl FileStatsLayoutReader { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - predicate - .transform_down(|expr| { - if expr.is::() { - return self.lower_stat_fn(expr).map(Transformed::yes); - } - - Ok(Transformed::no(expr)) - }) - .map(Transformed::into_inner) - } - - fn lower_stat_fn(&self, expr: Expression) -> VortexResult { - let options = expr.as_::(); - let aggregate_fn = options.aggregate_fn(); - let input = expr.child(0); - let input_dtype = input.return_dtype(self.child.dtype())?; - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(lit(false)); - } - return Ok(self - .stat_ref(input, Stat::NaNCount)? - .map(|stat| eq(stat, row_count_expr())) - .unwrap_or_else(null_bool_expr)); - } - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(lit(true)); - } - return Ok(self - .stat_ref(input, Stat::NaNCount)? - .map(|stat| eq(stat, lit(0u64))) - .unwrap_or_else(null_bool_expr)); - } - - if aggregate_fn.is::() && !has_nans(&input_dtype) { - return Ok(lit(0u64)); - } - - if aggregate_fn.is::() { - return Ok(self - .stat_ref(input, Stat::NullCount)? - .map(|stat| eq(stat, row_count_expr())) - .unwrap_or_else(null_bool_expr)); - } - - if aggregate_fn.is::() { - return Ok(self - .stat_ref(input, Stat::NullCount)? - .map(|stat| eq(stat, lit(0u64))) - .unwrap_or_else(null_bool_expr)); - } - - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(null_expr(expr.return_dtype(self.child.dtype())?)); + let mut binder = FileStatsBinder { reader: self }; + let Some(predicate) = bind_stats(predicate, &mut binder)? else { + vortex_bail!("missing stats should lower to null literals"); }; - - let return_dtype = expr.return_dtype(self.child.dtype())?; - Ok(self - .stat_ref(input, stat)? - .unwrap_or_else(|| null_expr(return_dtype))) - } - - fn stat_ref(&self, input: &Expression, stat: Stat) -> VortexResult> { - let field_paths = referenced_field_paths(input, self.child.dtype())?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); - }; - Ok(self.stats_ref(field_path, stat)) + Ok(predicate) } pub fn file_stats(&self) -> &FileStatistics { @@ -210,20 +134,27 @@ impl FileStatsLayoutReader { } } -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) -} - -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) +struct FileStatsBinder<'a> { + reader: &'a FileStatsLayoutReader, } -fn null_bool_expr() -> Expression { - null_expr(DType::Bool(Nullability::NonNullable)) -} +impl StatBinder for FileStatsBinder<'_> { + fn scope(&self) -> &DType { + self.reader.child.dtype() + } -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let field_paths = referenced_field_paths(input, self.scope())?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + Ok(self.reader.stats_ref(field_path, stat)) + } } /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 96154e69571..5c9fdcdad32 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,33 +8,20 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::fns::all_nan::AllNan; -use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; -use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; -use vortex_array::aggregate_fn::fns::all_null::AllNull; -use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; -use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::is_root; -use vortex_array::expr::lit; use vortex_array::expr::root; use vortex_array::expr::stats::Stat; -use vortex_array::expr::traversal::NodeExt; -use vortex_array::expr::traversal::Transformed; -use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::EmptyOptions; -use vortex_array::scalar_fn::ScalarFnVTableExt; -use vortex_array::scalar_fn::fns::stat::StatFn; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; use vortex_error::VortexResult; @@ -132,107 +119,42 @@ impl ZoneMap { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - // Rewritten predicates are evaluated against the stats table, not the data - // column. Lower each StatFn before execution so unavailable stats become - // nullable "unknown" constants rather than prune signals. - predicate - .transform_down(|expr| { - if expr.is::() { - return self.lower_stat_fn(expr).map(Transformed::yes); - } - - Ok(Transformed::no(expr)) - }) - .map(Transformed::into_inner) - } - - fn lower_stat_fn(&self, expr: Expression) -> VortexResult { - // This is the bridge from aggregate-backed bound expressions to the legacy - // zoned stats columns. Exact NullCount and NanCount can prove richer - // all-* aggregates; non-root or missing stats lower to nullable unknowns. - let options = expr.as_::(); - let input = expr.child(0); - let input_dtype = input.return_dtype(&self.column_dtype)?; - let input_is_root = is_root(input); - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(false)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(true)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, lit(0u64))); - } - - if options.aggregate_fn().is::() && !has_nans(&input_dtype) { - return Ok(lit(0u64)); - } - - let return_dtype = match options.aggregate_fn().return_dtype(&input_dtype) { - Some(return_dtype) => return_dtype, - None => vortex_bail!( - "Aggregate function {} does not support input dtype {}", - options.aggregate_fn(), - input_dtype - ), + let mut binder = ZoneMapStatsBinder { zone_map: self }; + let Some(predicate) = bind_stats(predicate, &mut binder)? else { + vortex_bail!("missing stats should lower to null literals"); }; - - if !input_is_root { - return Ok(null_expr(return_dtype)); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, lit(0u64))); - } - - let Some(stat) = Stat::from_aggregate_fn(options.aggregate_fn()) else { - return Ok(null_expr(return_dtype)); - }; - - self.stat_field_expr(stat) - } - - fn stat_field_expr(&self, stat: Stat) -> VortexResult { - if self.array.unmasked_field_by_name_opt(stat.name()).is_some() { - return Ok(get_item(stat.name(), root())); - } - - let Some(dtype) = stat.dtype(&self.column_dtype) else { - vortex_bail!( - "Stat {} does not support column dtype {}", - stat, - self.column_dtype - ); - }; - Ok(null_expr(dtype)) + Ok(predicate) } } -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) +struct ZoneMapStatsBinder<'a> { + zone_map: &'a ZoneMap, } -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) -} +impl StatBinder for ZoneMapStatsBinder<'_> { + fn scope(&self) -> &DType { + &self.zone_map.column_dtype + } -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + if !is_root(input) { + return Ok(None); + } + if self + .zone_map + .array + .unmasked_field_by_name_opt(stat.name()) + .is_none() + { + return Ok(None); + } + Ok(Some(get_item(stat.name(), root()))) + } } /// Build per-zone row counts for a zone map. From c1d6d945d39ed9bafbb60456b84f313e76304d1a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 14:29:34 -0400 Subject: [PATCH 3/8] Fix file stats binding for computed expressions Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-file/src/v2/file_stats_reader.rs | 56 ++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index be2efce92a2..7d9bd0d66cf 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -10,7 +10,6 @@ use std::ops::Range; use std::sync::Arc; -use itertools::Itertools; use vortex_array::Canonical; use vortex_array::IntoArray; use vortex_array::MaskFuture; @@ -23,10 +22,11 @@ use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; use vortex_array::expr::StatsCatalog; -use vortex_array::expr::analysis::referenced_field_paths; +use vortex_array::expr::is_root; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::stats::bind::StatBinder; @@ -149,14 +149,22 @@ impl StatBinder for FileStatsBinder<'_> { stat: Stat, _stat_dtype: &DType, ) -> VortexResult> { - let field_paths = referenced_field_paths(input, self.scope())?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { + let Some(field_path) = direct_field_path(input) else { return Ok(None); }; - Ok(self.reader.stats_ref(field_path, stat)) + Ok(self.reader.stats_ref(&field_path, stat)) } } +fn direct_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + let field_name = expr.as_opt::()?; + direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) +} + /// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. impl StatsCatalog for FileStatsLayoutReader { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { @@ -261,6 +269,7 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::expr::checked_add; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::is_not_null; @@ -402,6 +411,43 @@ mod tests { }) } + #[test] + fn no_pruning_for_computed_expression_stats() -> VortexResult<()> { + block_on(|handle| async { + let session = SESSION.clone().with_handle(handle); + let ctx = ArrayContext::empty(); + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let struct_array = + StructArray::from_fields([("col", buffer![0i32, 100].into_array())].as_slice())?; + let strategy = TableStrategy::new( + Arc::new(FlatLayoutStrategy::default()), + Arc::new(FlatLayoutStrategy::default()), + ); + let layout = strategy + .write_stream( + ctx, + Arc::::clone(&segments), + struct_array.into_array().to_array_stream().sequenced(ptr), + eof, + &session, + ) + .await?; + + let child = layout.new_reader("".into(), segments, &SESSION, &Default::default())?; + let reader = + FileStatsLayoutReader::new(child, test_file_stats(0, 100), SESSION.clone()); + + let expr = gt(checked_add(get_item("col", root()), lit(5i32)), lit(102i32)); + let mask = Mask::new_true(2); + let result = reader.pruning_evaluation(&(0..2), &expr, mask)?.await?; + + assert_eq!(result, Mask::new_true(2)); + + Ok(()) + }) + } + /// Regression test: `IS NULL` on a nullable timestamp column must not fail with a /// dtype mismatch. The bug was that `stats_ref` used the *field* dtype (timestamp) /// for the `NullCount` stat scalar instead of the stat's own dtype (u64). From 468ebfda10c2cfccb3cacb96aac55439b00ef5d4 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 15:34:21 -0400 Subject: [PATCH 4/8] Fuse checked pruning stats rewrites Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/mod.rs | 1 - vortex-array/src/expr/pruning/pruning_expr.rs | 218 +++++++++--------- vortex-array/src/scalar_fn/fns/is_not_null.rs | 3 + vortex-array/src/scalar_fn/fns/is_null.rs | 3 + .../src/scalar_fn/fns/list_contains/mod.rs | 10 + vortex-array/src/stats/bind.rs | 192 +++++++-------- vortex-file/src/file.rs | 4 +- vortex-layout/src/layouts/zoned/zone_map.rs | 28 +-- 8 files changed, 247 insertions(+), 212 deletions(-) diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index bbcfa5942a0..7c20508b7a8 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -6,7 +6,6 @@ mod relation; pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; -pub use pruning_expr::checked_pruning_expr_with_session; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 5208b2efada..724dd6a1408 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -1,12 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +#[cfg(test)] use std::cell::RefCell; use std::iter; use itertools::Itertools; use vortex_error::VortexResult; use vortex_session::VortexSession; +#[cfg(test)] use vortex_utils::aliases::hash_map::HashMap; use super::relation::Relation; @@ -16,49 +18,29 @@ use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; +#[cfg(test)] use crate::expr::StatsCatalog; use crate::expr::analysis::referenced_field_paths; use crate::expr::get_item; +use crate::expr::is_root; use crate::expr::root; use crate::expr::stats::Stat; +use crate::scalar_fn::fns::cast::Cast; +use crate::scalar_fn::fns::get_item::GetItem; use crate::stats::bind::StatBinder; use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; -// A catalog that return a stat column whenever it is required, tracking all accessed +// A catalog that returns a stat column whenever it is required, tracking all accessed // stats and returning them later. +#[cfg(test)] #[derive(Default)] pub(crate) struct TrackingStatsCatalog { usage: RefCell>, } -impl TrackingStatsCatalog { - /// Consume the catalog, yielding a map of field statistics that were required - /// for each expression. - fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> { - self.usage.into_inner() - } -} - -// A catalog that return a stat column if it exists in the given scope. -struct ScopeStatsCatalog<'a> { - inner: TrackingStatsCatalog, - available_stats: &'a FieldPathSet, -} - -impl StatsCatalog for ScopeStatsCatalog<'_> { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let stat_path = field_path.clone().push(stat.name()); - - if self.available_stats.contains(&stat_path) { - self.inner.stats_ref(field_path, stat) - } else { - None - } - } -} - +#[cfg(test)] impl StatsCatalog for TrackingStatsCatalog { fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { let mut expr = root(); @@ -85,8 +67,7 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa .into() } -/// Build a pruning expr mask, using an existing set of stats. -/// The available stats are provided as a set of [`FieldPath`]. +/// Build a pruning expression using session-registered stats rewrite rules. /// /// A pruning expression is one that returns `true` for all positions where the original expression /// cannot hold, and false if it cannot be determined from stats alone whether the positions can @@ -97,34 +78,10 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa /// replace those placeholders with the row count for its current scope before /// executing the returned expression. /// -/// If the falsification logic attempts to access an unknown stat, -/// this function will return `None`. +/// The returned expression is lowered to stats-table field references. Proof branches that require +/// stats not present in `available_stats` are discarded; this returns `Ok(None)` if no usable proof +/// remains. pub fn checked_pruning_expr( - expr: &Expression, - available_stats: &FieldPathSet, -) -> Option<(Expression, RequiredStats)> { - let catalog = ScopeStatsCatalog { - inner: Default::default(), - available_stats, - }; - - let expr = expr.stat_falsification(&catalog)?; - - // TODO(joe): filter access by used exprs - let mut relation: Relation = Relation::new(); - for ((field_path, stat), _) in catalog.inner.into_usages() { - relation.insert(field_path, stat) - } - - Some((expr, relation)) -} - -/// Build a pruning expression using session-registered stats rewrite rules. -/// -/// The returned expression is lowered to the same stats-table field references as -/// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in -/// `available_stats`, this returns `Ok(None)`. -pub fn checked_pruning_expr_with_session( expr: &Expression, scope: &DType, available_stats: &FieldPathSet, @@ -163,9 +120,15 @@ impl StatBinder for RequiredStatsBinder<'_> { stat: Stat, _stat_dtype: &DType, ) -> VortexResult> { - let field_paths = referenced_field_paths(input, self.scope)?; - let Some(field_path) = field_paths.iter().exactly_one().ok() else { - return Ok(None); + let field_path = match direct_stat_field_path(input) { + Some(field_path) => field_path, + None => { + let field_paths = referenced_field_paths(input, self.scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + field_path.clone() + } }; let stat_path = field_path.clone().push(stat.name()); if !self.available_stats.contains(&stat_path) { @@ -174,7 +137,7 @@ impl StatBinder for RequiredStatsBinder<'_> { self.required_stats.insert(field_path.clone(), stat); Ok(Some(get_item( - field_path_stat_field_name(field_path, stat), + field_path_stat_field_name(&field_path, stat), root(), ))) } @@ -182,22 +145,54 @@ impl StatBinder for RequiredStatsBinder<'_> { fn missing_stat(&mut self, _dtype: DType) -> VortexResult> { Ok(None) } + + fn bind_branch(&mut self, bind: F) -> VortexResult> + where + Self: Sized, + F: FnOnce(&mut Self) -> VortexResult>, + { + let required_stats = self.required_stats.clone(); + let bound = bind(self)?; + if bound.is_none() { + self.required_stats = required_stats; + } + Ok(bound) + } +} + +fn direct_stat_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + if expr.is::() { + return direct_stat_field_path(expr.child(0)); + } + + let field_name = expr.as_opt::()?; + direct_stat_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } #[cfg(test)] mod tests { + use std::sync::LazyLock; + use rstest::fixture; use rstest::rstest; + use vortex_session::VortexSession; + use vortex_utils::aliases::hash_map::HashMap; use vortex_utils::aliases::hash_set::HashSet; - use super::HashMap; + use super::RequiredStats; use crate::dtype::DType; use crate::dtype::FieldName; use crate::dtype::FieldNames; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::dtype::Nullability; + use crate::dtype::PType; use crate::dtype::StructFields; + use crate::expr::Expression; use crate::expr::and; use crate::expr::between; use crate::expr::cast; @@ -217,6 +212,38 @@ mod tests { use crate::expr::stats::Stat; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::between::StrictComparison; + use crate::stats::session::StatsSession; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn test_scope() -> DType { + DType::Struct( + StructFields::from_iter([ + ("a", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("b", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("x", DType::Bool(Nullability::NonNullable)), + ("y", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("z", DType::Primitive(PType::I32, Nullability::NonNullable)), + ( + "float_col", + DType::Primitive(PType::F32, Nullability::NonNullable), + ), + ( + "int_col", + DType::Primitive(PType::I32, Nullability::NonNullable), + ), + ]), + Nullability::NonNullable, + ) + } + + fn checked( + expr: &Expression, + available_stats: &FieldPathSet, + ) -> Option<(Expression, RequiredStats)> { + checked_pruning_expr(expr, &test_scope(), available_stats, &SESSION).unwrap() + } // Implement some checked pruning expressions. #[fixture] @@ -237,7 +264,7 @@ mod tests { let name = FieldName::from("a"); let literal_eq = lit(42); let eq_expr = eq(get_item("a", root()), literal_eq.clone()); - let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); + let (converted, _refs) = checked(&eq_expr, &available_stats).unwrap(); let expected_expr = or( gt( get_item( @@ -263,7 +290,7 @@ mod tests { let other_col = FieldName::from("b"); let eq_expr = eq(col(column.clone()), col(other_col.clone())); - let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(&eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -308,7 +335,7 @@ mod tests { let other_col = FieldName::from("b"); let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone())); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -355,7 +382,7 @@ mod tests { let other_expr = col(other_col.clone()); let not_eq_expr = gt(col(column.clone()), other_expr); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -388,7 +415,7 @@ mod tests { let other_col = lit(42); let not_eq_expr = gt(col(column.clone()), other_col.clone()); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( @@ -413,7 +440,7 @@ mod tests { let other_expr = col(other_col.clone()); let not_eq_expr = lt(col(column.clone()), other_expr); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -446,7 +473,7 @@ mod tests { // pruning expr => a.min >= 42 let expr = lt(col("a"), lit(42)); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))]) @@ -458,7 +485,7 @@ mod tests { fn pruning_identity(available_stats: FieldPathSet) { let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50))); - let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (predicate, _) = checked(&expr, &available_stats).unwrap(); let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50))); assert_eq!(&predicate.to_string(), &expected_expr.to_string()); @@ -468,7 +495,7 @@ mod tests { // Test case: a > 10 AND a < 50 let column = FieldName::from("a"); let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50))); - let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap(); + let (predicate, _) = checked(&and_expr, &available_stats).unwrap(); // Expected: a_max <= 10 OR a_min >= 50 assert_eq!( @@ -507,7 +534,7 @@ mod tests { // True > False // True let expr = gt_eq(col("x"), gt(col("y"), col("z"))); - assert!(checked_pruning_expr(&expr, &available_stats).is_none()); + assert!(checked(&expr, &available_stats).is_none()); // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)` } @@ -530,46 +557,27 @@ mod tests { #[rstest] fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) { let expr = gt_eq(col("float_col"), lit(f32::NAN)); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - assert_eq!( - &converted, - &and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NaNCount of NaN is 1 - eq(lit(1u64), lit(0u64)), - ), - // This is the standard conversion of the >= operator. Comparing NAN to a max - // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited - // by the previous check for nan_count anyway. - lt(col("float_col_max"), lit(f32::NAN)), - ) - ); + assert!(checked(&expr, &available_stats_with_nans).is_none()); - // One half of the expression requires NAN count check, the other half does not. + // One half of the expression requires an all-non-NaN proof, the other half does not. let expr = and( gt(col("float_col"), lit(10f32)), lt(col("int_col"), lit(10)), ); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - + let (converted, refs) = checked(&expr, &available_stats_with_nans).unwrap(); + assert_eq!( + refs.map(), + &HashMap::from_iter([( + FieldPath::from_name("int_col"), + HashSet::from_iter([Stat::Min]) + )]) + ); assert_eq!( &converted, - &or( - // NaNCount check is enforced for the float column - and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NanCount of a non-NaN float literal is 0 - eq(lit(0u64), lit(0u64)), - ), - // We want the opposite: we can prune IF either one is false. - lt_eq(col("float_col_max"), lit(10f32)), - ), - // NanCount check is skipped for the int column - gt_eq(col("int_col_min"), lit(10)), - ) + // The float branch cannot be proven without AllNonNan stats, so the + // remaining proof is the int branch. + >_eq(col("int_col_min"), lit(10)) ) } @@ -584,7 +592,7 @@ mod tests { upper_strict: StrictComparison::NonStrict, }, ); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( @@ -613,7 +621,7 @@ mod tests { Nullability::NonNullable, ); let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value")); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 589333304e2..59986171ed1 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -267,11 +267,14 @@ mod tests { let (pruning_expr, st) = checked_pruning_expr( &expr, + &test_harness::struct_dtype(), &FieldPathSet::from_iter([FieldPath::from_iter([ Field::Name("a".into()), Field::Name("null_count".into()), ])]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!( diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 7315fbe8c07..6a971e7ecd4 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -251,11 +251,14 @@ mod tests { let (pruning_expr, st) = checked_pruning_expr( &expr, + &test_harness::struct_dtype(), &FieldPathSet::from_iter([FieldPath::from_iter([ Field::Name("a".into()), Field::Name("null_count".into()), ])]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64))); diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 978a1da1caf..bbaed489fd1 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -630,14 +630,24 @@ mod tests { )), col("a"), ); + let scope = DType::Struct( + StructFields::new( + ["a"].into(), + vec![DType::Primitive(I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); let (expr, st) = checked_pruning_expr( &expr, + &scope, &FieldPathSet::from_iter([ FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]), FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]), ]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!( diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 714404fdceb..82166e0b1bd 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -5,23 +5,15 @@ use vortex_error::VortexResult; -use crate::aggregate_fn::fns::all_nan::AllNan; -use crate::aggregate_fn::fns::all_non_nan::AllNonNan; -use crate::aggregate_fn::fns::all_non_null::AllNonNull; -use crate::aggregate_fn::fns::all_null::AllNull; -use crate::aggregate_fn::fns::nan_count::NanCount; +use crate::aggregate_fn::AggregateFnRef; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::eq; use crate::expr::lit; use crate::expr::stats::Stat; -use crate::expr::traversal::NodeExt; -use crate::expr::traversal::Transformed; use crate::scalar::Scalar; -use crate::scalar_fn::EmptyOptions; -use crate::scalar_fn::ScalarFnVTableExt; +use crate::scalar_fn::fns::binary::Binary; +use crate::scalar_fn::fns::operators::Operator; use crate::scalar_fn::fns::stat::StatFn; -use crate::scalar_fn::internal::row_count::RowCount; /// A target that can bind abstract statistics to concrete expressions. pub trait StatBinder { @@ -40,6 +32,23 @@ pub trait StatBinder { stat_dtype: &DType, ) -> VortexResult>; + /// Bind `aggregate_fn(input)` to a concrete expression. + /// + /// The default implementation supports aggregate functions with legacy + /// [`Stat`] slots. Binders that store richer aggregate stats can override + /// this method without extending the generic stats binding walker. + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + self.bind_stat(input, stat, stat_dtype) + } + /// Expression to use when a stat is unavailable. /// /// The default is a nullable null literal, which preserves three-valued @@ -48,47 +57,101 @@ pub trait StatBinder { fn missing_stat(&mut self, dtype: DType) -> VortexResult> { Ok(Some(null_expr(dtype))) } + + /// Bind a proof branch, rolling back any binder-local bookkeeping when the + /// branch cannot be bound. + /// + /// Binders that only substitute expressions can use the default + /// implementation. Binders that track required stats should override this + /// so discarded proof branches do not leak requirements. + fn bind_branch(&mut self, bind: F) -> VortexResult> + where + Self: Sized, + F: FnOnce(&mut Self) -> VortexResult>, + { + bind(self) + } } /// Bind all `vortex.stat` expressions in `predicate`. /// -/// The predicate is usually the output of a stats rewrite rule. This function -/// centralizes the legacy aggregate/stat mapping: `all_null` and `all_nan` -/// style aggregate expressions are expanded through exact count stats, while -/// direct aggregate stats are delegated to the supplied binder. +/// The predicate is usually the output of a stats rewrite rule. Rewrite rules +/// are responsible for expressing stat semantics; binding maps aggregate-backed +/// stat requests to the concrete stats representation supported by the binder. pub fn bind_stats( predicate: Expression, binder: &mut impl StatBinder, ) -> VortexResult> { let scope = binder.scope().clone(); - let mut missing_stat = false; - let lowered = predicate - .transform_down(|expr| { - if !expr.is::() { - return Ok(Transformed::no(expr)); - } + bind_stats_expr(predicate, &scope, binder) +} - match bind_stat_fn(&expr, &scope, binder)? { - Some(bound) => Ok(Transformed::yes(bound)), - None => { - let dtype = expr.return_dtype(&scope)?; - match binder.missing_stat(dtype.clone())? { - Some(missing) => Ok(Transformed::yes(missing)), - None => { - missing_stat = true; - Ok(Transformed::yes(null_expr(dtype))) - } - } - } +fn bind_stats_expr( + expr: Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + if expr.is::() { + return match bind_stat_fn(&expr, scope, binder)? { + Some(bound) => Ok(Some(bound)), + None => { + let dtype = expr.return_dtype(scope)?; + binder.missing_stat(dtype) } - })? - .into_inner(); + }; + } + + if expr.is::() { + return bind_binary_expr(expr, scope, binder); + } - if missing_stat { - return Ok(None); + let mut children = Vec::with_capacity(expr.children().len()); + for child in expr.children().iter() { + let Some(child) = bind_stats_expr(child.clone(), scope, binder)? else { + return Ok(None); + }; + children.push(child); } - Ok(Some(lowered)) + Ok(Some(expr.with_children(children)?)) +} + +fn bind_binary_expr( + expr: Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + let operator = expr.as_::(); + + match operator { + Operator::Or => { + let lhs = binder + .bind_branch(|binder| bind_stats_expr(expr.child(0).clone(), scope, binder))?; + let rhs = binder + .bind_branch(|binder| bind_stats_expr(expr.child(1).clone(), scope, binder))?; + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), + (Some(expr), None) | (None, Some(expr)) => Ok(Some(expr)), + (None, None) => Ok(None), + } + } + Operator::And => binder.bind_branch(|binder| { + let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; + let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), + _ => Ok(None), + } + }), + _ => binder.bind_branch(|binder| { + let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; + let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; + match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), + _ => Ok(None), + } + }), + } } fn bind_stat_fn( @@ -99,62 +162,11 @@ fn bind_stat_fn( let options = expr.as_::(); let aggregate_fn = options.aggregate_fn(); let input = expr.child(0); - let input_dtype = input.return_dtype(scope)?; - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(false))); - } - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NaNCount, &stat_dtype)? - .map(|stat| eq(stat, row_count_expr()))); - } - - if aggregate_fn.is::() { - if !has_nans(&input_dtype) { - return Ok(Some(lit(true))); - } - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NaNCount, &stat_dtype)? - .map(|stat| eq(stat, lit(0u64)))); - } - - if aggregate_fn.is::() && !has_nans(&input_dtype) { - return Ok(Some(lit(0u64))); - } - - if aggregate_fn.is::() { - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NullCount, &stat_dtype)? - .map(|stat| eq(stat, row_count_expr()))); - } - - if aggregate_fn.is::() { - let stat_dtype = expr.return_dtype(scope)?; - return Ok(binder - .bind_stat(input, Stat::NullCount, &stat_dtype)? - .map(|stat| eq(stat, lit(0u64)))); - } - - let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { - return Ok(None); - }; let stat_dtype = expr.return_dtype(scope)?; - binder.bind_stat(input, stat, &stat_dtype) -} - -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) + binder.bind_aggregate(input, aggregate_fn, &stat_dtype) } fn null_expr(dtype: DType) -> Expression { lit(Scalar::null(dtype.as_nullable())) } - -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) -} diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index 225d18b561a..9e39a6fab5c 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -22,7 +22,7 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::FieldPathSet; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr_with_session; +use vortex_array::expr::pruning::checked_pruning_expr; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_error::VortexResult; use vortex_layout::LayoutReader; @@ -218,7 +218,7 @@ impl VortexFile { ); let Some((predicate, required_stats)) = - checked_pruning_expr_with_session(filter, self.footer.dtype(), &set, &self.session)? + checked_pruning_expr(filter, self.footer.dtype(), &set, &self.session)? else { return Ok(false); }; diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 5c9fdcdad32..f16082fc90e 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -344,7 +344,7 @@ mod tests { } #[test] - fn all_null_stat_fn_lowers_to_null_count_and_row_count() { + fn all_null_stat_fn_lowers_to_unknown_mask() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -359,11 +359,14 @@ mod tests { .unwrap(); let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); } #[test] - fn all_non_null_stat_fn_lowers_to_null_count() { + fn all_non_null_stat_fn_lowers_to_unknown_mask() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -380,12 +383,12 @@ mod tests { let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn non_float_nan_stat_fns_lower_to_constants() { + fn non_float_nan_stat_fns_error() { let zone_map = ZoneMap::try_new( PType::I32.into(), StructArray::try_new(FieldNames::empty(), vec![], 2, Validity::NonNullable).unwrap(), @@ -395,11 +398,8 @@ mod tests { ) .unwrap(); - let mask = zone_map.prune(&all_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); - - let mask = zone_map.prune(&all_non_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true])); + assert!(zone_map.prune(&all_nan(root()), &SESSION).is_err()); + assert!(zone_map.prune(&all_non_nan(root()), &SESSION).is_err()); } #[test] @@ -429,7 +429,7 @@ mod tests { } #[test] - fn float_min_max_stat_fn_requires_nan_count() { + fn float_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -483,12 +483,12 @@ mod tests { let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn float_cast_min_max_stat_fn_uses_source_nan_count() { + fn float_cast_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -525,7 +525,7 @@ mod tests { let pruning_expr = falsify(&expr, PType::F32.into()); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true])); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); } #[test] From cd4867432dbcc158de108402a4bd2b91d8caf55b Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 15:46:35 -0400 Subject: [PATCH 5/8] Simplify stats binding null handling Signed-off-by: "Nicholas Gates" Signed-off-by: Nicholas Gates --- vortex-array/src/expr/pruning/pruning_expr.rs | 94 +++++++++++----- vortex-array/src/scalar_fn/fns/binary/mod.rs | 56 ++++++++++ vortex-array/src/scalar_fn/fns/is_not_null.rs | 7 +- vortex-array/src/scalar_fn/fns/is_null.rs | 9 +- vortex-array/src/stats/bind.rs | 103 ++++-------------- 5 files changed, 157 insertions(+), 112 deletions(-) diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 724dd6a1408..61d578787a0 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -10,6 +10,7 @@ use vortex_error::VortexResult; use vortex_session::VortexSession; #[cfg(test)] use vortex_utils::aliases::hash_map::HashMap; +use vortex_utils::aliases::hash_set::HashSet; use super::relation::Relation; use crate::dtype::DType; @@ -27,6 +28,7 @@ use crate::expr::root; use crate::expr::stats::Stat; use crate::scalar_fn::fns::cast::Cast; use crate::scalar_fn::fns::get_item::GetItem; +use crate::scalar_fn::fns::literal::Literal; use crate::stats::bind::StatBinder; use crate::stats::bind::bind_stats; @@ -78,9 +80,9 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa /// replace those placeholders with the row count for its current scope before /// executing the returned expression. /// -/// The returned expression is lowered to stats-table field references. Proof branches that require -/// stats not present in `available_stats` are discarded; this returns `Ok(None)` if no usable proof -/// remains. +/// The returned expression is lowered to stats-table field references. Stats not present in +/// `available_stats` are replaced with typed null literals, preserving three-valued pruning +/// semantics without requiring callers to materialize unavailable stats. pub fn checked_pruning_expr( expr: &Expression, scope: &DType, @@ -99,8 +101,12 @@ pub fn checked_pruning_expr( let Some(lowered) = bind_stats(predicate, &mut binder)? else { return Ok(None); }; + let required_stats = filter_required_stats(&lowered, binder.required_stats); + if required_stats.map().is_empty() && !matches!(bool_literal(&lowered), Some(Some(true))) { + return Ok(None); + } - Ok(Some((lowered, binder.required_stats))) + Ok(Some((lowered, required_stats))) } struct RequiredStatsBinder<'a> { @@ -141,23 +147,6 @@ impl StatBinder for RequiredStatsBinder<'_> { root(), ))) } - - fn missing_stat(&mut self, _dtype: DType) -> VortexResult> { - Ok(None) - } - - fn bind_branch(&mut self, bind: F) -> VortexResult> - where - Self: Sized, - F: FnOnce(&mut Self) -> VortexResult>, - { - let required_stats = self.required_stats.clone(); - let bound = bind(self)?; - if bound.is_none() { - self.required_stats = required_stats; - } - Ok(bound) - } } fn direct_stat_field_path(expr: &Expression) -> Option { @@ -173,6 +162,44 @@ fn direct_stat_field_path(expr: &Expression) -> Option { direct_stat_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } +fn filter_required_stats(expr: &Expression, required_stats: RequiredStats) -> RequiredStats { + let referenced_names = referenced_stat_field_names(expr); + let mut filtered = Relation::new(); + for (field_path, stats) in required_stats { + for stat in stats { + if referenced_names.contains(&field_path_stat_field_name(&field_path, stat)) { + filtered.insert(field_path.clone(), stat); + } + } + } + filtered +} + +fn referenced_stat_field_names(expr: &Expression) -> HashSet { + let mut refs = HashSet::new(); + collect_referenced_stat_field_names(expr, &mut refs); + refs +} + +fn collect_referenced_stat_field_names(expr: &Expression, refs: &mut HashSet) { + if let Some(field_name) = expr.as_opt::() + && is_root(expr.child(0)) + { + refs.insert(field_name.clone()); + return; + } + + for child in expr.children().iter() { + collect_referenced_stat_field_names(child, refs); + } +} + +fn bool_literal(expr: &Expression) -> Option> { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) +} + #[cfg(test)] mod tests { use std::sync::LazyLock; @@ -210,6 +237,7 @@ mod tests { use crate::expr::pruning::field_path_stat_field_name; use crate::expr::root; use crate::expr::stats::Stat; + use crate::scalar::Scalar; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::between::StrictComparison; use crate::stats::session::StatsSession; @@ -568,16 +596,26 @@ mod tests { let (converted, refs) = checked(&expr, &available_stats_with_nans).unwrap(); assert_eq!( refs.map(), - &HashMap::from_iter([( - FieldPath::from_name("int_col"), - HashSet::from_iter([Stat::Min]) - )]) + &HashMap::from_iter([ + ( + FieldPath::from_name("float_col"), + HashSet::from_iter([Stat::Max]) + ), + ( + FieldPath::from_name("int_col"), + HashSet::from_iter([Stat::Min]) + ) + ]) ); assert_eq!( &converted, - // The float branch cannot be proven without AllNonNan stats, so the - // remaining proof is the int branch. - >_eq(col("int_col_min"), lit(10)) + &or( + and( + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + lt_eq(col("float_col_max"), lit(10f32)), + ), + gt_eq(col("int_col_min"), lit(10)), + ) ) } diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..6babe5263d1 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,6 +17,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; +use crate::dtype::Nullability; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -34,6 +35,7 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -45,10 +47,45 @@ mod numeric; pub(crate) use numeric::*; use crate::scalar::NumericOperator; +use crate::scalar::Scalar; #[derive(Clone)] pub struct Binary; +fn simplify_and(lhs: &Expression, rhs: &Expression) -> Option { + match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)), + (Some(Some(true)), _) => Some(rhs.clone()), + (_, Some(Some(true))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + } +} + +fn simplify_or(lhs: &Expression, rhs: &Expression) -> Option { + match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)), + (Some(Some(false)), _) => Some(rhs.clone()), + (_, Some(Some(false))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + } +} + +fn bool_literal(expr: &Expression) -> Option> { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) +} + +fn is_null_literal(expr: &Expression) -> bool { + expr.as_opt::().is_some_and(Scalar::is_null) +} + +fn null_bool() -> Expression { + lit(Scalar::null(DType::Bool(Nullability::Nullable))) +} + impl ScalarFnVTable for Binary { type Options = Operator; @@ -165,6 +202,25 @@ impl ScalarFnVTable for Binary { } } + fn simplify_untyped( + &self, + operator: &Operator, + expr: &Expression, + ) -> VortexResult> { + let lhs = expr.child(0); + let rhs = expr.child(1); + + if operator.is_comparison() && (is_null_literal(lhs) || is_null_literal(rhs)) { + return Ok(Some(null_bool())); + } + + Ok(match operator { + Operator::And => simplify_and(lhs, rhs), + Operator::Or => simplify_or(lhs, rhs), + _ => None, + }) + } + fn stat_falsification( &self, operator: &Operator, diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 59986171ed1..eb60fc8aec8 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -135,6 +135,8 @@ mod tests { use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; + use crate::expr::lit; + use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; @@ -279,7 +281,10 @@ mod tests { assert_eq!( &pruning_expr, - &eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])) + &or( + eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + ) ); assert_eq!( st.map(), diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 6a971e7ecd4..bbf8a8f2409 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -125,6 +125,7 @@ mod tests { use crate::expr::get_item; use crate::expr::is_null; use crate::expr::lit; + use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; @@ -261,7 +262,13 @@ mod tests { .unwrap() .unwrap(); - assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64))); + assert_eq!( + &pruning_expr, + &or( + eq(col("a_null_count"), lit(0u64)), + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + ) + ); assert_eq!( st.map(), &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs index 82166e0b1bd..752664396c6 100644 --- a/vortex-array/src/stats/bind.rs +++ b/vortex-array/src/stats/bind.rs @@ -10,9 +10,9 @@ use crate::dtype::DType; use crate::expr::Expression; use crate::expr::lit; use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; use crate::scalar::Scalar; -use crate::scalar_fn::fns::binary::Binary; -use crate::scalar_fn::fns::operators::Operator; use crate::scalar_fn::fns::stat::StatFn; /// A target that can bind abstract statistics to concrete expressions. @@ -57,20 +57,6 @@ pub trait StatBinder { fn missing_stat(&mut self, dtype: DType) -> VortexResult> { Ok(Some(null_expr(dtype))) } - - /// Bind a proof branch, rolling back any binder-local bookkeeping when the - /// branch cannot be bound. - /// - /// Binders that only substitute expressions can use the default - /// implementation. Binders that track required stats should override this - /// so discarded proof branches do not leak requirements. - fn bind_branch(&mut self, bind: F) -> VortexResult> - where - Self: Sized, - F: FnOnce(&mut Self) -> VortexResult>, - { - bind(self) - } } /// Bind all `vortex.stat` expressions in `predicate`. @@ -83,75 +69,28 @@ pub fn bind_stats( binder: &mut impl StatBinder, ) -> VortexResult> { let scope = binder.scope().clone(); - bind_stats_expr(predicate, &scope, binder) -} - -fn bind_stats_expr( - expr: Expression, - scope: &DType, - binder: &mut impl StatBinder, -) -> VortexResult> { - if expr.is::() { - return match bind_stat_fn(&expr, scope, binder)? { - Some(bound) => Ok(Some(bound)), - None => { - let dtype = expr.return_dtype(scope)?; - binder.missing_stat(dtype) + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); } - }; - } - - if expr.is::() { - return bind_binary_expr(expr, scope, binder); - } - - let mut children = Vec::with_capacity(expr.children().len()); - for child in expr.children().iter() { - let Some(child) = bind_stats_expr(child.clone(), scope, binder)? else { - return Ok(None); - }; - children.push(child); - } - Ok(Some(expr.with_children(children)?)) -} - -fn bind_binary_expr( - expr: Expression, - scope: &DType, - binder: &mut impl StatBinder, -) -> VortexResult> { - let operator = expr.as_::(); - - match operator { - Operator::Or => { - let lhs = binder - .bind_branch(|binder| bind_stats_expr(expr.child(0).clone(), scope, binder))?; - let rhs = binder - .bind_branch(|binder| bind_stats_expr(expr.child(1).clone(), scope, binder))?; - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), - (Some(expr), None) | (None, Some(expr)) => Ok(Some(expr)), - (None, None) => Ok(None), - } - } - Operator::And => binder.bind_branch(|binder| { - let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; - let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), - _ => Ok(None), + match bind_stat_fn(&expr, &scope, binder)? { + Some(bound) => Ok(Transformed::yes(bound)), + None => { + let dtype = expr.return_dtype(&scope)?; + match binder.missing_stat(dtype.clone())? { + Some(missing) => Ok(Transformed::yes(missing)), + None => Ok(Transformed::yes(null_expr(dtype))), + } + } } - }), - _ => binder.bind_branch(|binder| { - let lhs = bind_stats_expr(expr.child(0).clone(), scope, binder)?; - let rhs = bind_stats_expr(expr.child(1).clone(), scope, binder)?; - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => Ok(Some(expr.with_children([lhs, rhs])?)), - _ => Ok(None), - } - }), - } + })? + .into_inner(); + + #[expect(deprecated)] + let lowered = lowered.simplify_untyped()?; + Ok(Some(lowered)) } fn bind_stat_fn( From 52fd44337e3a471702db98c08c51ae85de9dbdec Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 16:13:25 -0400 Subject: [PATCH 6/8] Install Java toolchain in CI Signed-off-by: Nicholas Gates --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dea9f3b9ef4..ed68d375724 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -376,6 +376,10 @@ jobs: with: sccache: s3 - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + - uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 + with: + distribution: "corretto" + java-version: "17" - uses: ./.github/actions/setup-prebuild - run: ./gradlew javadoc working-directory: ./java From 429383824e9918bf29a7a3b3f1f8d6653b61b990 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 16:39:09 -0400 Subject: [PATCH 7/8] Remove legacy stat falsification hooks Signed-off-by: Nicholas Gates --- vortex-array/src/expr/expression.rs | 22 ---- vortex-array/src/expr/pruning/pruning_expr.rs | 27 ---- vortex-array/src/scalar_fn/erased.rs | 9 -- vortex-array/src/scalar_fn/fns/between/mod.rs | 19 --- vortex-array/src/scalar_fn/fns/binary/mod.rs | 117 +----------------- vortex-array/src/scalar_fn/fns/dynamic.rs | 45 ------- vortex-array/src/scalar_fn/fns/is_not_null.rs | 17 --- vortex-array/src/scalar_fn/fns/is_null.rs | 15 --- vortex-array/src/scalar_fn/fns/like/mod.rs | 101 --------------- .../src/scalar_fn/fns/list_contains/mod.rs | 45 ------- vortex-array/src/scalar_fn/typed.rs | 13 -- vortex-array/src/scalar_fn/vtable.rs | 13 -- vortex-file/src/v2/file_stats_reader.rs | 13 +- 13 files changed, 9 insertions(+), 447 deletions(-) diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index cc21fb9a9a6..043b61aaaf4 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -114,28 +114,6 @@ impl Expression { self.scalar_fn.validity(self) } - /// An expression over zone-statistics which implies all records in the zone evaluate to false. - /// - /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed - /// that `e` evaluates to false on all records in the zone. However, the inverse is not - /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to - /// true on all records. - /// - /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr. - /// - /// # Examples - /// - /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum - /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`. - /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`. - /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x). - /// - /// Some expressions, in theory, have falsifications but this function does not support them - /// such as `x < (y < z)` or `x LIKE "needle%"`. - pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_falsification(self, catalog) - } - /// Returns an expression that proves this predicate is definitely false from stats. /// /// `scope` is the dtype of the row this expression evaluates over. diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 61d578787a0..ee775b4f13e 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -1,15 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -#[cfg(test)] -use std::cell::RefCell; use std::iter; use itertools::Itertools; use vortex_error::VortexResult; use vortex_session::VortexSession; -#[cfg(test)] -use vortex_utils::aliases::hash_map::HashMap; use vortex_utils::aliases::hash_set::HashSet; use super::relation::Relation; @@ -19,8 +15,6 @@ use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; -#[cfg(test)] -use crate::expr::StatsCatalog; use crate::expr::analysis::referenced_field_paths; use crate::expr::get_item; use crate::expr::is_root; @@ -34,27 +28,6 @@ use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; -// A catalog that returns a stat column whenever it is required, tracking all accessed -// stats and returning them later. -#[cfg(test)] -#[derive(Default)] -pub(crate) struct TrackingStatsCatalog { - usage: RefCell>, -} - -#[cfg(test)] -impl StatsCatalog for TrackingStatsCatalog { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let mut expr = root(); - let name = field_path_stat_field_name(field_path, stat); - expr = get_item(name, expr); - self.usage - .borrow_mut() - .insert((field_path.clone(), stat), expr.clone()); - Some(expr) - } -} - #[doc(hidden)] pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName { field_path diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 10e82d25455..69befb405e0 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -181,15 +181,6 @@ impl ScalarFnRef { self.0.simplify_untyped(expr) } - /// Compute stat falsification expression. - pub(crate) fn stat_falsification( - &self, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_falsification(expr, catalog) - } - /// Compute stat expression. pub(crate) fn stat_expression( &self, diff --git a/vortex-array/src/scalar_fn/fns/between/mod.rs b/vortex-array/src/scalar_fn/fns/between/mod.rs index 0e0d9949195..bd546bed941 100644 --- a/vortex-array/src/scalar_fn/fns/between/mod.rs +++ b/vortex-array/src/scalar_fn/fns/between/mod.rs @@ -25,7 +25,6 @@ use crate::arrays::Primitive; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::DType::Bool; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; use crate::scalar::Scalar; use crate::scalar_fn::Arity; @@ -33,8 +32,6 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; -use crate::scalar_fn::fns::binary::Binary; use crate::scalar_fn::fns::binary::execute_boolean; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -298,22 +295,6 @@ impl ScalarFnVTable for Between { between_canonical(&arr, &lower, &upper, options, ctx) } - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let arr = expr.child(0).clone(); - let lower = expr.child(1).clone(); - let upper = expr.child(2).clone(); - - let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]); - let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); - - and(lhs, rhs).stat_falsification(catalog) - } - fn validity( &self, _options: &Self::Options, diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 6babe5263d1..b51f86b3188 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -18,18 +18,9 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::dtype::Nullability; -use crate::expr::StatsCatalog; use crate::expr::and; -use crate::expr::and_collect; -use crate::expr::eq; use crate::expr::expression::Expression; -use crate::expr::gt; -use crate::expr::gt_eq; use crate::expr::lit; -use crate::expr::lt; -use crate::expr::lt_eq; -use crate::expr::or_collect; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -221,110 +212,6 @@ impl ScalarFnVTable for Binary { }) } - fn stat_falsification( - &self, - operator: &Operator, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Wrap another predicate with an optional NaNCount check, if the stat is available. - // - // For example, regular pruning conversion for `A >= B` would be - // - // A.max < B.min - // - // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting - // in: - // - // (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min - // - // Non-floating point column and literal expressions should be unaffected as they do not - // have a nan_count statistic defined. - fn with_nan_predicate( - lhs: &Expression, - rhs: &Expression, - value_predicate: Expression, - catalog: &dyn StatsCatalog, - ) -> Expression { - let nan_predicate = and_collect( - lhs.stat_expression(Stat::NaNCount, catalog) - .into_iter() - .chain(rhs.stat_expression(Stat::NaNCount, catalog)) - .map(|nans| eq(nans, lit(0u64))), - ); - - if let Some(nan_check) = nan_predicate { - and(nan_check, value_predicate) - } else { - value_predicate - } - } - - let lhs = expr.child(0); - let rhs = expr.child(1); - match operator { - Operator::Eq => { - let min_lhs = lhs.stat_min(catalog); - let max_lhs = lhs.stat_max(catalog); - - let min_rhs = rhs.stat_min(catalog); - let max_rhs = rhs.stat_max(catalog); - - let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b)); - let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b)); - - let min_max_check = or_collect(left.into_iter().chain(right))?; - - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::NotEq => { - let min_lhs = lhs.stat_min(catalog)?; - let max_lhs = lhs.stat_max(catalog)?; - - let min_rhs = rhs.stat_min(catalog)?; - let max_rhs = rhs.stat_max(catalog)?; - - let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gt => { - let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lt => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::And => or_collect( - lhs.stat_falsification(catalog) - .into_iter() - .chain(rhs.stat_falsification(catalog)), - ), - Operator::Or => Some(and( - lhs.stat_falsification(catalog)?, - rhs.stat_falsification(catalog)?, - )), - Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, - } - } - fn validity( &self, operator: &Operator, @@ -381,8 +268,12 @@ mod tests { use crate::expr::Expression; use crate::expr::and_collect; use crate::expr::col; + use crate::expr::eq; + use crate::expr::gt; + use crate::expr::gt_eq; use crate::expr::lit; use crate::expr::lt; + use crate::expr::lt_eq; use crate::expr::not_eq; use crate::expr::or; use crate::expr::or_collect; diff --git a/vortex-array/src/scalar_fn/fns/dynamic.rs b/vortex-array/src/scalar_fn/fns/dynamic.rs index 7efebf79220..f6e6619282a 100644 --- a/vortex-array/src/scalar_fn/fns/dynamic.rs +++ b/vortex-array/src/scalar_fn/fns/dynamic.rs @@ -20,7 +20,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::traversal::NodeExt; use crate::expr::traversal::NodeVisitor; use crate::expr::traversal::TraversalOrder; @@ -120,50 +119,6 @@ impl ScalarFnVTable for DynamicComparison { .into_array()) } - fn stat_falsification( - &self, - dynamic: &DynamicComparisonExpr, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let lhs = expr.child(0); - match dynamic.operator { - CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Gte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Lt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - CompareOperator::Lte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - } - } - // Defer to the child fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index eb60fc8aec8..a64aab32611 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -3,7 +3,6 @@ use std::fmt::Formatter; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_error::VortexResult; use vortex_session::VortexSession; use vortex_session::registry::CachedId; @@ -15,16 +14,12 @@ use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; use crate::validity::Validity; /// Expression that checks for non-null values. @@ -100,18 +95,6 @@ impl ScalarFnVTable for IsNotNull { fn is_fallible(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // is_not_null is falsified when ALL values are null, i.e. null_count == row_count. - let child = expr.child(0); - let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, RowCount.new_expr(EmptyOptions, []))) - } } #[cfg(test)] diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index bbf8a8f2409..807b9d9a043 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -12,11 +12,6 @@ use crate::arrays::ConstantArray; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -84,16 +79,6 @@ impl ScalarFnVTable for IsNull { } } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, lit(0u64))) - } - fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } diff --git a/vortex-array/src/scalar_fn/fns/like/mod.rs b/vortex-array/src/scalar_fn/fns/like/mod.rs index b7f357020f1..850484a1450 100644 --- a/vortex-array/src/scalar_fn/fns/like/mod.rs +++ b/vortex-array/src/scalar_fn/fns/like/mod.rs @@ -21,20 +21,12 @@ use crate::arrow::Datum; use crate::arrow::from_arrow_columnar; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::and; -use crate::expr::gt; -use crate::expr::gt_eq; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; -use crate::scalar::StringLike; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::fns::literal::Literal; /// Options for SQL LIKE function #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -165,49 +157,6 @@ impl ScalarFnVTable for Like { fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - like_opts: &LikeOptions, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Attempt to do min/max pruning for LIKE 'exact' or LIKE 'prefix%' - - // Don't attempt to handle ilike or negated like - if like_opts.negated || like_opts.case_insensitive { - return None; - } - - // Extract the pattern out - let pat = expr.child(1).as_::(); - - // LIKE NULL is nonsensical, don't try to handle it - let pat_str = pat.as_utf8().value()?; - - let src = expr.child(0).clone(); - let src_min = src.stat_min(catalog)?; - let src_max = src.stat_max(catalog)?; - - match LikeVariant::from_str(pat_str)? { - LikeVariant::Exact(text) => { - // col LIKE 'exact' ==> col.min > 'exact' || col.max < 'exact' - Some(or( - gt(src_min, lit(text.as_ref())), - lt(src_max, lit(text.as_ref())), - )) - } - LikeVariant::Prefix(prefix) => { - // col LIKE 'prefix%' ==> col.max < 'prefix' || col.min >= 'prefiy' - let succ = prefix.to_string().increment().ok()?; - - Some(or( - gt_eq(src_min, lit(succ)), - lt(src_max, lit(prefix.as_ref())), - )) - } - } - } } /// Implementation of LIKE using the Arrow crate. @@ -295,15 +244,11 @@ mod tests { use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; - use crate::expr::col; use crate::expr::get_item; - use crate::expr::ilike; use crate::expr::like; use crate::expr::lit; use crate::expr::not; use crate::expr::not_ilike; - use crate::expr::not_like; - use crate::expr::pruning::pruning_expr::TrackingStatsCatalog; use crate::expr::root; use crate::scalar_fn::fns::like::LikeVariant; @@ -390,50 +335,4 @@ mod tests { assert_eq!(LikeVariant::from_str(r"%\%%"), None); assert_eq!(LikeVariant::from_str("_pattern"), None); } - - #[test] - fn test_like_pushdown() { - // Test that LIKE prefix and exactness filters can be pushed down into stats filtering - // at scan time. - let catalog = TrackingStatsCatalog::default(); - - let pruning_expr = like(col("a"), lit("prefix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "prefiy") or ($.a_max < "prefix"))"#); - - let pruning_expr = like(col("a"), lit(r"\%%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "&") or ($.a_max < "%"))"#); - - // Multiple wildcards - let pruning_expr = like(col("a"), lit("pref%ix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - let pruning_expr = like(col("a"), lit("pref_ix_")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - // Exact match - let pruning_expr = like(col("a"), lit("exactly")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min > "exactly") or ($.a_max < "exactly"))"#); - - // Suffix search skips pushdown - let pruning_expr = like(col("a"), lit("%suffix")).stat_falsification(&catalog); - assert_eq!(pruning_expr, None); - - // NOT LIKE, ILIKE not supported currently - assert_eq!( - None, - not_like(col("a"), lit("a")).stat_falsification(&catalog) - ); - assert_eq!(None, ilike(col("a"), lit("a")).stat_falsification(&catalog)); - } } diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index bbaed489fd1..236bdd646eb 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -33,13 +33,6 @@ use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::and_collect; -use crate::expr::gt; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; use crate::match_each_integer_ptype; use crate::match_each_unsigned_integer_ptype; use crate::scalar::ListScalar; @@ -51,7 +44,6 @@ use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; use crate::scalar_fn::fns::binary::Binary; -use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity; @@ -129,43 +121,6 @@ impl ScalarFnVTable for ListContains { compute_list_contains(&list_array, &value_array, ctx) } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let list = expr.child(0); - let needle = expr.child(1); - - // falsification(contains([1,2,5], x)) => - // falsification(x != 1) and falsification(x != 2) and falsification(x != 5) - let min = list.stat_min(catalog)?; - let max = list.stat_max(catalog)?; - // If the list is constant when we can compare each element to the value - if min == max { - let list_ = min - .as_opt::() - .and_then(|l| l.as_list_opt()) - .and_then(|l| l.elements())?; - if list_.is_empty() { - // contains([], x) is always false. - return Some(lit(true)); - } - let value_max = needle.stat_max(catalog)?; - let value_min = needle.stat_min(catalog)?; - - return and_collect(list_.iter().map(move |v| { - or( - lt(value_max.clone(), lit(v.clone())), - gt(value_min.clone(), lit(v.clone())), - ) - })); - } - - None - } - // Nullability matters for contains([], x) where x is false. fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index a2ef9549bff..83d2bfea496 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -101,11 +101,6 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; fn validity(&self, expression: &Expression) -> VortexResult>; - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option; fn stat_expression( &self, expression: &Expression, @@ -223,14 +218,6 @@ impl DynScalarFn for TypedScalarFnInstance { V::validity(&self.vtable, &self.options, expression) } - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_falsification(&self.vtable, &self.options, expression, catalog) - } - fn stat_expression( &self, expression: &Expression, diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index f4862f6876a..1556354f9f1 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -179,19 +179,6 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(None) } - /// See [`Expression::stat_falsification`]. - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = catalog; - None - } - /// See [`Expression::stat_expression`]. fn stat_expression( &self, diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 7d9bd0d66cf..b697becbc0f 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -21,7 +21,6 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; -use vortex_array::expr::StatsCatalog; use vortex_array::expr::is_root; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; @@ -94,9 +93,8 @@ impl FileStatsLayoutReader { }; let pruning_expr = self.lower_stats(pruning_expr)?; - // Given how we implemented the StatsCatalog, we know the expression must be all literals - // or row_count placeholders. We can therefore optimize with a null scope since there are - // no field references that need to be resolved. + // Stats lowering replaces available stats with literals and unavailable stats with nulls, + // so only row_count placeholders remain unresolved here. let simplified = pruning_expr.optimize_recursive(&DType::Null)?; if let Some(result) = simplified.as_opt::() { // Can prune if the result is non-nullable and true @@ -152,7 +150,7 @@ impl StatBinder for FileStatsBinder<'_> { let Some(field_path) = direct_field_path(input) else { return Ok(None); }; - Ok(self.reader.stats_ref(&field_path, stat)) + Ok(self.reader.stat_ref(&field_path, stat)) } } @@ -165,9 +163,8 @@ fn direct_field_path(expr: &Expression) -> Option { direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) } -/// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. -impl StatsCatalog for FileStatsLayoutReader { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { +impl FileStatsLayoutReader { + fn stat_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { // FileStats currently only holds top-level field statistics. if field_path.parts().len() != 1 { return None; From ffbccee5e987e734132ebb8dcebe6b63b9ec3b93 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 16:45:16 -0400 Subject: [PATCH 8/8] Remove legacy stat expression hooks Signed-off-by: Nicholas Gates --- vortex-array/src/expr/expression.rs | 24 ----------- vortex-array/src/expr/mod.rs | 1 - vortex-array/src/expr/pruning/mod.rs | 17 -------- vortex-array/src/scalar_fn/erased.rs | 12 ------ vortex-array/src/scalar_fn/fns/cast/mod.rs | 36 ----------------- vortex-array/src/scalar_fn/fns/get_item.rs | 21 ---------- vortex-array/src/scalar_fn/fns/literal.rs | 47 ---------------------- vortex-array/src/scalar_fn/fns/root.rs | 13 ------ vortex-array/src/scalar_fn/typed.rs | 17 -------- vortex-array/src/scalar_fn/vtable.rs | 17 -------- 10 files changed, 205 deletions(-) diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index 043b61aaaf4..41cd61a369d 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -15,9 +15,7 @@ use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::dtype::DType; -use crate::expr::StatsCatalog; use crate::expr::display::DisplayTreeExpr; -use crate::expr::stats::Stat; use crate::scalar_fn::ScalarFnRef; use crate::scalar_fn::fns::root::Root; @@ -142,28 +140,6 @@ impl Expression { crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self) } - /// Returns an expression representing the zoned statistic for the given stat, if available. - /// - /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a - /// scope. Expressions can implement this function to propagate such statistics through the - /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`. - /// - /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an - /// issue to discuss a solution to this. - pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_expression(self, stat, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Min, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Max, catalog) - } - /// Format the expression as a compact string. /// /// Since this is a recursive formatter, it is exposed on the public Expression type. diff --git a/vortex-array/src/expr/mod.rs b/vortex-array/src/expr/mod.rs index a5d32510443..72969baf23a 100644 --- a/vortex-array/src/expr/mod.rs +++ b/vortex-array/src/expr/mod.rs @@ -42,7 +42,6 @@ pub mod traversal; pub use analysis::*; pub use expression::*; pub use exprs::*; -pub use pruning::StatsCatalog; pub trait VortexExprExt { /// Accumulate all field references from this expression and its children in a set diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index 7c20508b7a8..5ce2785f446 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -8,20 +8,3 @@ pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; - -use crate::dtype::FieldPath; -use crate::expr::Expression; -use crate::expr::stats::Stat; - -/// A catalog of available stats that are associated with field paths. -pub trait StatsCatalog { - /// Given a field path and statistic, return an expression that when evaluated over the catalog - /// will return that stat for the referenced field. - /// - /// This is likely to be a column expression, or a literal. - /// - /// Returns `None` if the stat is not available for the field path. - fn stats_ref(&self, _field_path: &FieldPath, _stat: Stat) -> Option { - None - } -} diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 69befb405e0..6e0011c297a 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ReduceCtx; @@ -180,16 +178,6 @@ impl ScalarFnRef { pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult> { self.0.simplify_untyped(expr) } - - /// Compute stat expression. - pub(crate) fn stat_expression( - &self, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_expression(expr, stat, catalog) - } } impl Debug for ScalarFnRef { diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index abc59af2c9a..20852779d42 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -32,11 +32,8 @@ use crate::arrays::VarBinView; use crate::arrays::struct_::compute::cast::struct_cast; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; -use crate::expr::StatsCatalog; -use crate::expr::cast; use crate::expr::expression::Expression; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -152,39 +149,6 @@ impl ScalarFnVTable for Cast { Ok(None) } - fn stat_expression( - &self, - dtype: &DType, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - match stat { - Stat::IsConstant - | Stat::IsSorted - | Stat::IsStrictSorted - | Stat::NaNCount - | Stat::Sum - | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog), - Stat::Max | Stat::Min => { - // We cast min/max to the new type - expr.child(0) - .stat_expression(stat, catalog) - .map(|x| cast(x, dtype.clone())) - } - Stat::NullCount => { - // if !expr.data().is_nullable() { - // NOTE(ngates): we should decide on the semantics here. In theory, the null - // count of something cast to non-nullable will be zero. But if we return - // that we know this to be zero, then a pruning predicate may eliminate data - // that would otherwise have caused the cast to error. - // return Some(lit(0u64)); - // } - None - } - } - } - fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult> { Ok(Some(if dtype.is_nullable() { expression.child(0).validity()? diff --git a/vortex-array/src/scalar_fn/fns/get_item.rs b/vortex-array/src/scalar_fn/fns/get_item.rs index de7e45ca9b0..b9cb9202f38 100644 --- a/vortex-array/src/scalar_fn/fns/get_item.rs +++ b/vortex-array/src/scalar_fn/fns/get_item.rs @@ -18,12 +18,9 @@ use crate::builtins::ArrayBuiltins; use crate::builtins::ExprBuiltins; use crate::dtype::DType; use crate::dtype::FieldName; -use crate::dtype::FieldPath; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -188,24 +185,6 @@ impl ScalarFnVTable for GetItem { Ok(None) } - fn stat_expression( - &self, - field_name: &FieldName, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): I think we can do better here and support stats over nested fields. - // It would be nice if delegating to our child would return a struct of statistics - // matching the nested DType such that we can write: - // `get_item(expr.child(0).stat_expression(...), expr.data().field_name())` - - // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same - // name as a field in the root struct. This should be resolved with upcoming change to - // falsify expressions, but for now I'm preserving the existing buggy behavior. - catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat) - } - // This will apply struct nullability field. We could add a dtype?? fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { true diff --git a/vortex-array/src/scalar_fn/fns/literal.rs b/vortex-array/src/scalar_fn/fns/literal.rs index 16b112e5a78..5181a5250dd 100644 --- a/vortex-array/src/scalar_fn/fns/literal.rs +++ b/vortex-array/src/scalar_fn/fns/literal.rs @@ -16,9 +16,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; -use crate::match_each_float_ptype; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -96,50 +93,6 @@ impl ScalarFnVTable for Literal { Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array()) } - fn stat_expression( - &self, - scalar: &Scalar, - _expr: &Expression, - stat: Stat, - _catalog: &dyn StatsCatalog, - ) -> Option { - // NOTE(ngates): we return incorrect `1` values for counts here since we don't have - // row-count information. We could resolve this in the future by introducing a `count()` - // expression that evaluates to the row count of the provided scope. But since this is - // only currently used for pruning, it doesn't change the outcome. - - match stat { - Stat::Min | Stat::Max => Some(lit(scalar.clone())), - Stat::IsConstant => Some(lit(true)), - Stat::NaNCount => { - // The NaNCount for a non-float literal is not defined. - // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise. - let value = scalar.as_primitive_opt()?; - if !value.ptype().is_float() { - return None; - } - - match_each_float_ptype!(value.ptype(), |T| { - if value.typed_value::().is_some_and(|v| v.is_nan()) { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - }) - } - Stat::NullCount => { - if scalar.is_null() { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - } - Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => { - None - } - } - } - fn validity( &self, scalar: &Scalar, diff --git a/vortex-array/src/scalar_fn/fns/root.rs b/vortex-array/src/scalar_fn/fns/root.rs index 87b8b62ccf4..7bd5b758796 100644 --- a/vortex-array/src/scalar_fn/fns/root.rs +++ b/vortex-array/src/scalar_fn/fns/root.rs @@ -11,10 +11,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; -use crate::dtype::FieldPath; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -80,16 +77,6 @@ impl ScalarFnVTable for Root { vortex_bail!("Root expression is not executable") } - fn stat_expression( - &self, - _options: &Self::Options, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - catalog.stats_ref(&FieldPath::root(), stat) - } - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index 83d2bfea496..e31cdc79ed0 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -22,8 +22,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -101,12 +99,6 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; fn validity(&self, expression: &Expression) -> VortexResult>; - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option; // Options operations — self-contained fn options_serialize(&self) -> VortexResult>>; @@ -218,15 +210,6 @@ impl DynScalarFn for TypedScalarFnInstance { V::validity(&self.vtable, &self.options, expression) } - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_expression(&self.vtable, &self.options, expression, stat, catalog) - } - fn options_serialize(&self) -> VortexResult>> { V::serialize(&self.vtable, &self.options) } diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index 1556354f9f1..c66afc34932 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::expr::traversal::Node; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnRef; @@ -179,21 +177,6 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(None) } - /// See [`Expression::stat_expression`]. - fn stat_expression( - &self, - options: &Self::Options, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = stat; - _ = catalog; - None - } - /// Returns an expression that evaluates to the validity of the result of this expression. /// /// If a validity expression cannot be constructed, returns `None` and the expression will