Skip to content

Commit 78a34a0

Browse files
committed
Add ability to override function behaviour via registry in VortexSession
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 223d1df commit 78a34a0

6 files changed

Lines changed: 220 additions & 10 deletions

File tree

vortex-array/public-api.lock

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13388,18 +13388,50 @@ pub fn vortex_array::optimizer::rules::ParentReduceRuleAdapter<V, K>::matches(&s
1338813388

1338913389
pub fn vortex_array::optimizer::rules::ParentReduceRuleAdapter<V, K>::reduce_parent(&self, child: vortex_array::ArrayView<'_, V>, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
1339013390

13391+
pub mod vortex_array::optimizer::session
13392+
13393+
pub struct vortex_array::optimizer::session::OptimizerSession
13394+
13395+
impl vortex_array::optimizer::session::OptimizerSession
13396+
13397+
pub fn vortex_array::optimizer::session::OptimizerSession::empty() -> Self
13398+
13399+
pub fn vortex_array::optimizer::session::OptimizerSession::registry(&self) -> &vortex_session::registry::FnRegistry
13400+
13401+
impl core::default::Default for vortex_array::optimizer::session::OptimizerSession
13402+
13403+
pub fn vortex_array::optimizer::session::OptimizerSession::default() -> vortex_array::optimizer::session::OptimizerSession
13404+
13405+
impl core::fmt::Debug for vortex_array::optimizer::session::OptimizerSession
13406+
13407+
pub fn vortex_array::optimizer::session::OptimizerSession::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
13408+
13409+
pub trait vortex_array::optimizer::session::OptimizerSessionExt: vortex_session::SessionExt
13410+
13411+
pub fn vortex_array::optimizer::session::OptimizerSessionExt::optimizer(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::session::OptimizerSession>
13412+
13413+
impl<S: vortex_session::SessionExt> vortex_array::optimizer::session::OptimizerSessionExt for S
13414+
13415+
pub fn S::optimizer(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::session::OptimizerSession>
13416+
1339113417
pub trait vortex_array::optimizer::ArrayOptimizer
1339213418

1339313419
pub fn vortex_array::optimizer::ArrayOptimizer::optimize(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1339413420

13421+
pub fn vortex_array::optimizer::ArrayOptimizer::optimize_ctx(&self, ctx: &vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::ArrayRef>
13422+
1339513423
pub fn vortex_array::optimizer::ArrayOptimizer::optimize_recursive(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1339613424

1339713425
impl vortex_array::optimizer::ArrayOptimizer for vortex_array::ArrayRef
1339813426

1339913427
pub fn vortex_array::ArrayRef::optimize(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1340013428

13429+
pub fn vortex_array::ArrayRef::optimize_ctx(&self, ctx: &vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::ArrayRef>
13430+
1340113431
pub fn vortex_array::ArrayRef::optimize_recursive(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1340213432

13433+
pub type vortex_array::optimizer::ReduceParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
13434+
1340313435
pub mod vortex_array::patches
1340413436

1340513437
pub struct vortex_array::patches::Patches
@@ -22356,6 +22388,8 @@ impl vortex_array::optimizer::ArrayOptimizer for vortex_array::ArrayRef
2235622388

2235722389
pub fn vortex_array::ArrayRef::optimize(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
2235822390

22391+
pub fn vortex_array::ArrayRef::optimize_ctx(&self, ctx: &vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::ArrayRef>
22392+
2235922393
pub fn vortex_array::ArrayRef::optimize_recursive(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
2236022394

2236122395
impl vortex_array::scalar_fn::ReduceNode for vortex_array::ArrayRef

vortex-array/src/executor.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ impl ArrayRef {
107107
/// For safety, we will error when the number of execution iterations reaches a configurable
108108
/// maximum (default 128, override with `VORTEX_MAX_ITERATIONS`).
109109
pub fn execute_until<M: Matcher>(self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
110-
let mut current = self.optimize()?;
110+
let mut current = self.optimize_ctx(ctx)?;
111111
let mut stack: Vec<StackFrame> = Vec::new();
112112

113113
for _ in 0..max_iterations() {
@@ -122,7 +122,7 @@ impl ArrayRef {
122122
return Ok(current);
123123
}
124124
Some(frame) => {
125-
current = frame.put_back(current)?.optimize()?;
125+
current = frame.put_back(current)?.optimize_ctx(ctx)?;
126126
continue;
127127
}
128128
}
@@ -138,9 +138,9 @@ impl ArrayRef {
138138
"execute_parent rewrote {} -> {}",
139139
current, rewritten
140140
));
141-
current = rewritten.optimize()?;
141+
current = rewritten.optimize_ctx(ctx)?;
142142
if let Some(frame) = stack.pop() {
143-
current = frame.put_back(current)?.optimize()?;
143+
current = frame.put_back(current)?.optimize_ctx(ctx)?;
144144
}
145145
continue;
146146
}
@@ -159,7 +159,7 @@ impl ArrayRef {
159159
));
160160
let frame = StackFrame::new(parent, i, done, &child);
161161
stack.push(frame);
162-
current = child.optimize()?;
162+
current = child.optimize_ctx(ctx)?;
163163
}
164164
ExecutionStep::Done => {
165165
ctx.log(format_args!("Done: {}", array));

vortex-array/src/optimizer/mod.rs

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,89 @@
66
//!
77
//! Optimization runs between execution steps, which is what enables cross-step optimizations:
88
//! after a child is decoded, new `reduce_parent` rules may match that were previously blocked.
9+
//!
10+
//! There are two entry points:
11+
//!
12+
//! * [`ArrayOptimizer::optimize`] — runs the static rules only (the child encoding's
13+
//! `PARENT_RULES`). It does not require an execution context and is used by helpers like
14+
//! `ArrayBuiltins::cast` and `ArrayRef::slice` that build wrapped expressions and need them
15+
//! normalized inline.
16+
//! * [`ArrayOptimizer::optimize_ctx`] — runs the static rules and additionally consults the
17+
//! session's [`OptimizerSession`] registry keyed by `(parent_encoding_id, child_encoding_id)`
18+
//! before each `reduce_parent` step. The execute loop calls this entry point so plugin-
19+
//! registered parent-reduce rules fire during execution.
20+
21+
use std::sync::Arc;
922

1023
use vortex_error::VortexResult;
1124
use vortex_error::vortex_bail;
25+
use vortex_session::SessionExt;
1226

1327
use crate::ArrayRef;
28+
use crate::ExecutionCtx;
29+
use crate::optimizer::session::OptimizerSession;
1430

1531
pub mod rules;
32+
pub mod session;
33+
34+
/// Pluggable parent-reduce function signature used by [`OptimizerSession`].
35+
///
36+
/// A function of this type rewrites the parent array that holds `child` at `child_idx`, given
37+
/// the child itself and its parent. Returns `Ok(None)` when the function doesn't apply.
38+
///
39+
/// Registered under `(parent_encoding_id, child_encoding_id)`; callers downcast the erased
40+
/// `child`/`parent` to their expected types before applying logic.
41+
pub type ReduceParentFn =
42+
fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
1643

1744
/// Extension trait for optimizing array trees using reduce/reduce_parent rules.
1845
pub trait ArrayOptimizer {
19-
/// Optimize the root array node only by running reduce and reduce_parent rules to fixpoint.
46+
/// Optimize the root array node by running reduce and reduce_parent rules to fixpoint.
47+
///
48+
/// Uses only the child encoding's static `PARENT_RULES`. Use [`Self::optimize_ctx`] from
49+
/// inside the execute loop to also consult the session-scoped [`OptimizerSession`] registry.
2050
fn optimize(&self) -> VortexResult<ArrayRef>;
2151

22-
/// Optimize the entire array tree recursively (root and all descendants).
52+
/// Like [`Self::optimize`], but additionally consults the session's [`OptimizerSession`]
53+
/// registry for each `(parent_encoding_id, child_encoding_id)` pair before the static
54+
/// vtable rules.
55+
fn optimize_ctx(&self, ctx: &ExecutionCtx) -> VortexResult<ArrayRef>;
56+
57+
/// Optimize the entire array tree recursively (root and all descendants), static rules only.
2358
fn optimize_recursive(&self) -> VortexResult<ArrayRef>;
2459
}
2560

2661
impl ArrayOptimizer for ArrayRef {
2762
fn optimize(&self) -> VortexResult<ArrayRef> {
28-
Ok(try_optimize(self)?.unwrap_or_else(|| self.clone()))
63+
Ok(try_optimize(self, None)?.unwrap_or_else(|| self.clone()))
64+
}
65+
66+
fn optimize_ctx(&self, ctx: &ExecutionCtx) -> VortexResult<ArrayRef> {
67+
Ok(try_optimize(self, Some(ctx))?.unwrap_or_else(|| self.clone()))
2968
}
3069

3170
fn optimize_recursive(&self) -> VortexResult<ArrayRef> {
3271
Ok(try_optimize_recursive(self)?.unwrap_or_else(|| self.clone()))
3372
}
3473
}
3574

36-
fn try_optimize(array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
75+
/// Resolve a pluggable [`ReduceParentFn`] for `(parent, child)` from the session registry.
76+
///
77+
/// Returns `None` when no [`OptimizerSession`] is registered, or no function is registered under
78+
/// `(parent.encoding_id(), child.encoding_id())`. The returned `Arc` is owned so the caller is
79+
/// free to drop the session borrow before invoking it.
80+
fn plugin_reduce_parent(
81+
ctx: &ExecutionCtx,
82+
parent: &ArrayRef,
83+
child: &ArrayRef,
84+
) -> Option<Arc<ReduceParentFn>> {
85+
ctx.session().get_opt::<OptimizerSession>().and_then(|s| {
86+
s.registry()
87+
.find::<ReduceParentFn>(parent.encoding_id(), child.encoding_id())
88+
})
89+
}
90+
91+
fn try_optimize(array: &ArrayRef, ctx: Option<&ExecutionCtx>) -> VortexResult<Option<ArrayRef>> {
3792
let mut current_array = array.clone();
3893
let mut any_optimizations = false;
3994

@@ -55,6 +110,17 @@ fn try_optimize(array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
55110
// Its important to take all slots here, as `current_array` can change inside the loop.
56111
for (slot_idx, slot) in current_array.slots().iter().enumerate() {
57112
let Some(child) = slot else { continue };
113+
114+
// Registry-based override: tried before the child encoding's static PARENT_RULES.
115+
if let Some(ctx) = ctx
116+
&& let Some(plugin) = plugin_reduce_parent(ctx, &current_array, child)
117+
&& let Some(new_array) = plugin(child, &current_array, slot_idx)?
118+
{
119+
current_array = new_array;
120+
any_optimizations = true;
121+
continue 'outer;
122+
}
123+
58124
if let Some(new_array) = child.reduce_parent(&current_array, slot_idx)? {
59125
// If the parent was replaced, then we attempt to reduce it again.
60126
current_array = new_array;
@@ -80,7 +146,7 @@ fn try_optimize_recursive(array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
80146
let mut current_array = array.clone();
81147
let mut any_optimizations = false;
82148

83-
if let Some(new_array) = try_optimize(&current_array)? {
149+
if let Some(new_array) = try_optimize(&current_array, None)? {
84150
current_array = new_array;
85151
any_optimizations = true;
86152
}

vortex-session/public-api.lock

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,30 @@ impl<T> core::default::Default for vortex_session::registry::Context<T>
4040

4141
pub fn vortex_session::registry::Context<T>::default() -> Self
4242

43+
pub struct vortex_session::registry::FnRegistry(_)
44+
45+
impl vortex_session::registry::FnRegistry
46+
47+
pub fn vortex_session::registry::FnRegistry::contains(&self, outer: vortex_session::registry::Id, inner: vortex_session::registry::Id) -> bool
48+
49+
pub fn vortex_session::registry::FnRegistry::empty() -> Self
50+
51+
pub fn vortex_session::registry::FnRegistry::find<F: core::any::Any + core::marker::Send + core::marker::Sync>(&self, outer: vortex_session::registry::Id, inner: vortex_session::registry::Id) -> core::option::Option<alloc::sync::Arc<F>>
52+
53+
pub fn vortex_session::registry::FnRegistry::register<F: core::any::Any + core::marker::Send + core::marker::Sync>(&self, outer: vortex_session::registry::Id, inner: vortex_session::registry::Id, f: F)
54+
55+
impl core::clone::Clone for vortex_session::registry::FnRegistry
56+
57+
pub fn vortex_session::registry::FnRegistry::clone(&self) -> vortex_session::registry::FnRegistry
58+
59+
impl core::default::Default for vortex_session::registry::FnRegistry
60+
61+
pub fn vortex_session::registry::FnRegistry::default() -> vortex_session::registry::FnRegistry
62+
63+
impl core::fmt::Debug for vortex_session::registry::FnRegistry
64+
65+
pub fn vortex_session::registry::FnRegistry::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
66+
4367
pub struct vortex_session::registry::Id(_)
4468

4569
impl vortex_session::registry::Id

vortex-session/src/registry.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//! Many session types use a registry of objects that can be looked up by name to construct
55
//! contexts. This module provides a generic registry type for that purpose.
66
7+
use std::any::Any;
78
use std::cmp::Ordering;
89
use std::fmt;
910
use std::fmt::Debug;
@@ -287,3 +288,86 @@ impl<T: Clone> Context<T> {
287288
self.ids.read().clone()
288289
}
289290
}
291+
292+
/// A registry of type-erased function values keyed by a pair of [`Id`] values.
293+
///
294+
/// Each entry stores an `Arc<dyn Any + Send + Sync>` wrapping a caller-supplied concrete type `F`
295+
/// (typically a function pointer). Callers recover the original type by passing the same `F` to
296+
/// [`FnRegistry::find`]; `find` returns `Some(Arc<F>)` on a type match and `None` otherwise.
297+
///
298+
/// Used for pluggable dispatch keyed by an `(outer, inner)` identifier pair — for example the
299+
/// optimizer's parent-reduce registry keys by `(parent_encoding_id, child_encoding_id)` so that
300+
/// downstream crates can override the rule that would normally run from the child encoding's
301+
/// static `PARENT_RULES` set.
302+
#[derive(Clone, Debug, Default)]
303+
pub struct FnRegistry(Arc<DashMap<(Id, Id), Arc<dyn Any + Send + Sync>>>);
304+
305+
impl FnRegistry {
306+
/// Create a new, empty registry.
307+
pub fn empty() -> Self {
308+
Self::default()
309+
}
310+
311+
/// Register a function under `(outer, inner)`, replacing any existing entry.
312+
pub fn register<F: Any + Send + Sync>(&self, outer: Id, inner: Id, f: F) {
313+
self.0.insert((outer, inner), Arc::new(f));
314+
}
315+
316+
/// Look up a function registered under `(outer, inner)`, downcasting to `F`.
317+
///
318+
/// Returns `None` if no function is registered, or if the registered value is not of type `F`.
319+
pub fn find<F: Any + Send + Sync>(&self, outer: Id, inner: Id) -> Option<Arc<F>> {
320+
let entry = self.0.get(&(outer, inner))?;
321+
Arc::clone(entry.value()).downcast::<F>().ok()
322+
}
323+
324+
/// Return `true` if any function is registered under `(outer, inner)`.
325+
pub fn contains(&self, outer: Id, inner: Id) -> bool {
326+
self.0.contains_key(&(outer, inner))
327+
}
328+
}
329+
330+
#[cfg(test)]
331+
mod fn_registry_tests {
332+
use super::FnRegistry;
333+
use super::Id;
334+
335+
type DoubleFn = fn(i64) -> i64;
336+
337+
fn double(x: i64) -> i64 {
338+
x * 2
339+
}
340+
341+
#[test]
342+
fn register_and_find() {
343+
let registry = FnRegistry::default();
344+
let outer = Id::new("test.double");
345+
let inner = Id::new("test.int");
346+
347+
assert!(!registry.contains(outer, inner));
348+
registry.register::<DoubleFn>(outer, inner, double);
349+
350+
assert!(registry.contains(outer, inner));
351+
let f = registry.find::<DoubleFn>(outer, inner).unwrap();
352+
assert_eq!(f(21), 42);
353+
}
354+
355+
#[test]
356+
fn find_with_wrong_type_returns_none() {
357+
let registry = FnRegistry::default();
358+
let outer = Id::new("test.double");
359+
let inner = Id::new("test.int");
360+
registry.register::<DoubleFn>(outer, inner, double);
361+
362+
type OtherFn = fn(i32) -> i32;
363+
assert!(registry.find::<OtherFn>(outer, inner).is_none());
364+
}
365+
366+
#[test]
367+
fn missing_entry_returns_none() {
368+
let registry = FnRegistry::default();
369+
let outer = Id::new("test.missing");
370+
let inner = Id::new("test.int");
371+
assert!(registry.find::<DoubleFn>(outer, inner).is_none());
372+
}
373+
}

vortex/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use vortex_array::dtype::session::DTypeSession;
1212
// vortex::expr is in the process of having its dependencies inverted, and will eventually be
1313
// pulled back out into a vortex_expr crate.
1414
pub use vortex_array::expr;
15+
use vortex_array::optimizer::session::OptimizerSession;
1516
pub use vortex_array::scalar_fn;
1617
use vortex_array::scalar_fn::session::ScalarFnSession;
1718
use vortex_array::session::ArraySession;
@@ -165,6 +166,7 @@ impl VortexSessionDefault for VortexSession {
165166
.with::<ArraySession>()
166167
.with::<LayoutSession>()
167168
.with::<ScalarFnSession>()
169+
.with::<OptimizerSession>()
168170
.with::<AggregateFnSession>()
169171
.with::<RuntimeSession>();
170172

0 commit comments

Comments
 (0)