Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions STYLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
- Maintain a clear separation between logical and physical types
- Keep functions focused and reasonably sized
- Separate public API from internal implementation details
- Prefer one public entrypoint for each piece of functionality; keep helper APIs crate-private
unless callers need them independently.
- Use modules to organize related functionality
- Place tests in a `tests` module or separate test files

Expand Down
38 changes: 38 additions & 0 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12778,10 +12778,14 @@ pub fn vortex_array::expr::Expression::children(&self) -> &alloc::sync::Arc<allo

pub fn vortex_array::expr::Expression::display_tree(&self) -> impl core::fmt::Display

pub fn vortex_array::expr::Expression::falsify(&self, &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::expr::Expression::fmt_sql(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result

pub fn vortex_array::expr::Expression::return_dtype(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>

pub fn vortex_array::expr::Expression::satisfy(&self, &vortex_session::VortexSession) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>

pub fn vortex_array::expr::Expression::scalar_fn(&self) -> &vortex_array::scalar_fn::ScalarFnRef

pub fn vortex_array::expr::Expression::stat_expression(&self, vortex_array::expr::stats::Stat, &dyn vortex_array::expr::pruning::StatsCatalog) -> core::option::Option<vortex_array::expr::Expression>
Expand Down Expand Up @@ -19774,6 +19778,24 @@ pub fn vortex_array::stats::expr::sum(vortex_array::expr::Expression) -> vortex_

pub mod vortex_array::stats::flatbuffers

pub mod vortex_array::stats::session

pub struct vortex_array::stats::session::StatsRewriteSession

impl core::default::Default for vortex_array::stats::StatsRewriteSession

pub fn vortex_array::stats::StatsRewriteSession::default() -> vortex_array::stats::StatsRewriteSession

impl core::fmt::Debug for vortex_array::stats::StatsRewriteSession

pub fn vortex_array::stats::StatsRewriteSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl vortex_session::SessionVar for vortex_array::stats::StatsRewriteSession

pub fn vortex_array::stats::StatsRewriteSession::as_any(&self) -> &dyn core::any::Any

pub fn vortex_array::stats::StatsRewriteSession::as_any_mut(&mut self) -> &mut dyn core::any::Any

pub struct vortex_array::stats::ArrayStats

impl vortex_array::stats::ArrayStats
Expand Down Expand Up @@ -19834,6 +19856,22 @@ pub fn vortex_array::stats::MutTypedStatsSetRef<'_, '_>::is_empty(&self) -> bool

pub fn vortex_array::stats::MutTypedStatsSetRef<'_, '_>::len(&self) -> usize

pub struct vortex_array::stats::StatsRewriteSession

impl core::default::Default for vortex_array::stats::StatsRewriteSession

pub fn vortex_array::stats::StatsRewriteSession::default() -> vortex_array::stats::StatsRewriteSession

impl core::fmt::Debug for vortex_array::stats::StatsRewriteSession

pub fn vortex_array::stats::StatsRewriteSession::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result

impl vortex_session::SessionVar for vortex_array::stats::StatsRewriteSession

pub fn vortex_array::stats::StatsRewriteSession::as_any(&self) -> &dyn core::any::Any

pub fn vortex_array::stats::StatsRewriteSession::as_any_mut(&mut self) -> &mut dyn core::any::Any

pub struct vortex_array::stats::StatsSet

impl vortex_array::stats::StatsSet
Expand Down
17 changes: 17 additions & 0 deletions vortex-array/src/expr/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::sync::Arc;
use itertools::Itertools;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;

use crate::dtype::DType;
use crate::expr::StatsCatalog;
Expand Down Expand Up @@ -135,6 +136,22 @@ impl Expression {
self.scalar_fn().stat_falsification(self, catalog)
}

/// Returns an expression that proves this predicate is definitely false from stats.
///
/// If the returned expression evaluates to `true` for a stats scope, this expression is
/// guaranteed to be false for every row in that scope. `false` and `null` are unknown.
pub fn falsify(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
crate::stats::rewrite::StatsRewriteCtx::new(session).falsify(self)
}

/// Returns an expression that proves this predicate is definitely true from stats.
///
/// If the returned expression evaluates to `true` for a stats scope, this expression is
/// guaranteed to be true for every row in that scope. `false` and `null` are unknown.
pub fn satisfy(&self, session: &VortexSession) -> VortexResult<Option<Expression>> {
crate::stats::rewrite::StatsRewriteCtx::new(session).satisfy(self)
}

/// Returns an expression representing the zoned statistic for the given stat, if available.
///
/// The [`StatsCatalog`] returns expressions that can be evaluated using the zone map as a
Expand Down
3 changes: 3 additions & 0 deletions vortex-array/src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ pub use stats_set::*;
mod array;
pub mod expr;
pub mod flatbuffers;
pub(crate) mod rewrite;
pub mod session;
mod stats_set;

pub use array::*;
pub use session::*;
use vortex_error::VortexExpect;

use crate::expr::stats::Stat;
Expand Down
192 changes: 192 additions & 0 deletions vortex-array/src/stats/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Session-registered rewrite rules for aggregate-backed stats expressions.

use std::fmt::Debug;
use std::sync::Arc;

use vortex_error::VortexResult;
use vortex_session::VortexSession;

use crate::expr::Expression;
use crate::expr::or_collect;
use crate::scalar_fn::ScalarFnId;
use crate::stats::session::StatsRewriteSessionExt;

/// Shared reference to a stats rewrite rule.
pub(crate) type StatsRewriteRuleRef = Arc<dyn StatsRewriteRule>;

/// A plugin-provided rule that rewrites predicates into stats-backed proof expressions.
///
/// A falsifier evaluates to `true` only when the original predicate is definitely false for the
/// current stats scope. A satisfier evaluates to `true` only when the original predicate is
/// definitely true for the current stats scope. Returning `None` means the rule cannot prove
/// anything for the expression.
#[allow(dead_code)]
pub(crate) trait StatsRewriteRule: Debug + Send + Sync + 'static {
/// The scalar function ID this rule applies to.
fn scalar_fn_id(&self) -> ScalarFnId;

/// Rewrite an expression into a stats-backed falsifier.
fn falsify(
&self,
expr: &Expression,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
_ = expr;
_ = ctx;
Ok(None)
}

/// Rewrite an expression into a stats-backed satisfier.
fn satisfy(
&self,
expr: &Expression,
ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
_ = expr;
_ = ctx;
Ok(None)
}
}

/// Context passed to stats rewrite rules.
pub(crate) struct StatsRewriteCtx<'a> {
session: &'a VortexSession,
}

impl<'a> StatsRewriteCtx<'a> {
/// Create a rewrite context for `session`.
pub(crate) fn new(session: &'a VortexSession) -> Self {
Self { session }
}

/// Returns the session that owns the rewrite registry.
pub(crate) fn session(&self) -> &'a VortexSession {
self.session
}

/// Rewrite `expr` into a stats-backed falsifier.
pub(crate) fn falsify(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
rewrite(expr, self, StatsRewriteRule::falsify)
}

/// Rewrite `expr` into a stats-backed satisfier.
pub(crate) fn satisfy(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
rewrite(expr, self, StatsRewriteRule::satisfy)
}
}

