Skip to content

Commit f418844

Browse files
authored
Add ability to override function behaviour via registry in VortexSession (#7588)
This logic isn't used yet but will be used to allow us to customise behaviour of functions depending on an integration point, i.e. Datafusion can have it's casting logic that is different from arrow casting logic while everyone using vortex can still continue calling `cast` and not specialize for the engine they're using Thing to consider is whether we want require passing session to optimise or whether we should remove the implicit optimise calls and defer them to execute loop The next pr will replace struct casting logic with Arrow and DF specific behaviour. --------- Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent a001ba8 commit f418844

12 files changed

Lines changed: 260 additions & 29 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ version = "0.1.0"
8989
aho-corasick = "1.1.3"
9090
anyhow = "1.0.97"
9191
arbitrary = "1.3.2"
92-
arc-swap = "1.8"
92+
arc-swap = "1.9"
9393
arcref = "0.2.0"
9494
arrow-arith = "58"
9595
arrow-array = "58"

vortex-array/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ workspace = true
2121

2222
[dependencies]
2323
arbitrary = { workspace = true, optional = true }
24+
arc-swap = { workspace = true }
2425
arcref = { workspace = true }
2526
arrow-arith = { workspace = true }
2627
arrow-array = { workspace = true, features = ["ffi"] }

vortex-array/public-api.lock

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13222,6 +13222,36 @@ pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normaliz
1322213222

1322313223
pub mod vortex_array::optimizer
1322413224

13225+
pub mod vortex_array::optimizer::kernels
13226+
13227+
pub struct vortex_array::optimizer::kernels::ArrayKernels
13228+
13229+
impl vortex_array::optimizer::kernels::ArrayKernels
13230+
13231+
pub fn vortex_array::optimizer::kernels::ArrayKernels::empty() -> Self
13232+
13233+
pub fn vortex_array::optimizer::kernels::ArrayKernels::find_reduce_parent(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id) -> core::option::Option<alloc::sync::Arc<[vortex_array::optimizer::kernels::ReduceParentFn]>>
13234+
13235+
pub fn vortex_array::optimizer::kernels::ArrayKernels::register_reduce_parent<I: core::iter::traits::collect::IntoIterator<Item = vortex_array::optimizer::kernels::ReduceParentFn>>(&self, parent: vortex_session::registry::Id, child: vortex_session::registry::Id, fns: I)
13236+
13237+
impl core::default::Default for vortex_array::optimizer::kernels::ArrayKernels
13238+
13239+
pub fn vortex_array::optimizer::kernels::ArrayKernels::default() -> vortex_array::optimizer::kernels::ArrayKernels
13240+
13241+
impl core::fmt::Debug for vortex_array::optimizer::kernels::ArrayKernels
13242+
13243+
pub fn vortex_array::optimizer::kernels::ArrayKernels::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
13244+
13245+
pub trait vortex_array::optimizer::kernels::ArrayKernelsExt: vortex_session::SessionExt
13246+
13247+
pub fn vortex_array::optimizer::kernels::ArrayKernelsExt::kernels(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::kernels::ArrayKernels>
13248+
13249+
impl<S: vortex_session::SessionExt> vortex_array::optimizer::kernels::ArrayKernelsExt for S
13250+
13251+
pub fn S::kernels(&self) -> vortex_session::Ref<'_, vortex_array::optimizer::kernels::ArrayKernels>
13252+
13253+
pub type vortex_array::optimizer::kernels::ReduceParentFn = fn(child: &vortex_array::ArrayRef, parent: &vortex_array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>
13254+
1322513255
pub mod vortex_array::optimizer::rules
1322613256

1322713257
pub struct vortex_array::optimizer::rules::ParentReduceRuleAdapter<V, R>
@@ -13364,13 +13394,17 @@ pub trait vortex_array::optimizer::ArrayOptimizer
1336413394

1336513395
pub fn vortex_array::optimizer::ArrayOptimizer::optimize(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1336613396

13367-
pub fn vortex_array::optimizer::ArrayOptimizer::optimize_recursive(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
13397+
pub fn vortex_array::optimizer::ArrayOptimizer::optimize_ctx(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::ArrayRef>
13398+
13399+
pub fn vortex_array::optimizer::ArrayOptimizer::optimize_recursive(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1336813400

1336913401
impl vortex_array::optimizer::ArrayOptimizer for vortex_array::ArrayRef
1337013402

1337113403
pub fn vortex_array::ArrayRef::optimize(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1337213404

13373-
pub fn vortex_array::ArrayRef::optimize_recursive(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
13405+
pub fn vortex_array::ArrayRef::optimize_ctx(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::ArrayRef>
13406+
13407+
pub fn vortex_array::ArrayRef::optimize_recursive(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1337413408

1337513409
pub mod vortex_array::patches
1337613410

@@ -22328,7 +22362,9 @@ impl vortex_array::optimizer::ArrayOptimizer for vortex_array::ArrayRef
2232822362

2232922363
pub fn vortex_array::ArrayRef::optimize(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
2233022364

22331-
pub fn vortex_array::ArrayRef::optimize_recursive(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
22365+
pub fn vortex_array::ArrayRef::optimize_ctx(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::ArrayRef>
22366+
22367+
pub fn vortex_array::ArrayRef::optimize_recursive(&self, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_array::ArrayRef>
2233222368

2233322369
impl vortex_array::scalar_fn::ReduceNode for vortex_array::ArrayRef
2233422370

vortex-array/src/arrays/extension/compute/rules.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,12 @@ impl ArrayParentReduceRule<Extension> for ExtensionFilterPushDownRule {
7878

7979
#[cfg(test)]
8080
mod tests {
81+
use std::sync::LazyLock;
82+
8183
use vortex_buffer::buffer;
8284
use vortex_error::VortexResult;
8385
use vortex_mask::Mask;
86+
use vortex_session::VortexSession;
8487

8588
use crate::IntoArray;
8689
#[expect(deprecated)]
@@ -108,6 +111,10 @@ mod tests {
108111
use crate::scalar::ScalarValue;
109112
use crate::scalar_fn::fns::binary::Binary;
110113
use crate::scalar_fn::fns::operators::Operator;
114+
use crate::session::ArraySession;
115+
116+
static SESSION: LazyLock<VortexSession> =
117+
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
111118

112119
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
113120
struct TestExt;
@@ -220,7 +227,7 @@ mod tests {
220227
.try_new_array(3, Operator::Lt, [constant_ext, ext_array])
221228
.unwrap();
222229

223-
let optimized = scalar_fn_array.optimize_recursive().unwrap();
230+
let optimized = scalar_fn_array.optimize_recursive(&SESSION).unwrap();
224231
let scalar_fn = optimized.as_opt::<ScalarFn>().unwrap();
225232
let children = scalar_fn.children();
226233
let constant = children[0]

vortex-array/src/executor.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
//! 3. **`execute_parent`** -- child-driven fused execution (may read buffers).
1111
//! 4. **`execute`** -- the encoding's own decode step (most expensive).
1212
//!
13-
//! The main entry point is [`DynArray::execute_until`], which uses an explicit work stack
13+
//! The main entry point is [`ArrayRef::execute_until`], which uses an explicit work stack
1414
//! to drive execution iteratively without recursion. Between steps, the optimizer runs
15-
//! reduce/reduce_parent rules to fixpoint.
15+
//! reduce/reduce_parent rules to fixpoint using the active [`ExecutionCtx`] session, so
16+
//! session-registered optimizer kernels participate during execution.
1617
//!
1718
//! See <https://docs.vortex.dev/developer-guide/internals/execution> for a full description
1819
//! of the model.
@@ -88,17 +89,21 @@ impl ArrayRef {
8889
///
8990
/// Each iteration proceeds through three steps in order:
9091
///
91-
/// 1. **Done / canonical check** if `current` satisfies the active done predicate or is
92+
/// 1. **Done / canonical check** - if `current` satisfies the active done predicate or is
9293
/// canonical, splice it back into the stacked parent (if any) and continue, or return.
93-
/// 2. **`execute_parent` on children** try each child's `execute_parent` against `current`
94+
/// 2. **`execute_parent` on children** - try each child's `execute_parent` against `current`
9495
/// as the parent (e.g. `Filter(RunEnd)` → `FilterExecuteAdaptor` fires from RunEnd).
9596
/// If there is a stacked parent frame, the rewritten child is spliced back into it so
9697
/// that optimize and further `execute_parent` can fire on the reconstructed parent
9798
/// (e.g. `Slice(RunEnd)` → `RunEnd` spliced into stacked `Filter` → `Filter(RunEnd)`
9899
/// whose `FilterExecuteAdaptor` fires on the next iteration).
99-
/// 3. **`execute`** call the encoding's own execute step, which either returns `Done` or
100+
/// 3. **`execute`** - call the encoding's own execute step, which either returns `Done` or
100101
/// `ExecuteSlot(i)` to push a child onto the stack for focused execution.
101102
///
103+
/// Optimizer calls in this loop use [`ExecutionCtx::session`], so kernels registered on the
104+
/// session's [`ArrayKernels`](crate::optimizer::kernels::ArrayKernels) are visible between
105+
/// execution steps.
106+
///
102107
/// Note: the returned array may not match `M`. If execution converges to a canonical form
103108
/// that does not match `M`, the canonical array is returned since no further execution
104109
/// progress is possible.
@@ -110,7 +115,7 @@ impl ArrayRef {
110115
let mut stack: Vec<StackFrame> = Vec::new();
111116

112117
for _ in 0..max_iterations() {
113-
// Step 1: done / canonical splice back into stacked parent or return.
118+
// Step 1: done / canonical - splice back into stacked parent or return.
114119
let is_done = stack
115120
.last()
116121
.map_or(M::matches as DonePredicate, |frame| frame.done);
@@ -121,7 +126,7 @@ impl ArrayRef {
121126
return Ok(current);
122127
}
123128
Some(frame) => {
124-
current = frame.put_back(current)?.optimize()?;
129+
current = frame.put_back(current)?.optimize_ctx(ctx.session())?;
125130
continue;
126131
}
127132
}
@@ -137,9 +142,9 @@ impl ArrayRef {
137142
"execute_parent rewrote {} -> {}",
138143
current, rewritten
139144
));
140-
current = rewritten.optimize()?;
145+
current = rewritten.optimize_ctx(ctx.session())?;
141146
if let Some(frame) = stack.pop() {
142-
current = frame.put_back(current)?.optimize()?;
147+
current = frame.put_back(current)?.optimize_ctx(ctx.session())?;
143148
}
144149
continue;
145150
}
@@ -158,7 +163,7 @@ impl ArrayRef {
158163
));
159164
let frame = StackFrame::new(parent, i, done, &child);
160165
stack.push(frame);
161-
current = child.optimize()?;
166+
current = child.optimize_ctx(ctx.session())?;
162167
}
163168
ExecutionStep::Done => {
164169
ctx.log(format_args!("Done: {}", array));
@@ -523,7 +528,7 @@ macro_rules! require_child {
523528
/// execution of child `$idx`.
524529
///
525530
/// Unlike `require_child!`, this is a statement macro (no value produced) and does not clone
526-
/// `$parent` it is moved into the early-return path.
531+
/// `$parent` - it is moved into the early-return path.
527532
///
528533
/// ```ignore
529534
/// require_opt_child!(array, array.patches().map(|p| p.indices()), 1 => Primitive);
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Session-scoped registry for optimizer kernels.
5+
//!
6+
//! [`ArrayKernels`] stores function pointers that participate in array optimization without
7+
//! adding rules to an encoding vtable. The optimizer currently consults it for parent-reduce
8+
//! rewrites before the child encoding's static `PARENT_RULES`. A registered function can
9+
//! therefore add a rule for an extension encoding or take precedence over a built-in rule.
10+
//!
11+
//! Kernel entries are addressed by `(outer_id, child_id, kind)`. For parent-reduce kernels,
12+
//! `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is the
13+
//! child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the parent
14+
//! id is the scalar function id.
15+
//!
16+
//! Sessions created by the top-level `vortex` crate install an empty registry by default. Other
17+
//! sessions can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely
18+
//! on [`ArrayKernelsExt::kernels`] to insert the default value.
19+
20+
use std::hash::BuildHasher;
21+
use std::sync::Arc;
22+
use std::sync::LazyLock;
23+
24+
use arc_swap::ArcSwap;
25+
use vortex_error::VortexResult;
26+
use vortex_session::Ref;
27+
use vortex_session::SessionExt;
28+
use vortex_session::registry::Id;
29+
use vortex_utils::aliases::DefaultHashBuilder;
30+
use vortex_utils::aliases::hash_map::HashMap;
31+
32+
use crate::ArrayRef;
33+
34+
/// Shared hasher used to combine `(outer, child, FnKind)` tuples into [`FnRegistry`] keys.
35+
static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
36+
37+
/// Function pointer for a plugin-provided parent-reduce rewrite.
38+
///
39+
/// The optimizer calls this with the matched `child`, its `parent`, and the slot index where the
40+
/// child appears. Return `Ok(Some(new_parent))` to replace the parent, or `Ok(None)` when the
41+
/// rewrite does not apply.
42+
///
43+
/// Implementations must preserve the parent's logical length and dtype, matching the invariant
44+
/// required of static parent-reduce rules.
45+
pub type ReduceParentFn =
46+
fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
47+
48+
/// Session-scoped registry of optimizer kernel functions.
49+
#[derive(Debug, Default)]
50+
pub struct ArrayKernels {
51+
reduce_parent: ArcSwap<HashMap<u64, Arc<[ReduceParentFn]>>>,
52+
}
53+
54+
impl ArrayKernels {
55+
/// Create an empty [`ArrayKernels`] with no kernels registered.
56+
pub fn empty() -> Self {
57+
Self::default()
58+
}
59+
60+
/// Register a [`ReduceParentFn`] for `(outer, child)`.
61+
///
62+
/// The optimizer will invoke `f` when it sees a parent with encoding id `outer` holding a
63+
/// child with encoding id `child` during a `reduce_parent` step, before trying the child
64+
/// encoding's static `PARENT_RULES`. `outer` is usually the parent array's encoding id. For
65+
/// `ScalarFnArray`, it is the scalar function id, for example `Cast.id()`.
66+
///
67+
/// Replaces any function already registered for the same pair.
68+
pub fn register_reduce_parent<I: IntoIterator<Item = ReduceParentFn>>(
69+
&self,
70+
parent: Id,
71+
child: Id,
72+
fns: I,
73+
) {
74+
let registry = self.reduce_parent.load();
75+
let id = self.hash_fn_ids(parent, child);
76+
let mut owned_registry = registry.as_ref().clone();
77+
if let Some(existing) = owned_registry.remove(&id) {
78+
owned_registry.insert(id, existing.as_ref().iter().cloned().chain(fns).collect());
79+
} else {
80+
owned_registry.insert(id, fns.into_iter().collect());
81+
}
82+
self.reduce_parent.store(Arc::new(owned_registry));
83+
}
84+
85+
/// Look up the [`ReduceParentFn`] registered for `(outer, child)`.
86+
///
87+
/// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
88+
/// function.
89+
pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
90+
let id = self.hash_fn_ids(parent, child);
91+
let map = self.reduce_parent.load();
92+
let entry = map.get(&id)?;
93+
Some(Arc::clone(entry))
94+
}
95+
96+
/// Combine a typed kernel id tuple into the `u64` key expected by the underlying
97+
/// [`FnRegistry`]. All typed helpers use this path so registration and lookup agree.
98+
fn hash_fn_ids(&self, parent: Id, child: Id) -> u64 {
99+
FN_HASHER.hash_one((parent, child))
100+
}
101+
}
102+
103+
/// Extension trait for accessing optimizer kernels from a
104+
/// [`VortexSession`](vortex_session::VortexSession).
105+
pub trait ArrayKernelsExt: SessionExt {
106+
/// Returns the [`ArrayKernels`] session variable, inserting a default-constructed one if
107+
/// none has been registered on the session yet.
108+
fn kernels(&self) -> Ref<'_, ArrayKernels> {
109+
self.get::<ArrayKernels>()
110+
}
111+
}
112+
113+
impl<S: SessionExt> ArrayKernelsExt for S {}

0 commit comments

Comments
 (0)