diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dea9f3b9ef4..ed68d375724 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -376,6 +376,10 @@ jobs: with: sccache: s3 - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + - uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 + with: + distribution: "corretto" + java-version: "17" - uses: ./.github/actions/setup-prebuild - run: ./gradlew javadoc working-directory: ./java diff --git a/vortex-array/src/expr/expression.rs b/vortex-array/src/expr/expression.rs index cc21fb9a9a6..41cd61a369d 100644 --- a/vortex-array/src/expr/expression.rs +++ b/vortex-array/src/expr/expression.rs @@ -15,9 +15,7 @@ use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::dtype::DType; -use crate::expr::StatsCatalog; use crate::expr::display::DisplayTreeExpr; -use crate::expr::stats::Stat; use crate::scalar_fn::ScalarFnRef; use crate::scalar_fn::fns::root::Root; @@ -114,28 +112,6 @@ impl Expression { self.scalar_fn.validity(self) } - /// An expression over zone-statistics which implies all records in the zone evaluate to false. - /// - /// Given an expression, `e`, if `e.stat_falsification(..)` evaluates to true, it is guaranteed - /// that `e` evaluates to false on all records in the zone. However, the inverse is not - /// necessarily true: even if the falsification evaluates to false, `e` need not evaluate to - /// true on all records. - /// - /// The [`StatsCatalog`] can be used to constrain or rename stats used in the final expr. - /// - /// # Examples - /// - /// - An expression over one variable: `x > 0` is false for all records in a zone if the maximum - /// value of the column `x` in that zone is less than or equal to zero: `max(x) <= 0`. - /// - An expression over two variables: `x > y` becomes `max(x) <= min(y)`. - /// - A conjunctive expression: `x > y AND z < x` becomes `max(x) <= min(y) OR min(z) >= max(x). - /// - /// Some expressions, in theory, have falsifications but this function does not support them - /// such as `x < (y < z)` or `x LIKE "needle%"`. - pub fn stat_falsification(&self, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_falsification(self, catalog) - } - /// Returns an expression that proves this predicate is definitely false from stats. /// /// `scope` is the dtype of the row this expression evaluates over. @@ -164,28 +140,6 @@ impl Expression { crate::stats::rewrite::StatsRewriteCtx::new(session, scope).satisfy(self) } - /// Returns an expression representing the zoned statistic for the given stat, if available. - /// - /// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a - /// scope. Expressions can implement this function to propagate such statistics through the - /// expression tree. For example, the `a + 10` expression could propagate `min: min(a) + 10`. - /// - /// NOTE(gatesn): we currently cannot represent statistics over nested fields. Please file an - /// issue to discuss a solution to this. - pub fn stat_expression(&self, stat: Stat, catalog: &dyn StatsCatalog) -> Option { - self.scalar_fn().stat_expression(self, stat, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_min(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Min, catalog) - } - - /// Returns an expression representing the zoned maximum statistic, if available. - pub fn stat_max(&self, catalog: &dyn StatsCatalog) -> Option { - self.stat_expression(Stat::Max, catalog) - } - /// Format the expression as a compact string. /// /// Since this is a recursive formatter, it is exposed on the public Expression type. diff --git a/vortex-array/src/expr/mod.rs b/vortex-array/src/expr/mod.rs index a5d32510443..72969baf23a 100644 --- a/vortex-array/src/expr/mod.rs +++ b/vortex-array/src/expr/mod.rs @@ -42,7 +42,6 @@ pub mod traversal; pub use analysis::*; pub use expression::*; pub use exprs::*; -pub use pruning::StatsCatalog; pub trait VortexExprExt { /// Accumulate all field references from this expression and its children in a set diff --git a/vortex-array/src/expr/pruning/mod.rs b/vortex-array/src/expr/pruning/mod.rs index 7c20508b7a8..5ce2785f446 100644 --- a/vortex-array/src/expr/pruning/mod.rs +++ b/vortex-array/src/expr/pruning/mod.rs @@ -8,20 +8,3 @@ pub use pruning_expr::RequiredStats; pub use pruning_expr::checked_pruning_expr; pub use pruning_expr::field_path_stat_field_name; pub use relation::Relation; - -use crate::dtype::FieldPath; -use crate::expr::Expression; -use crate::expr::stats::Stat; - -/// A catalog of available stats that are associated with field paths. -pub trait StatsCatalog { - /// Given a field path and statistic, return an expression that when evaluated over the catalog - /// will return that stat for the referenced field. - /// - /// This is likely to be a column expression, or a literal. - /// - /// Returns `None` if the stat is not available for the field path. - fn stats_ref(&self, _field_path: &FieldPath, _stat: Stat) -> Option { - None - } -} diff --git a/vortex-array/src/expr/pruning/pruning_expr.rs b/vortex-array/src/expr/pruning/pruning_expr.rs index 00d29fbcf99..ee775b4f13e 100644 --- a/vortex-array/src/expr/pruning/pruning_expr.rs +++ b/vortex-array/src/expr/pruning/pruning_expr.rs @@ -1,70 +1,33 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::cell::RefCell; use std::iter; use itertools::Itertools; -use vortex_utils::aliases::hash_map::HashMap; +use vortex_error::VortexResult; +use vortex_session::VortexSession; +use vortex_utils::aliases::hash_set::HashSet; use super::relation::Relation; +use crate::dtype::DType; use crate::dtype::Field; use crate::dtype::FieldName; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::expr::Expression; -use crate::expr::StatsCatalog; +use crate::expr::analysis::referenced_field_paths; use crate::expr::get_item; +use crate::expr::is_root; use crate::expr::root; use crate::expr::stats::Stat; +use crate::scalar_fn::fns::cast::Cast; +use crate::scalar_fn::fns::get_item::GetItem; +use crate::scalar_fn::fns::literal::Literal; +use crate::stats::bind::StatBinder; +use crate::stats::bind::bind_stats; pub type RequiredStats = Relation; -// A catalog that return a stat column whenever it is required, tracking all accessed -// stats and returning them later. -#[derive(Default)] -pub(crate) struct TrackingStatsCatalog { - usage: RefCell>, -} - -impl TrackingStatsCatalog { - /// Consume the catalog, yielding a map of field statistics that were required - /// for each expression. - fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> { - self.usage.into_inner() - } -} - -// A catalog that return a stat column if it exists in the given scope. -struct ScopeStatsCatalog<'a> { - inner: TrackingStatsCatalog, - available_stats: &'a FieldPathSet, -} - -impl StatsCatalog for ScopeStatsCatalog<'_> { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let stat_path = field_path.clone().push(stat.name()); - - if self.available_stats.contains(&stat_path) { - self.inner.stats_ref(field_path, stat) - } else { - None - } - } -} - -impl StatsCatalog for TrackingStatsCatalog { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { - let mut expr = root(); - let name = field_path_stat_field_name(field_path, stat); - expr = get_item(name, expr); - self.usage - .borrow_mut() - .insert((field_path.clone(), stat), expr.clone()); - Some(expr) - } -} - #[doc(hidden)] pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName { field_path @@ -79,8 +42,7 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa .into() } -/// Build a pruning expr mask, using an existing set of stats. -/// The available stats are provided as a set of [`FieldPath`]. +/// Build a pruning expression using session-registered stats rewrite rules. /// /// A pruning expression is one that returns `true` for all positions where the original expression /// cannot hold, and false if it cannot be determined from stats alone whether the positions can @@ -91,42 +53,146 @@ pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldNa /// replace those placeholders with the row count for its current scope before /// executing the returned expression. /// -/// If the falsification logic attempts to access an unknown stat, -/// this function will return `None`. +/// The returned expression is lowered to stats-table field references. Stats not present in +/// `available_stats` are replaced with typed null literals, preserving three-valued pruning +/// semantics without requiring callers to materialize unavailable stats. pub fn checked_pruning_expr( expr: &Expression, + scope: &DType, available_stats: &FieldPathSet, -) -> Option<(Expression, RequiredStats)> { - let catalog = ScopeStatsCatalog { - inner: Default::default(), + session: &VortexSession, +) -> VortexResult> { + let Some(predicate) = expr.falsify(scope, session)? else { + return Ok(None); + }; + + let mut binder = RequiredStatsBinder { + scope, available_stats, + required_stats: Relation::new(), }; + let Some(lowered) = bind_stats(predicate, &mut binder)? else { + return Ok(None); + }; + let required_stats = filter_required_stats(&lowered, binder.required_stats); + if required_stats.map().is_empty() && !matches!(bool_literal(&lowered), Some(Some(true))) { + return Ok(None); + } + + Ok(Some((lowered, required_stats))) +} + +struct RequiredStatsBinder<'a> { + scope: &'a DType, + available_stats: &'a FieldPathSet, + required_stats: RequiredStats, +} + +impl StatBinder for RequiredStatsBinder<'_> { + fn scope(&self) -> &DType { + self.scope + } + + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let field_path = match direct_stat_field_path(input) { + Some(field_path) => field_path, + None => { + let field_paths = referenced_field_paths(input, self.scope)?; + let Some(field_path) = field_paths.iter().exactly_one().ok() else { + return Ok(None); + }; + field_path.clone() + } + }; + let stat_path = field_path.clone().push(stat.name()); + if !self.available_stats.contains(&stat_path) { + return Ok(None); + } + + self.required_stats.insert(field_path.clone(), stat); + Ok(Some(get_item( + field_path_stat_field_name(&field_path, stat), + root(), + ))) + } +} + +fn direct_stat_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + if expr.is::() { + return direct_stat_field_path(expr.child(0)); + } - let expr = expr.stat_falsification(&catalog)?; + let field_name = expr.as_opt::()?; + direct_stat_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) +} + +fn filter_required_stats(expr: &Expression, required_stats: RequiredStats) -> RequiredStats { + let referenced_names = referenced_stat_field_names(expr); + let mut filtered = Relation::new(); + for (field_path, stats) in required_stats { + for stat in stats { + if referenced_names.contains(&field_path_stat_field_name(&field_path, stat)) { + filtered.insert(field_path.clone(), stat); + } + } + } + filtered +} - // TODO(joe): filter access by used exprs - let mut relation: Relation = Relation::new(); - for ((field_path, stat), _) in catalog.inner.into_usages() { - relation.insert(field_path, stat) +fn referenced_stat_field_names(expr: &Expression) -> HashSet { + let mut refs = HashSet::new(); + collect_referenced_stat_field_names(expr, &mut refs); + refs +} + +fn collect_referenced_stat_field_names(expr: &Expression, refs: &mut HashSet) { + if let Some(field_name) = expr.as_opt::() + && is_root(expr.child(0)) + { + refs.insert(field_name.clone()); + return; } - Some((expr, relation)) + for child in expr.children().iter() { + collect_referenced_stat_field_names(child, refs); + } +} + +fn bool_literal(expr: &Expression) -> Option> { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) } #[cfg(test)] mod tests { + use std::sync::LazyLock; + use rstest::fixture; use rstest::rstest; + use vortex_session::VortexSession; + use vortex_utils::aliases::hash_map::HashMap; use vortex_utils::aliases::hash_set::HashSet; - use super::HashMap; + use super::RequiredStats; use crate::dtype::DType; use crate::dtype::FieldName; use crate::dtype::FieldNames; use crate::dtype::FieldPath; use crate::dtype::FieldPathSet; use crate::dtype::Nullability; + use crate::dtype::PType; use crate::dtype::StructFields; + use crate::expr::Expression; use crate::expr::and; use crate::expr::between; use crate::expr::cast; @@ -144,8 +210,41 @@ mod tests { use crate::expr::pruning::field_path_stat_field_name; use crate::expr::root; use crate::expr::stats::Stat; + use crate::scalar::Scalar; use crate::scalar_fn::fns::between::BetweenOptions; use crate::scalar_fn::fns::between::StrictComparison; + use crate::stats::session::StatsSession; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + fn test_scope() -> DType { + DType::Struct( + StructFields::from_iter([ + ("a", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("b", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("x", DType::Bool(Nullability::NonNullable)), + ("y", DType::Primitive(PType::I32, Nullability::NonNullable)), + ("z", DType::Primitive(PType::I32, Nullability::NonNullable)), + ( + "float_col", + DType::Primitive(PType::F32, Nullability::NonNullable), + ), + ( + "int_col", + DType::Primitive(PType::I32, Nullability::NonNullable), + ), + ]), + Nullability::NonNullable, + ) + } + + fn checked( + expr: &Expression, + available_stats: &FieldPathSet, + ) -> Option<(Expression, RequiredStats)> { + checked_pruning_expr(expr, &test_scope(), available_stats, &SESSION).unwrap() + } // Implement some checked pruning expressions. #[fixture] @@ -166,7 +265,7 @@ mod tests { let name = FieldName::from("a"); let literal_eq = lit(42); let eq_expr = eq(get_item("a", root()), literal_eq.clone()); - let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); + let (converted, _refs) = checked(&eq_expr, &available_stats).unwrap(); let expected_expr = or( gt( get_item( @@ -192,7 +291,7 @@ mod tests { let other_col = FieldName::from("b"); let eq_expr = eq(col(column.clone()), col(other_col.clone())); - let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(&eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -237,7 +336,7 @@ mod tests { let other_col = FieldName::from("b"); let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone())); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -284,7 +383,7 @@ mod tests { let other_expr = col(other_col.clone()); let not_eq_expr = gt(col(column.clone()), other_expr); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -317,7 +416,7 @@ mod tests { let other_col = lit(42); let not_eq_expr = gt(col(column.clone()), other_col.clone()); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( @@ -342,7 +441,7 @@ mod tests { let other_expr = col(other_col.clone()); let not_eq_expr = lt(col(column.clone()), other_expr); - let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap(); + let (converted, refs) = checked(¬_eq_expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([ @@ -375,7 +474,7 @@ mod tests { // pruning expr => a.min >= 42 let expr = lt(col("a"), lit(42)); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))]) @@ -387,7 +486,7 @@ mod tests { fn pruning_identity(available_stats: FieldPathSet) { let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50))); - let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (predicate, _) = checked(&expr, &available_stats).unwrap(); let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50))); assert_eq!(&predicate.to_string(), &expected_expr.to_string()); @@ -397,7 +496,7 @@ mod tests { // Test case: a > 10 AND a < 50 let column = FieldName::from("a"); let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50))); - let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap(); + let (predicate, _) = checked(&and_expr, &available_stats).unwrap(); // Expected: a_max <= 10 OR a_min >= 50 assert_eq!( @@ -436,7 +535,7 @@ mod tests { // True > False // True let expr = gt_eq(col("x"), gt(col("y"), col("z"))); - assert!(checked_pruning_expr(&expr, &available_stats).is_none()); + assert!(checked(&expr, &available_stats).is_none()); // TODO(DK): a sufficiently complex pruner would produce: `x_max <= (y_max > z_min)` } @@ -459,44 +558,35 @@ mod tests { #[rstest] fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) { let expr = gt_eq(col("float_col"), lit(f32::NAN)); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - assert_eq!( - &converted, - &and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NaNCount of NaN is 1 - eq(lit(1u64), lit(0u64)), - ), - // This is the standard conversion of the >= operator. Comparing NAN to a max - // stat is nonsensical, as min/max stats ignore NaNs, but this should be short-circuited - // by the previous check for nan_count anyway. - lt(col("float_col_max"), lit(f32::NAN)), - ) - ); + assert!(checked(&expr, &available_stats_with_nans).is_none()); - // One half of the expression requires NAN count check, the other half does not. + // One half of the expression requires an all-non-NaN proof, the other half does not. let expr = and( gt(col("float_col"), lit(10f32)), lt(col("int_col"), lit(10)), ); - let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap(); - + let (converted, refs) = checked(&expr, &available_stats_with_nans).unwrap(); + assert_eq!( + refs.map(), + &HashMap::from_iter([ + ( + FieldPath::from_name("float_col"), + HashSet::from_iter([Stat::Max]) + ), + ( + FieldPath::from_name("int_col"), + HashSet::from_iter([Stat::Min]) + ) + ]) + ); assert_eq!( &converted, &or( - // NaNCount check is enforced for the float column and( - and( - eq(col("float_col_nan_count"), lit(0u64)), - // NanCount of a non-NaN float literal is 0 - eq(lit(0u64), lit(0u64)), - ), - // We want the opposite: we can prune IF either one is false. + lit(Scalar::null(DType::Bool(Nullability::Nullable))), lt_eq(col("float_col_max"), lit(10f32)), ), - // NanCount check is skipped for the int column gt_eq(col("int_col_min"), lit(10)), ) ) @@ -513,7 +603,7 @@ mod tests { upper_strict: StrictComparison::NonStrict, }, ); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( @@ -542,7 +632,7 @@ mod tests { Nullability::NonNullable, ); let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value")); - let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let (converted, refs) = checked(&expr, &available_stats).unwrap(); assert_eq!( refs.map(), &HashMap::from_iter([( diff --git a/vortex-array/src/scalar_fn/erased.rs b/vortex-array/src/scalar_fn/erased.rs index 10e82d25455..6e0011c297a 100644 --- a/vortex-array/src/scalar_fn/erased.rs +++ b/vortex-array/src/scalar_fn/erased.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ReduceCtx; @@ -180,25 +178,6 @@ impl ScalarFnRef { pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult> { self.0.simplify_untyped(expr) } - - /// Compute stat falsification expression. - pub(crate) fn stat_falsification( - &self, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_falsification(expr, catalog) - } - - /// Compute stat expression. - pub(crate) fn stat_expression( - &self, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - self.0.stat_expression(expr, stat, catalog) - } } impl Debug for ScalarFnRef { diff --git a/vortex-array/src/scalar_fn/fns/between/mod.rs b/vortex-array/src/scalar_fn/fns/between/mod.rs index 0e0d9949195..bd546bed941 100644 --- a/vortex-array/src/scalar_fn/fns/between/mod.rs +++ b/vortex-array/src/scalar_fn/fns/between/mod.rs @@ -25,7 +25,6 @@ use crate::arrays::Primitive; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::DType::Bool; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; use crate::scalar::Scalar; use crate::scalar_fn::Arity; @@ -33,8 +32,6 @@ use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; -use crate::scalar_fn::fns::binary::Binary; use crate::scalar_fn::fns::binary::execute_boolean; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -298,22 +295,6 @@ impl ScalarFnVTable for Between { between_canonical(&arr, &lower, &upper, options, ctx) } - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let arr = expr.child(0).clone(); - let lower = expr.child(1).clone(); - let upper = expr.child(2).clone(); - - let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]); - let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]); - - and(lhs, rhs).stat_falsification(catalog) - } - fn validity( &self, _options: &Self::Options, diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..b51f86b3188 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,23 +17,16 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; -use crate::expr::StatsCatalog; +use crate::dtype::Nullability; use crate::expr::and; -use crate::expr::and_collect; -use crate::expr::eq; use crate::expr::expression::Expression; -use crate::expr::gt; -use crate::expr::gt_eq; use crate::expr::lit; -use crate::expr::lt; -use crate::expr::lt_eq; -use crate::expr::or_collect; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; @@ -45,10 +38,45 @@ mod numeric; pub(crate) use numeric::*; use crate::scalar::NumericOperator; +use crate::scalar::Scalar; #[derive(Clone)] pub struct Binary; +fn simplify_and(lhs: &Expression, rhs: &Expression) -> Option { + match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(false)), _) | (_, Some(Some(false))) => Some(lit(false)), + (Some(Some(true)), _) => Some(rhs.clone()), + (_, Some(Some(true))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + } +} + +fn simplify_or(lhs: &Expression, rhs: &Expression) -> Option { + match (bool_literal(lhs), bool_literal(rhs)) { + (Some(Some(true)), _) | (_, Some(Some(true))) => Some(lit(true)), + (Some(Some(false)), _) => Some(rhs.clone()), + (_, Some(Some(false))) => Some(lhs.clone()), + (Some(None), Some(None)) => Some(lhs.clone()), + _ => None, + } +} + +fn bool_literal(expr: &Expression) -> Option> { + expr.as_opt::()? + .as_bool_opt() + .map(|value| value.value()) +} + +fn is_null_literal(expr: &Expression) -> bool { + expr.as_opt::().is_some_and(Scalar::is_null) +} + +fn null_bool() -> Expression { + lit(Scalar::null(DType::Bool(Nullability::Nullable))) +} + impl ScalarFnVTable for Binary { type Options = Operator; @@ -165,108 +193,23 @@ impl ScalarFnVTable for Binary { } } - fn stat_falsification( + fn simplify_untyped( &self, operator: &Operator, expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Wrap another predicate with an optional NaNCount check, if the stat is available. - // - // For example, regular pruning conversion for `A >= B` would be - // - // A.max < B.min - // - // With NaN predicate introduction, we'd conjunct it with a check for NaNCount, resulting - // in: - // - // (A.nan_count = 0) AND (B.nan_count = 0) AND A.max < B.min - // - // Non-floating point column and literal expressions should be unaffected as they do not - // have a nan_count statistic defined. - fn with_nan_predicate( - lhs: &Expression, - rhs: &Expression, - value_predicate: Expression, - catalog: &dyn StatsCatalog, - ) -> Expression { - let nan_predicate = and_collect( - lhs.stat_expression(Stat::NaNCount, catalog) - .into_iter() - .chain(rhs.stat_expression(Stat::NaNCount, catalog)) - .map(|nans| eq(nans, lit(0u64))), - ); - - if let Some(nan_check) = nan_predicate { - and(nan_check, value_predicate) - } else { - value_predicate - } - } - + ) -> VortexResult> { let lhs = expr.child(0); let rhs = expr.child(1); - match operator { - Operator::Eq => { - let min_lhs = lhs.stat_min(catalog); - let max_lhs = lhs.stat_max(catalog); - - let min_rhs = rhs.stat_min(catalog); - let max_rhs = rhs.stat_max(catalog); - let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b)); - let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b)); - - let min_max_check = or_collect(left.into_iter().chain(right))?; - - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::NotEq => { - let min_lhs = lhs.stat_min(catalog)?; - let max_lhs = lhs.stat_max(catalog)?; - - let min_rhs = rhs.stat_min(catalog)?; - let max_rhs = rhs.stat_max(catalog)?; - - let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gt => { - let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Gte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lt => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::Lte => { - // NaN is not captured by the min/max stat, so we must check NaNCount before pruning - let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?); - - Some(with_nan_predicate(lhs, rhs, min_max_check, catalog)) - } - Operator::And => or_collect( - lhs.stat_falsification(catalog) - .into_iter() - .chain(rhs.stat_falsification(catalog)), - ), - Operator::Or => Some(and( - lhs.stat_falsification(catalog)?, - rhs.stat_falsification(catalog)?, - )), - Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None, + if operator.is_comparison() && (is_null_literal(lhs) || is_null_literal(rhs)) { + return Ok(Some(null_bool())); } + + Ok(match operator { + Operator::And => simplify_and(lhs, rhs), + Operator::Or => simplify_or(lhs, rhs), + _ => None, + }) } fn validity( @@ -325,8 +268,12 @@ mod tests { use crate::expr::Expression; use crate::expr::and_collect; use crate::expr::col; + use crate::expr::eq; + use crate::expr::gt; + use crate::expr::gt_eq; use crate::expr::lit; use crate::expr::lt; + use crate::expr::lt_eq; use crate::expr::not_eq; use crate::expr::or; use crate::expr::or_collect; diff --git a/vortex-array/src/scalar_fn/fns/cast/mod.rs b/vortex-array/src/scalar_fn/fns/cast/mod.rs index abc59af2c9a..20852779d42 100644 --- a/vortex-array/src/scalar_fn/fns/cast/mod.rs +++ b/vortex-array/src/scalar_fn/fns/cast/mod.rs @@ -32,11 +32,8 @@ use crate::arrays::VarBinView; use crate::arrays::struct_::compute::cast::struct_cast; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; -use crate::expr::StatsCatalog; -use crate::expr::cast; use crate::expr::expression::Expression; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -152,39 +149,6 @@ impl ScalarFnVTable for Cast { Ok(None) } - fn stat_expression( - &self, - dtype: &DType, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - match stat { - Stat::IsConstant - | Stat::IsSorted - | Stat::IsStrictSorted - | Stat::NaNCount - | Stat::Sum - | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog), - Stat::Max | Stat::Min => { - // We cast min/max to the new type - expr.child(0) - .stat_expression(stat, catalog) - .map(|x| cast(x, dtype.clone())) - } - Stat::NullCount => { - // if !expr.data().is_nullable() { - // NOTE(ngates): we should decide on the semantics here. In theory, the null - // count of something cast to non-nullable will be zero. But if we return - // that we know this to be zero, then a pruning predicate may eliminate data - // that would otherwise have caused the cast to error. - // return Some(lit(0u64)); - // } - None - } - } - } - fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult> { Ok(Some(if dtype.is_nullable() { expression.child(0).validity()? diff --git a/vortex-array/src/scalar_fn/fns/dynamic.rs b/vortex-array/src/scalar_fn/fns/dynamic.rs index 7efebf79220..f6e6619282a 100644 --- a/vortex-array/src/scalar_fn/fns/dynamic.rs +++ b/vortex-array/src/scalar_fn/fns/dynamic.rs @@ -20,7 +20,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::traversal::NodeExt; use crate::expr::traversal::NodeVisitor; use crate::expr::traversal::TraversalOrder; @@ -120,50 +119,6 @@ impl ScalarFnVTable for DynamicComparison { .into_array()) } - fn stat_falsification( - &self, - dynamic: &DynamicComparisonExpr, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let lhs = expr.child(0); - match dynamic.operator { - CompareOperator::Eq | CompareOperator::NotEq => None, - CompareOperator::Gt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Gte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Lt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_max(catalog)?], - )), - CompareOperator::Lt => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gte, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - CompareOperator::Lte => Some(DynamicComparison.new_expr( - DynamicComparisonExpr { - operator: CompareOperator::Gt, - rhs: Arc::clone(&dynamic.rhs), - default: !dynamic.default, - }, - vec![lhs.stat_min(catalog)?], - )), - } - } - // Defer to the child fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false diff --git a/vortex-array/src/scalar_fn/fns/get_item.rs b/vortex-array/src/scalar_fn/fns/get_item.rs index de7e45ca9b0..b9cb9202f38 100644 --- a/vortex-array/src/scalar_fn/fns/get_item.rs +++ b/vortex-array/src/scalar_fn/fns/get_item.rs @@ -18,12 +18,9 @@ use crate::builtins::ArrayBuiltins; use crate::builtins::ExprBuiltins; use crate::dtype::DType; use crate::dtype::FieldName; -use crate::dtype::FieldPath; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -188,24 +185,6 @@ impl ScalarFnVTable for GetItem { Ok(None) } - fn stat_expression( - &self, - field_name: &FieldName, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - // TODO(ngates): I think we can do better here and support stats over nested fields. - // It would be nice if delegating to our child would return a struct of statistics - // matching the nested DType such that we can write: - // `get_item(expr.child(0).stat_expression(...), expr.data().field_name())` - - // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same - // name as a field in the root struct. This should be resolved with upcoming change to - // falsify expressions, but for now I'm preserving the existing buggy behavior. - catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat) - } - // This will apply struct nullability field. We could add a dtype?? fn is_null_sensitive(&self, _field_name: &FieldName) -> bool { true diff --git a/vortex-array/src/scalar_fn/fns/is_not_null.rs b/vortex-array/src/scalar_fn/fns/is_not_null.rs index 589333304e2..a64aab32611 100644 --- a/vortex-array/src/scalar_fn/fns/is_not_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_not_null.rs @@ -3,7 +3,6 @@ use std::fmt::Formatter; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_error::VortexResult; use vortex_session::VortexSession; use vortex_session::registry::CachedId; @@ -15,16 +14,12 @@ use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::ScalarFnVTableExt; use crate::validity::Validity; /// Expression that checks for non-null values. @@ -100,18 +95,6 @@ impl ScalarFnVTable for IsNotNull { fn is_fallible(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // is_not_null is falsified when ALL values are null, i.e. null_count == row_count. - let child = expr.child(0); - let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, RowCount.new_expr(EmptyOptions, []))) - } } #[cfg(test)] @@ -135,6 +118,8 @@ mod tests { use crate::expr::eq; use crate::expr::get_item; use crate::expr::is_not_null; + use crate::expr::lit; + use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; @@ -267,16 +252,22 @@ mod tests { let (pruning_expr, st) = checked_pruning_expr( &expr, + &test_harness::struct_dtype(), &FieldPathSet::from_iter([FieldPath::from_iter([ Field::Name("a".into()), Field::Name("null_count".into()), ])]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!( &pruning_expr, - &eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])) + &or( + eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, [])), + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + ) ); assert_eq!( st.map(), diff --git a/vortex-array/src/scalar_fn/fns/is_null.rs b/vortex-array/src/scalar_fn/fns/is_null.rs index 7315fbe8c07..807b9d9a043 100644 --- a/vortex-array/src/scalar_fn/fns/is_null.rs +++ b/vortex-array/src/scalar_fn/fns/is_null.rs @@ -12,11 +12,6 @@ use crate::arrays::ConstantArray; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::eq; -use crate::expr::lit; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -84,16 +79,6 @@ impl ScalarFnVTable for IsNull { } } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?; - Some(eq(null_count_expr, lit(0u64))) - } - fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true } @@ -125,6 +110,7 @@ mod tests { use crate::expr::get_item; use crate::expr::is_null; use crate::expr::lit; + use crate::expr::or; use crate::expr::pruning::checked_pruning_expr; use crate::expr::root; use crate::expr::stats::Stat; @@ -251,14 +237,23 @@ mod tests { let (pruning_expr, st) = checked_pruning_expr( &expr, + &test_harness::struct_dtype(), &FieldPathSet::from_iter([FieldPath::from_iter([ Field::Name("a".into()), Field::Name("null_count".into()), ])]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); - assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64))); + assert_eq!( + &pruning_expr, + &or( + eq(col("a_null_count"), lit(0u64)), + lit(Scalar::null(DType::Bool(Nullability::Nullable))), + ) + ); assert_eq!( st.map(), &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))]) diff --git a/vortex-array/src/scalar_fn/fns/like/mod.rs b/vortex-array/src/scalar_fn/fns/like/mod.rs index b7f357020f1..850484a1450 100644 --- a/vortex-array/src/scalar_fn/fns/like/mod.rs +++ b/vortex-array/src/scalar_fn/fns/like/mod.rs @@ -21,20 +21,12 @@ use crate::arrow::Datum; use crate::arrow::from_arrow_columnar; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; use crate::expr::and; -use crate::expr::gt; -use crate::expr::gt_eq; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; -use crate::scalar::StringLike; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; -use crate::scalar_fn::fns::literal::Literal; /// Options for SQL LIKE function #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -165,49 +157,6 @@ impl ScalarFnVTable for Like { fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { false } - - fn stat_falsification( - &self, - like_opts: &LikeOptions, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - // Attempt to do min/max pruning for LIKE 'exact' or LIKE 'prefix%' - - // Don't attempt to handle ilike or negated like - if like_opts.negated || like_opts.case_insensitive { - return None; - } - - // Extract the pattern out - let pat = expr.child(1).as_::(); - - // LIKE NULL is nonsensical, don't try to handle it - let pat_str = pat.as_utf8().value()?; - - let src = expr.child(0).clone(); - let src_min = src.stat_min(catalog)?; - let src_max = src.stat_max(catalog)?; - - match LikeVariant::from_str(pat_str)? { - LikeVariant::Exact(text) => { - // col LIKE 'exact' ==> col.min > 'exact' || col.max < 'exact' - Some(or( - gt(src_min, lit(text.as_ref())), - lt(src_max, lit(text.as_ref())), - )) - } - LikeVariant::Prefix(prefix) => { - // col LIKE 'prefix%' ==> col.max < 'prefix' || col.min >= 'prefiy' - let succ = prefix.to_string().increment().ok()?; - - Some(or( - gt_eq(src_min, lit(succ)), - lt(src_max, lit(prefix.as_ref())), - )) - } - } - } } /// Implementation of LIKE using the Arrow crate. @@ -295,15 +244,11 @@ mod tests { use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; - use crate::expr::col; use crate::expr::get_item; - use crate::expr::ilike; use crate::expr::like; use crate::expr::lit; use crate::expr::not; use crate::expr::not_ilike; - use crate::expr::not_like; - use crate::expr::pruning::pruning_expr::TrackingStatsCatalog; use crate::expr::root; use crate::scalar_fn::fns::like::LikeVariant; @@ -390,50 +335,4 @@ mod tests { assert_eq!(LikeVariant::from_str(r"%\%%"), None); assert_eq!(LikeVariant::from_str("_pattern"), None); } - - #[test] - fn test_like_pushdown() { - // Test that LIKE prefix and exactness filters can be pushed down into stats filtering - // at scan time. - let catalog = TrackingStatsCatalog::default(); - - let pruning_expr = like(col("a"), lit("prefix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "prefiy") or ($.a_max < "prefix"))"#); - - let pruning_expr = like(col("a"), lit(r"\%%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "&") or ($.a_max < "%"))"#); - - // Multiple wildcards - let pruning_expr = like(col("a"), lit("pref%ix%")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - let pruning_expr = like(col("a"), lit("pref_ix_")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min >= "preg") or ($.a_max < "pref"))"#); - - // Exact match - let pruning_expr = like(col("a"), lit("exactly")) - .stat_falsification(&catalog) - .expect("LIKE stat falsification"); - insta::assert_snapshot!(pruning_expr, @r#"(($.a_min > "exactly") or ($.a_max < "exactly"))"#); - - // Suffix search skips pushdown - let pruning_expr = like(col("a"), lit("%suffix")).stat_falsification(&catalog); - assert_eq!(pruning_expr, None); - - // NOT LIKE, ILIKE not supported currently - assert_eq!( - None, - not_like(col("a"), lit("a")).stat_falsification(&catalog) - ); - assert_eq!(None, ilike(col("a"), lit("a")).stat_falsification(&catalog)); - } } diff --git a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs index 978a1da1caf..236bdd646eb 100644 --- a/vortex-array/src/scalar_fn/fns/list_contains/mod.rs +++ b/vortex-array/src/scalar_fn/fns/list_contains/mod.rs @@ -33,13 +33,6 @@ use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::IntegerPType; use crate::dtype::Nullability; -use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::and_collect; -use crate::expr::gt; -use crate::expr::lit; -use crate::expr::lt; -use crate::expr::or; use crate::match_each_integer_ptype; use crate::match_each_unsigned_integer_ptype; use crate::scalar::ListScalar; @@ -51,7 +44,6 @@ use crate::scalar_fn::ExecutionArgs; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; use crate::scalar_fn::fns::binary::Binary; -use crate::scalar_fn::fns::literal::Literal; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity; @@ -129,43 +121,6 @@ impl ScalarFnVTable for ListContains { compute_list_contains(&list_array, &value_array, ctx) } - fn stat_falsification( - &self, - _options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - let list = expr.child(0); - let needle = expr.child(1); - - // falsification(contains([1,2,5], x)) => - // falsification(x != 1) and falsification(x != 2) and falsification(x != 5) - let min = list.stat_min(catalog)?; - let max = list.stat_max(catalog)?; - // If the list is constant when we can compare each element to the value - if min == max { - let list_ = min - .as_opt::() - .and_then(|l| l.as_list_opt()) - .and_then(|l| l.elements())?; - if list_.is_empty() { - // contains([], x) is always false. - return Some(lit(true)); - } - let value_max = needle.stat_max(catalog)?; - let value_min = needle.stat_min(catalog)?; - - return and_collect(list_.iter().map(move |v| { - or( - lt(value_max.clone(), lit(v.clone())), - gt(value_min.clone(), lit(v.clone())), - ) - })); - } - - None - } - // Nullability matters for contains([], x) where x is false. fn is_null_sensitive(&self, _instance: &Self::Options) -> bool { true @@ -630,14 +585,24 @@ mod tests { )), col("a"), ); + let scope = DType::Struct( + StructFields::new( + ["a"].into(), + vec![DType::Primitive(I32, Nullability::NonNullable)], + ), + Nullability::NonNullable, + ); let (expr, st) = checked_pruning_expr( &expr, + &scope, &FieldPathSet::from_iter([ FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]), FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]), ]), + &LEGACY_SESSION, ) + .unwrap() .unwrap(); assert_eq!( diff --git a/vortex-array/src/scalar_fn/fns/literal.rs b/vortex-array/src/scalar_fn/fns/literal.rs index 16b112e5a78..5181a5250dd 100644 --- a/vortex-array/src/scalar_fn/fns/literal.rs +++ b/vortex-array/src/scalar_fn/fns/literal.rs @@ -16,9 +16,6 @@ use crate::IntoArray; use crate::arrays::ConstantArray; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; -use crate::match_each_float_ptype; use crate::scalar::Scalar; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; @@ -96,50 +93,6 @@ impl ScalarFnVTable for Literal { Ok(ConstantArray::new(scalar.clone(), args.row_count()).into_array()) } - fn stat_expression( - &self, - scalar: &Scalar, - _expr: &Expression, - stat: Stat, - _catalog: &dyn StatsCatalog, - ) -> Option { - // NOTE(ngates): we return incorrect `1` values for counts here since we don't have - // row-count information. We could resolve this in the future by introducing a `count()` - // expression that evaluates to the row count of the provided scope. But since this is - // only currently used for pruning, it doesn't change the outcome. - - match stat { - Stat::Min | Stat::Max => Some(lit(scalar.clone())), - Stat::IsConstant => Some(lit(true)), - Stat::NaNCount => { - // The NaNCount for a non-float literal is not defined. - // For floating point types, the NaNCount is 1 for lit(NaN), and 0 otherwise. - let value = scalar.as_primitive_opt()?; - if !value.ptype().is_float() { - return None; - } - - match_each_float_ptype!(value.ptype(), |T| { - if value.typed_value::().is_some_and(|v| v.is_nan()) { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - }) - } - Stat::NullCount => { - if scalar.is_null() { - Some(lit(1u64)) - } else { - Some(lit(0u64)) - } - } - Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum | Stat::UncompressedSizeInBytes => { - None - } - } - } - fn validity( &self, scalar: &Scalar, diff --git a/vortex-array/src/scalar_fn/fns/root.rs b/vortex-array/src/scalar_fn/fns/root.rs index 87b8b62ccf4..7bd5b758796 100644 --- a/vortex-array/src/scalar_fn/fns/root.rs +++ b/vortex-array/src/scalar_fn/fns/root.rs @@ -11,10 +11,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; -use crate::dtype::FieldPath; -use crate::expr::StatsCatalog; use crate::expr::expression::Expression; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; @@ -80,16 +77,6 @@ impl ScalarFnVTable for Root { vortex_bail!("Root expression is not executable") } - fn stat_expression( - &self, - _options: &Self::Options, - _expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - catalog.stats_ref(&FieldPath::root(), stat) - } - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-array/src/scalar_fn/typed.rs b/vortex-array/src/scalar_fn/typed.rs index a2ef9549bff..e31cdc79ed0 100644 --- a/vortex-array/src/scalar_fn/typed.rs +++ b/vortex-array/src/scalar_fn/typed.rs @@ -22,8 +22,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::ExecutionArgs; @@ -101,17 +99,6 @@ pub(super) trait DynScalarFn: 'static + Send + Sync + super::sealed::Sealed { ) -> VortexResult>; fn simplify_untyped(&self, expression: &Expression) -> VortexResult>; fn validity(&self, expression: &Expression) -> VortexResult>; - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option; - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option; // Options operations — self-contained fn options_serialize(&self) -> VortexResult>>; @@ -223,23 +210,6 @@ impl DynScalarFn for TypedScalarFnInstance { V::validity(&self.vtable, &self.options, expression) } - fn stat_falsification( - &self, - expression: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_falsification(&self.vtable, &self.options, expression, catalog) - } - - fn stat_expression( - &self, - expression: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - V::stat_expression(&self.vtable, &self.options, expression, stat, catalog) - } - fn options_serialize(&self) -> VortexResult>> { V::serialize(&self.vtable, &self.options) } diff --git a/vortex-array/src/scalar_fn/vtable.rs b/vortex-array/src/scalar_fn/vtable.rs index f4862f6876a..c66afc34932 100644 --- a/vortex-array/src/scalar_fn/vtable.rs +++ b/vortex-array/src/scalar_fn/vtable.rs @@ -20,8 +20,6 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::expr::Expression; -use crate::expr::StatsCatalog; -use crate::expr::stats::Stat; use crate::expr::traversal::Node; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnRef; @@ -179,34 +177,6 @@ pub trait ScalarFnVTable: 'static + Sized + Clone + Send + Sync { Ok(None) } - /// See [`Expression::stat_falsification`]. - fn stat_falsification( - &self, - options: &Self::Options, - expr: &Expression, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = catalog; - None - } - - /// See [`Expression::stat_expression`]. - fn stat_expression( - &self, - options: &Self::Options, - expr: &Expression, - stat: Stat, - catalog: &dyn StatsCatalog, - ) -> Option { - _ = options; - _ = expr; - _ = stat; - _ = catalog; - None - } - /// Returns an expression that evaluates to the validity of the result of this expression. /// /// If a validity expression cannot be constructed, returns `None` and the expression will diff --git a/vortex-array/src/stats/bind.rs b/vortex-array/src/stats/bind.rs new file mode 100644 index 00000000000..752664396c6 --- /dev/null +++ b/vortex-array/src/stats/bind.rs @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Bind abstract `vortex.stat` expressions to a concrete stats representation. + +use vortex_error::VortexResult; + +use crate::aggregate_fn::AggregateFnRef; +use crate::dtype::DType; +use crate::expr::Expression; +use crate::expr::lit; +use crate::expr::stats::Stat; +use crate::expr::traversal::NodeExt; +use crate::expr::traversal::Transformed; +use crate::scalar::Scalar; +use crate::scalar_fn::fns::stat::StatFn; + +/// A target that can bind abstract statistics to concrete expressions. +pub trait StatBinder { + /// The dtype scope used to type-check expressions before stats are bound. + fn scope(&self) -> &DType; + + /// Bind `stat(input)` to a concrete expression. + /// + /// Returning `Ok(None)` marks the stat as unavailable. [`bind_stats`] will + /// then call [`Self::missing_stat`] with the dtype expected from the + /// original `vortex.stat` expression. + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + stat_dtype: &DType, + ) -> VortexResult>; + + /// Bind `aggregate_fn(input)` to a concrete expression. + /// + /// The default implementation supports aggregate functions with legacy + /// [`Stat`] slots. Binders that store richer aggregate stats can override + /// this method without extending the generic stats binding walker. + fn bind_aggregate( + &mut self, + input: &Expression, + aggregate_fn: &AggregateFnRef, + stat_dtype: &DType, + ) -> VortexResult> { + let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else { + return Ok(None); + }; + self.bind_stat(input, stat, stat_dtype) + } + + /// Expression to use when a stat is unavailable. + /// + /// The default is a nullable null literal, which preserves three-valued + /// pruning semantics for stats-table execution. Catalog-like binders can + /// return `Ok(None)` to reject expressions that require unavailable stats. + fn missing_stat(&mut self, dtype: DType) -> VortexResult> { + Ok(Some(null_expr(dtype))) + } +} + +/// Bind all `vortex.stat` expressions in `predicate`. +/// +/// The predicate is usually the output of a stats rewrite rule. Rewrite rules +/// are responsible for expressing stat semantics; binding maps aggregate-backed +/// stat requests to the concrete stats representation supported by the binder. +pub fn bind_stats( + predicate: Expression, + binder: &mut impl StatBinder, +) -> VortexResult> { + let scope = binder.scope().clone(); + let lowered = predicate + .transform_down(|expr| { + if !expr.is::() { + return Ok(Transformed::no(expr)); + } + + match bind_stat_fn(&expr, &scope, binder)? { + Some(bound) => Ok(Transformed::yes(bound)), + None => { + let dtype = expr.return_dtype(&scope)?; + match binder.missing_stat(dtype.clone())? { + Some(missing) => Ok(Transformed::yes(missing)), + None => Ok(Transformed::yes(null_expr(dtype))), + } + } + } + })? + .into_inner(); + + #[expect(deprecated)] + let lowered = lowered.simplify_untyped()?; + Ok(Some(lowered)) +} + +fn bind_stat_fn( + expr: &Expression, + scope: &DType, + binder: &mut impl StatBinder, +) -> VortexResult> { + let options = expr.as_::(); + let aggregate_fn = options.aggregate_fn(); + let input = expr.child(0); + + let stat_dtype = expr.return_dtype(scope)?; + binder.bind_aggregate(input, aggregate_fn, &stat_dtype) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index ceb085e0815..5f5684dbde2 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -19,9 +19,10 @@ pub use expr::sum; pub use stats_set::*; mod array; +pub mod bind; pub mod expr; pub mod flatbuffers; -pub(crate) mod rewrite; +pub mod rewrite; pub mod session; mod stats_set; diff --git a/vortex-array/src/stats/rewrite.rs b/vortex-array/src/stats/rewrite.rs index 52d354df1a0..dfba62ded40 100644 --- a/vortex-array/src/stats/rewrite.rs +++ b/vortex-array/src/stats/rewrite.rs @@ -21,7 +21,7 @@ mod builtins; pub(crate) use builtins::register_builtins; /// Shared reference to a stats rewrite rule. -pub(crate) type StatsRewriteRuleRef = Arc; +pub type StatsRewriteRuleRef = Arc; /// A plugin-provided rule for predicates whose root scalar function matches this rule. /// @@ -40,7 +40,7 @@ pub(crate) type StatsRewriteRuleRef = Arc; /// `expr` is the full predicate expression whose root scalar function id is /// [`Self::scalar_fn_id`]. Use [`StatsRewriteCtx`] to resolve dtypes and recursively rewrite child /// predicates. -pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { +pub trait StatsRewriteRule: Debug + Send + Sync + 'static { /// Returns the scalar function id handled by this rule. fn scalar_fn_id(&self) -> ScalarFnId; @@ -83,35 +83,35 @@ pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static { } /// Context passed to stats rewrite rules. -pub(crate) struct StatsRewriteCtx<'a> { +pub struct StatsRewriteCtx<'a> { session: &'a VortexSession, scope: &'a DType, } impl<'a> StatsRewriteCtx<'a> { /// Create a rewrite context for `session`. - pub(crate) fn new(session: &'a VortexSession, scope: &'a DType) -> Self { + pub fn new(session: &'a VortexSession, scope: &'a DType) -> Self { Self { session, scope } } /// Returns the session that owns the rewrite registry. - pub(crate) fn session(&self) -> &'a VortexSession { + pub fn session(&self) -> &'a VortexSession { self.session } /// Return the dtype of `expr` within this rewrite scope. - pub(crate) fn return_dtype(&self, expr: &Expression) -> VortexResult { + pub fn return_dtype(&self, expr: &Expression) -> VortexResult { expr.return_dtype(self.scope) } /// Rewrite `expr` into a stats-backed falsifier. - pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult> { + pub fn falsify(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::falsify) } /// Rewrite `expr` into a stats-backed satisfier. - pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult> { + pub fn satisfy(&self, expr: &Expression) -> VortexResult> { self.ensure_predicate(expr)?; rewrite(expr, self, StatsRewriteRule::satisfy) } diff --git a/vortex-array/src/stats/session.rs b/vortex-array/src/stats/session.rs index 2d4325b2cd7..91eae4a4fa9 100644 --- a/vortex-array/src/stats/session.rs +++ b/vortex-array/src/stats/session.rs @@ -37,14 +37,12 @@ impl Default for StatsSession { impl StatsSession { /// Register a stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite(&self, rule: R) { + pub fn register_rewrite(&self, rule: R) { self.register_rewrite_ref(Arc::new(rule)); } /// Register a shared stats rewrite rule. - #[allow(dead_code)] - pub(crate) fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { + pub fn register_rewrite_ref(&self, rule: StatsRewriteRuleRef) { let mut rules = self.rewrite_rules.write(); let rule_id = rule.scalar_fn_id(); let mut updated_rules = rules @@ -75,7 +73,7 @@ impl SessionVar for StatsSession { } /// Extension trait for accessing stats session data. -pub(crate) trait StatsSessionExt: SessionExt { +pub trait StatsSessionExt: SessionExt { /// Returns the stats session state. fn stats(&self) -> Ref<'_, StatsSession> { self.get::() diff --git a/vortex-file/src/file.rs b/vortex-file/src/file.rs index ded986f6210..9e39a6fab5c 100644 --- a/vortex-file/src/file.rs +++ b/vortex-file/src/file.rs @@ -217,7 +217,9 @@ impl VortexFile { }), ); - let Some((predicate, required_stats)) = checked_pruning_expr(filter, &set) else { + let Some((predicate, required_stats)) = + checked_pruning_expr(filter, self.footer.dtype(), &set, &self.session)? + else { return Ok(false); }; diff --git a/vortex-file/src/v2/file_stats_reader.rs b/vortex-file/src/v2/file_stats_reader.rs index 0121c12b07d..b697becbc0f 100644 --- a/vortex-file/src/v2/file_stats_reader.rs +++ b/vortex-file/src/v2/file_stats_reader.rs @@ -21,13 +21,17 @@ use vortex_array::dtype::FieldMask; use vortex_array::dtype::FieldPath; use vortex_array::dtype::StructFields; use vortex_array::expr::Expression; -use vortex_array::expr::StatsCatalog; +use vortex_array::expr::is_root; use vortex_array::expr::lit; use vortex_array::expr::stats::Stat; use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::fns::get_item::GetItem; use vortex_array::scalar_fn::fns::literal::Literal; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_layout::ArrayFuture; use vortex_layout::LayoutReader; use vortex_layout::LayoutReaderRef; @@ -83,14 +87,14 @@ impl FileStatsLayoutReader { /// Row-count placeholders are resolved against the full file row count, /// independent of the requested row range. fn evaluate_file_stats(&self, expr: &Expression) -> VortexResult { - let Some(pruning_expr) = expr.stat_falsification(self) else { + let Some(pruning_expr) = expr.falsify(self.child.dtype(), &self.session)? else { // If there is no pruning expression, we can't prune. return Ok(false); }; + let pruning_expr = self.lower_stats(pruning_expr)?; - // Given how we implemented the StatsCatalog, we know the expression must be all literals - // or row_count placeholders. We can therefore optimize with a null scope since there are - // no field references that need to be resolved. + // Stats lowering replaces available stats with literals and unavailable stats with nulls, + // so only row_count placeholders remain unresolved here. let simplified = pruning_expr.optimize_recursive(&DType::Null)?; if let Some(result) = simplified.as_opt::() { // Can prune if the result is non-nullable and true @@ -115,14 +119,52 @@ impl FileStatsLayoutReader { Ok(result.as_bool().value() == Some(true)) } + fn lower_stats(&self, predicate: Expression) -> VortexResult { + let mut binder = FileStatsBinder { reader: self }; + let Some(predicate) = bind_stats(predicate, &mut binder)? else { + vortex_bail!("missing stats should lower to null literals"); + }; + Ok(predicate) + } + pub fn file_stats(&self) -> &FileStatistics { &self.file_stats } } -/// Implements [`StatsCatalog`] to provide file-level stats to expressions during pruning evaluation. -impl StatsCatalog for FileStatsLayoutReader { - fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { +struct FileStatsBinder<'a> { + reader: &'a FileStatsLayoutReader, +} + +impl StatBinder for FileStatsBinder<'_> { + fn scope(&self) -> &DType { + self.reader.child.dtype() + } + + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + let Some(field_path) = direct_field_path(input) else { + return Ok(None); + }; + Ok(self.reader.stat_ref(&field_path, stat)) + } +} + +fn direct_field_path(expr: &Expression) -> Option { + if is_root(expr) { + return Some(FieldPath::root()); + } + + let field_name = expr.as_opt::()?; + direct_field_path(expr.child(0)).map(|path| path.push(field_name.clone())) +} + +impl FileStatsLayoutReader { + fn stat_ref(&self, field_path: &FieldPath, stat: Stat) -> Option { // FileStats currently only holds top-level field statistics. if field_path.parts().len() != 1 { return None; @@ -224,6 +266,7 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; + use vortex_array::expr::checked_add; use vortex_array::expr::get_item; use vortex_array::expr::gt; use vortex_array::expr::is_not_null; @@ -365,6 +408,43 @@ mod tests { }) } + #[test] + fn no_pruning_for_computed_expression_stats() -> VortexResult<()> { + block_on(|handle| async { + let session = SESSION.clone().with_handle(handle); + let ctx = ArrayContext::empty(); + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let struct_array = + StructArray::from_fields([("col", buffer![0i32, 100].into_array())].as_slice())?; + let strategy = TableStrategy::new( + Arc::new(FlatLayoutStrategy::default()), + Arc::new(FlatLayoutStrategy::default()), + ); + let layout = strategy + .write_stream( + ctx, + Arc::::clone(&segments), + struct_array.into_array().to_array_stream().sequenced(ptr), + eof, + &session, + ) + .await?; + + let child = layout.new_reader("".into(), segments, &SESSION, &Default::default())?; + let reader = + FileStatsLayoutReader::new(child, test_file_stats(0, 100), SESSION.clone()); + + let expr = gt(checked_add(get_item("col", root()), lit(5i32)), lit(102i32)); + let mask = Mask::new_true(2); + let result = reader.pruning_evaluation(&(0..2), &expr, mask)?.await?; + + assert_eq!(result, Mask::new_true(2)); + + Ok(()) + }) + } + /// Regression test: `IS NULL` on a nullable timestamp column must not fail with a /// dtype mismatch. The bug was that `stats_ref` used the *field* dtype (timestamp) /// for the `NullCount` stat scalar instead of the stat's own dtype (u64). diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 96154e69571..f16082fc90e 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,33 +8,20 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::aggregate_fn::fns::all_nan::AllNan; -use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; -use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; -use vortex_array::aggregate_fn::fns::all_null::AllNull; -use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; -use vortex_array::expr::eq; use vortex_array::expr::get_item; use vortex_array::expr::is_root; -use vortex_array::expr::lit; use vortex_array::expr::root; use vortex_array::expr::stats::Stat; -use vortex_array::expr::traversal::NodeExt; -use vortex_array::expr::traversal::Transformed; -use vortex_array::scalar::Scalar; -use vortex_array::scalar_fn::EmptyOptions; -use vortex_array::scalar_fn::ScalarFnVTableExt; -use vortex_array::scalar_fn::fns::stat::StatFn; -use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; +use vortex_array::stats::bind::StatBinder; +use vortex_array::stats::bind::bind_stats; use vortex_array::validity::Validity; use vortex_buffer::buffer; use vortex_error::VortexResult; @@ -132,107 +119,42 @@ impl ZoneMap { } fn lower_stats(&self, predicate: Expression) -> VortexResult { - // Rewritten predicates are evaluated against the stats table, not the data - // column. Lower each StatFn before execution so unavailable stats become - // nullable "unknown" constants rather than prune signals. - predicate - .transform_down(|expr| { - if expr.is::() { - return self.lower_stat_fn(expr).map(Transformed::yes); - } - - Ok(Transformed::no(expr)) - }) - .map(Transformed::into_inner) - } - - fn lower_stat_fn(&self, expr: Expression) -> VortexResult { - // This is the bridge from aggregate-backed bound expressions to the legacy - // zoned stats columns. Exact NullCount and NanCount can prove richer - // all-* aggregates; non-root or missing stats lower to nullable unknowns. - let options = expr.as_::(); - let input = expr.child(0); - let input_dtype = input.return_dtype(&self.column_dtype)?; - let input_is_root = is_root(input); - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(false)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - if !has_nans(&input_dtype) { - return Ok(lit(true)); - } - if !input_is_root { - return Ok(null_expr(DType::Bool(Nullability::NonNullable))); - } - return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, lit(0u64))); - } - - if options.aggregate_fn().is::() && !has_nans(&input_dtype) { - return Ok(lit(0u64)); - } - - let return_dtype = match options.aggregate_fn().return_dtype(&input_dtype) { - Some(return_dtype) => return_dtype, - None => vortex_bail!( - "Aggregate function {} does not support input dtype {}", - options.aggregate_fn(), - input_dtype - ), - }; - - if !input_is_root { - return Ok(null_expr(return_dtype)); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, row_count_expr())); - } - - if options.aggregate_fn().is::() { - return Ok(eq(self.stat_field_expr(Stat::NullCount)?, lit(0u64))); - } - - let Some(stat) = Stat::from_aggregate_fn(options.aggregate_fn()) else { - return Ok(null_expr(return_dtype)); + let mut binder = ZoneMapStatsBinder { zone_map: self }; + let Some(predicate) = bind_stats(predicate, &mut binder)? else { + vortex_bail!("missing stats should lower to null literals"); }; - - self.stat_field_expr(stat) - } - - fn stat_field_expr(&self, stat: Stat) -> VortexResult { - if self.array.unmasked_field_by_name_opt(stat.name()).is_some() { - return Ok(get_item(stat.name(), root())); - } - - let Some(dtype) = stat.dtype(&self.column_dtype) else { - vortex_bail!( - "Stat {} does not support column dtype {}", - stat, - self.column_dtype - ); - }; - Ok(null_expr(dtype)) + Ok(predicate) } } -fn row_count_expr() -> Expression { - RowCount.new_expr(EmptyOptions, []) +struct ZoneMapStatsBinder<'a> { + zone_map: &'a ZoneMap, } -fn null_expr(dtype: DType) -> Expression { - lit(Scalar::null(dtype.as_nullable())) -} +impl StatBinder for ZoneMapStatsBinder<'_> { + fn scope(&self) -> &DType { + &self.zone_map.column_dtype + } -fn has_nans(dtype: &DType) -> bool { - matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) + fn bind_stat( + &mut self, + input: &Expression, + stat: Stat, + _stat_dtype: &DType, + ) -> VortexResult> { + if !is_root(input) { + return Ok(None); + } + if self + .zone_map + .array + .unmasked_field_by_name_opt(stat.name()) + .is_none() + { + return Ok(None); + } + Ok(Some(get_item(stat.name(), root()))) + } } /// Build per-zone row counts for a zone map. @@ -422,7 +344,7 @@ mod tests { } #[test] - fn all_null_stat_fn_lowers_to_null_count_and_row_count() { + fn all_null_stat_fn_lowers_to_unknown_mask() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -437,11 +359,14 @@ mod tests { .unwrap(); let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); } #[test] - fn all_non_null_stat_fn_lowers_to_null_count() { + fn all_non_null_stat_fn_lowers_to_unknown_mask() { let zone_map = ZoneMap::try_new( PType::U64.into(), StructArray::from_fields(&[( @@ -458,12 +383,12 @@ mod tests { let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn non_float_nan_stat_fns_lower_to_constants() { + fn non_float_nan_stat_fns_error() { let zone_map = ZoneMap::try_new( PType::I32.into(), StructArray::try_new(FieldNames::empty(), vec![], 2, Validity::NonNullable).unwrap(), @@ -473,11 +398,8 @@ mod tests { ) .unwrap(); - let mask = zone_map.prune(&all_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); - - let mask = zone_map.prune(&all_non_nan(root()), &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true])); + assert!(zone_map.prune(&all_nan(root()), &SESSION).is_err()); + assert!(zone_map.prune(&all_non_nan(root()), &SESSION).is_err()); } #[test] @@ -507,7 +429,7 @@ mod tests { } #[test] - fn float_min_max_stat_fn_requires_nan_count() { + fn float_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -561,12 +483,12 @@ mod tests { let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), - BoolArray::from_iter([true, false, false]) + BoolArray::from_iter([false, false, false]) ); } #[test] - fn float_cast_min_max_stat_fn_uses_source_nan_count() { + fn float_cast_min_max_stat_fn_requires_all_non_nan() { let zone_map = ZoneMap::try_new( PType::F32.into(), StructArray::from_fields(&[ @@ -603,7 +525,7 @@ mod tests { let pruning_expr = falsify(&expr, PType::F32.into()); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); - assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true])); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); } #[test]