fn rewrite(
expr: &Expression,
ctx: &StatsRewriteCtx<'_>,
apply: fn(
&dyn StatsRewriteRule,
&Expression,
&StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>>,
) -> VortexResult<Option<Expression>> {
let rules = ctx
.session()
.stats_rewrites()
.rules_for(expr.scalar_fn().id());
let Some(rules) = rules else {
return Ok(None);
};

let mut rewrites = Vec::new();
for rule in rules.iter() {
if let Some(rewrite) = apply(rule.as_ref(), expr, ctx)? {
rewrites.push(rewrite);
}
}

Ok(or_collect(rewrites))
}

#[cfg(test)]
mod tests {
use vortex_error::VortexResult;
use vortex_session::VortexSession;

use super::StatsRewriteCtx;
use super::StatsRewriteRule;
use crate::expr::Expression;
use crate::expr::lit;
use crate::expr::or;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::fns::literal::Literal;
use crate::stats::session::StatsRewriteSession;
use crate::stats::session::StatsRewriteSessionExt;

#[derive(Debug)]
struct StaticLiteralRule {
falsifier: Option<Expression>,
satisfier: Option<Expression>,
}

impl StatsRewriteRule for StaticLiteralRule {
fn scalar_fn_id(&self) -> ScalarFnId {
Literal.id()
}

fn falsify(
&self,
_expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(self.falsifier.clone())
}

fn satisfy(
&self,
_expr: &Expression,
_ctx: &StatsRewriteCtx<'_>,
) -> VortexResult<Option<Expression>> {
Ok(self.satisfier.clone())
}
}

#[test]
fn combines_multiple_falsifiers_with_or() -> VortexResult<()> {
let session = VortexSession::empty().with::<StatsRewriteSession>();
session.stats_rewrites().register(StaticLiteralRule {
falsifier: Some(lit(false)),
satisfier: None,
});
session.stats_rewrites().register(StaticLiteralRule {
falsifier: Some(lit(true)),
satisfier: None,
});

assert_eq!(lit(7).falsify(&session)?, Some(or(lit(false), lit(true))));
Ok(())
}

#[test]
fn combines_multiple_satisfiers_with_or() -> VortexResult<()> {
let session = VortexSession::empty().with::<StatsRewriteSession>();
session.stats_rewrites().register(StaticLiteralRule {
falsifier: None,
satisfier: Some(lit(false)),
});
session.stats_rewrites().register(StaticLiteralRule {
falsifier: None,
satisfier: Some(lit(true)),
});

assert_eq!(lit(7).satisfy(&session)?, Some(or(lit(false), lit(true))));
Ok(())
}

#[test]
fn unregistered_expression_has_no_rewrite() -> VortexResult<()> {
let session = VortexSession::empty().with::<StatsRewriteSession>();

assert_eq!(lit(7).falsify(&session)?, None);
assert_eq!(lit(7).satisfy(&session)?, None);
Ok(())
}
}
70 changes: 70 additions & 0 deletions vortex-array/src/stats/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Session state for stats rewrite rules.

use std::any::Any;
use std::sync::Arc;

use parking_lot::RwLock;
use vortex_session::Ref;
use vortex_session::SessionExt;
use vortex_session::SessionVar;
use vortex_utils::aliases::hash_map::HashMap;

use crate::scalar_fn::ScalarFnId;
use crate::stats::rewrite::StatsRewriteRule;
use crate::stats::rewrite::StatsRewriteRuleRef;

type StatsRewriteRuleSet = Arc<[StatsRewriteRuleRef]>;

/// Session state for stats rewrite rules.
#[derive(Debug, Default)]
pub struct StatsRewriteSession {
rules: RwLock<HashMap<ScalarFnId, StatsRewriteRuleSet>>,
}

impl StatsRewriteSession {
/// Register a stats rewrite rule.
#[allow(dead_code)]
pub(crate) fn register<R: StatsRewriteRule>(&self, rule: R) {
self.register_ref(Arc::new(rule));
}

/// Register a shared stats rewrite rule.
#[allow(dead_code)]
pub(crate) fn register_ref(&self, rule: StatsRewriteRuleRef) {
let mut rules = self.rules.write();
let rule_id = rule.scalar_fn_id();
let mut updated_rules = rules
.get(&rule_id)
.map(|rules| rules.iter().cloned().collect::<Vec<_>>())
.unwrap_or_default();
updated_rules.push(rule);
rules.insert(rule_id, updated_rules.into());
}

/// Return the rewrite rules registered for `scalar_fn_id`.
pub(crate) fn rules_for(&self, scalar_fn_id: ScalarFnId) -> Option<StatsRewriteRuleSet> {
self.rules.read().get(&scalar_fn_id).cloned()
}
}

impl SessionVar for StatsRewriteSession {
fn as_any(&self) -> &dyn Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}

/// Extension trait for accessing stats rewrite session data.
pub(crate) trait StatsRewriteSessionExt: SessionExt {
/// Returns the stats rewrite rule registry.
fn stats_rewrites(&self) -> Ref<'_, StatsRewriteSession> {
self.get::<StatsRewriteSession>()
}
}
impl<S: SessionExt> StatsRewriteSessionExt for S {}
2 changes: 2 additions & 0 deletions vortex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use vortex_array::optimizer::kernels::ArrayKernels;
pub use vortex_array::scalar_fn;
use vortex_array::scalar_fn::session::ScalarFnSession;
use vortex_array::session::ArraySession;
use vortex_array::stats::session::StatsRewriteSession;
use vortex_io::session::RuntimeSession;
use vortex_layout::session::LayoutSession;
use vortex_session::VortexSession;
Expand Down Expand Up @@ -167,6 +168,7 @@ impl VortexSessionDefault for VortexSession {
.with::<ArraySession>()
.with::<LayoutSession>()
.with::<ScalarFnSession>()
.with::<StatsRewriteSession>()
.with::<ArrayKernels>()
.with::<AggregateFnSession>()
.with::<RuntimeSession>();
Expand Down
Loading