Skip to content

Commit 3186c7c

Browse files
committed
fixes
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 357c352 commit 3186c7c

9 files changed

Lines changed: 136 additions & 87 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ version = "0.1.0"
8989
aho-corasick = "1.1.3"
9090
anyhow = "1.0.97"
9191
arbitrary = "1.3.2"
92-
arc-swap = "1.8"
92+
arc-swap = "1.9"
9393
arcref = "0.2.0"
9494
arrow-arith = "58"
9595
arrow-array = "58"

vortex-array/src/arrays/extension/compute/rules.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ mod tests {
8888
use crate::IntoArray;
8989
#[expect(deprecated)]
9090
use crate::ToCanonical as _;
91-
use crate::VortexSessionExecute;
9291
use crate::arrays::Constant;
9392
use crate::arrays::ConstantArray;
9493
use crate::arrays::Extension;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Session state for pluggable parent-reduce rules.
5+
//!
6+
//! [`ArrayKernels`] wraps an [`FnRegistry`] keyed by `(parent_encoding_id, child_encoding_id)`
7+
//! and is consulted by the optimizer during execution, before the child encoding's static
8+
//! `PARENT_RULES` are tried. Entries are typed as [`ReduceParentFn`](super::ReduceParentFn).
9+
//!
10+
//! The registry is empty by default. Downstream crates register `ReduceParentFn` values to add
11+
//! new parent-reduce rules or override ones that the child encoding would otherwise run from its
12+
//! static `PARENT_RULES`.
13+
14+
use std::hash::BuildHasher;
15+
use std::sync::Arc;
16+
use std::sync::LazyLock;
17+
18+
use vortex_error::VortexResult;
19+
use vortex_session::Ref;
20+
use vortex_session::SessionExt;
21+
use vortex_session::registry::FnRegistry;
22+
use vortex_session::registry::Id;
23+
use vortex_utils::aliases::DefaultHashBuilder;
24+
25+
use crate::ArrayRef;
26+
27+
static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
28+
29+
/// Pluggable parent-reduce function signature used by [`ArrayKernels`].
30+
///
31+
/// A function of this type rewrites the parent array that holds `child` at `child_idx`, given
32+
/// the child itself and its parent. Returns `Ok(None)` when the function doesn't apply.
33+
///
34+
/// Registered under `(parent_encoding_id, child_encoding_id)`; callers downcast the erased
35+
/// `child`/`parent` to their expected types before applying logic.
36+
pub type ReduceParentFn =
37+
fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
38+
39+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40+
enum FnKind {
41+
Reduce,
42+
ReduceParent,
43+
ExecuteParent,
44+
Execute,
45+
}
46+
47+
/// Session state for pluggable parent-reduce dispatch keyed by `(parent_id, child_id)`.
48+
#[derive(Debug, Default)]
49+
pub struct ArrayKernels {
50+
registry: FnRegistry,
51+
}
52+
53+
impl ArrayKernels {
54+
/// Create an empty session with no rules registered.
55+
pub fn empty() -> Self {
56+
Self::default()
57+
}
58+
59+
pub fn register_reduce_parent(&self, outer: Id, child: Id, f: ReduceParentFn) {
60+
self.registry
61+
.register(self.hash_fn_ids(outer, child, FnKind::ReduceParent), f)
62+
}
63+
64+
pub fn find_reduce_parent(&self, outer: Id, child: Id) -> Option<Arc<ReduceParentFn>> {
65+
self.registry
66+
.find(self.hash_fn_ids(outer, child, FnKind::ReduceParent))
67+
}
68+
69+
pub fn contains_reduce_parent(&self, outer: Id, child: Id) -> bool {
70+
self.registry
71+
.contains(self.hash_fn_ids(outer, child, FnKind::ReduceParent))
72+
}
73+
74+
fn hash_fn_ids(&self, outer: Id, child: Id, fn_kind: FnKind) -> u64 {
75+
FN_HASHER.hash_one((outer, child, fn_kind))
76+
}
77+
}
78+
79+
/// Extension trait for accessing the optimizer registry from a Vortex session.
80+
pub trait ArrayKernelsExt: SessionExt {
81+
/// Returns the optimizer session variable.
82+
fn kernels(&self) -> Ref<'_, ArrayKernels> {
83+
self.get::<ArrayKernels>()
84+
}
85+
}
86+
87+
impl<S: SessionExt> ArrayKernelsExt for S {}

vortex-array/src/optimizer/mod.rs

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
//! `ArrayBuiltins::cast` and `ArrayRef::slice` that build wrapped expressions and need them
1515
//! normalized inline.
1616
//! * [`ArrayOptimizer::optimize_ctx`] — runs the static rules and additionally consults the
17-
//! session's [`OptimizerSession`] registry keyed by `(parent_encoding_id, child_encoding_id)`
17+
//! session's [`ArrayKernels`] registry keyed by `(parent_encoding_id, child_encoding_id)`
1818
//! before each `reduce_parent` step. The execute loop calls this entry point so plugin-
1919
//! registered parent-reduce rules fire during execution.
2020
@@ -26,38 +26,29 @@ use vortex_session::SessionExt;
2626
use vortex_session::VortexSession;
2727

2828
use crate::ArrayRef;
29-
use crate::optimizer::session::OptimizerSession;
29+
use crate::optimizer::kernels::ArrayKernels;
30+
use crate::optimizer::kernels::ReduceParentFn;
3031

32+
pub mod kernels;
3133
pub mod rules;
32-
pub mod session;
33-
34-
/// Pluggable parent-reduce function signature used by [`OptimizerSession`].
35-
///
36-
/// A function of this type rewrites the parent array that holds `child` at `child_idx`, given
37-
/// the child itself and its parent. Returns `Ok(None)` when the function doesn't apply.
38-
///
39-
/// Registered under `(parent_encoding_id, child_encoding_id)`; callers downcast the erased
40-
/// `child`/`parent` to their expected types before applying logic.
41-
pub type ReduceParentFn =
42-
fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
4334

4435
/// Extension trait for optimizing array trees using reduce/reduce_parent rules.
4536
pub trait ArrayOptimizer {
4637
/// Optimize the root array node by running reduce and reduce_parent rules to fixpoint.
4738
///
4839
/// Uses only the child encoding's static `PARENT_RULES`. Use [`Self::optimize_ctx`] from
49-
/// inside the execute loop to also consult the session-scoped [`OptimizerSession`] registry.
40+
/// inside the execute loop to also consult the session-scoped [`ArrayKernels`] registry.
5041
fn optimize(&self) -> VortexResult<ArrayRef>;
5142

52-
/// Like [`Self::optimize`], but additionally consults the [`OptimizerSession`] registered on
43+
/// Like [`Self::optimize`], but additionally consults the [`ArrayKernels`] registered on
5344
/// `session` for each `(parent_encoding_id, child_encoding_id)` pair before the static
54-
/// vtable rules. If `session` does not have an [`OptimizerSession`] registered, falls
45+
/// vtable rules. If `session` does not have an [`ArrayKernels`] registered, falls
5546
/// through to the static rules.
5647
fn optimize_ctx(&self, session: &VortexSession) -> VortexResult<ArrayRef>;
5748

5849
/// Optimize the entire array tree recursively (root and all descendants).
5950
///
60-
/// Consults the [`OptimizerSession`] registered on `session` for each parent/child pair
51+
/// Consults the [`ArrayKernels`] registered on `session` for each parent/child pair
6152
/// encountered during the recursive walk, so plugin-registered rules apply throughout the
6253
/// tree. Requires a [`VortexSession`] unconditionally so the registry is always honored
6354
/// when a recursive optimization is requested.
@@ -80,18 +71,17 @@ impl ArrayOptimizer for ArrayRef {
8071

8172
/// Resolve a pluggable [`ReduceParentFn`] for `(parent, child)` from `session`.
8273
///
83-
/// Returns `None` when no [`OptimizerSession`] is registered, or no function is registered under
74+
/// Returns `None` when no [`ArrayKernels`] is registered, or no function is registered under
8475
/// `(parent.encoding_id(), child.encoding_id())`. The returned `Arc` is owned so the caller can
8576
/// drop the session borrow before invoking it.
8677
fn plugin_reduce_parent(
8778
session: &VortexSession,
8879
parent: &ArrayRef,
8980
child: &ArrayRef,
9081
) -> Option<Arc<ReduceParentFn>> {
91-
session.get_opt::<OptimizerSession>().and_then(|s| {
92-
s.registry()
93-
.find::<ReduceParentFn>(parent.encoding_id(), child.encoding_id())
94-
})
82+
session
83+
.get_opt::<ArrayKernels>()
84+
.and_then(|s| s.find_reduce_parent(parent.encoding_id(), child.encoding_id()))
9585
}
9686

9787
fn try_optimize(

vortex-array/src/optimizer/session.rs

Lines changed: 0 additions & 43 deletions
This file was deleted.

vortex-session/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ all-features = true
2020
workspace = true
2121

2222
[dependencies]
23-
arcref = { workspace = true }
23+
arc-swap = { workspace = true }
2424
dashmap = { workspace = true }
2525
lasso = { workspace = true }
2626
parking_lot = { workspace = true }

vortex-session/src/registry.rs

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@ use std::fmt;
1010
use std::fmt::Debug;
1111
use std::fmt::Display;
1212
use std::fmt::Formatter;
13+
use std::hash::Hash;
1314
use std::ops::Deref;
1415
use std::sync::Arc;
1516
use std::sync::LazyLock;
1617
use std::sync::OnceLock;
1718

19+
use arc_swap::ArcSwap;
1820
use lasso::Spur;
1921
use lasso::ThreadedRodeo;
2022
use parking_lot::RwLock;
2123
use vortex_error::VortexExpect;
2224
use vortex_utils::aliases::dash_map::DashMap;
25+
use vortex_utils::aliases::hash_map::HashMap;
2326

2427
/// Global string interner for [`Id`] values.
2528
static INTERNER: LazyLock<ThreadedRodeo> = LazyLock::new(ThreadedRodeo::new);
@@ -299,8 +302,8 @@ impl<T: Clone> Context<T> {
299302
/// optimizer's parent-reduce registry keys by `(parent_encoding_id, child_encoding_id)` so that
300303
/// downstream crates can override the rule that would normally run from the child encoding's
301304
/// static `PARENT_RULES` set.
302-
#[derive(Clone, Debug, Default)]
303-
pub struct FnRegistry(Arc<DashMap<(Id, Id), Arc<dyn Any + Send + Sync>>>);
305+
#[derive(Debug, Default)]
306+
pub struct FnRegistry(ArcSwap<HashMap<u64, Arc<dyn Any + Send + Sync>>>);
304307

305308
impl FnRegistry {
306309
/// Create a new, empty registry.
@@ -309,26 +312,35 @@ impl FnRegistry {
309312
}
310313

311314
/// Register a function under `(outer, inner)`, replacing any existing entry.
312-
pub fn register<F: Any + Send + Sync>(&self, outer: Id, inner: Id, f: F) {
313-
self.0.insert((outer, inner), Arc::new(f));
315+
pub fn register<F: Any + Send + Sync>(&self, id: u64, f: F) {
316+
let registry = self.0.load();
317+
let mut owned_registry = registry.as_ref().clone();
318+
owned_registry.insert(id, Arc::new(f));
319+
self.0.store(Arc::new(owned_registry));
314320
}
315321

316322
/// Look up a function registered under `(outer, inner)`, downcasting to `F`.
317323
///
318324
/// Returns `None` if no function is registered, or if the registered value is not of type `F`.
319-
pub fn find<F: Any + Send + Sync>(&self, outer: Id, inner: Id) -> Option<Arc<F>> {
320-
let entry = self.0.get(&(outer, inner))?;
321-
Arc::clone(entry.value()).downcast::<F>().ok()
325+
pub fn find<F: Any + Send + Sync>(&self, id: u64) -> Option<Arc<F>> {
326+
let map = self.0.load();
327+
let entry = map.get(&id)?;
328+
Arc::clone(entry).downcast::<F>().ok()
322329
}
323330

324331
/// Return `true` if any function is registered under `(outer, inner)`.
325-
pub fn contains(&self, outer: Id, inner: Id) -> bool {
326-
self.0.contains_key(&(outer, inner))
332+
pub fn contains(&self, id: u64) -> bool {
333+
let map = self.0.load();
334+
map.contains_key(&id)
327335
}
328336
}
329337

330338
#[cfg(test)]
331339
mod fn_registry_tests {
340+
use std::hash::BuildHasher;
341+
342+
use vortex_utils::aliases::DefaultHashBuilder;
343+
332344
use super::FnRegistry;
333345
use super::Id;
334346

@@ -343,12 +355,13 @@ mod fn_registry_tests {
343355
let registry = FnRegistry::default();
344356
let outer = Id::new("test.double");
345357
let inner = Id::new("test.int");
358+
let id = DefaultHashBuilder::default().hash_one((outer, inner));
346359

347-
assert!(!registry.contains(outer, inner));
348-
registry.register::<DoubleFn>(outer, inner, double);
360+
assert!(!registry.contains(id));
361+
registry.register::<DoubleFn>(id, double);
349362

350-
assert!(registry.contains(outer, inner));
351-
let f = registry.find::<DoubleFn>(outer, inner).unwrap();
363+
assert!(registry.contains(id));
364+
let f = registry.find::<DoubleFn>(id).unwrap();
352365
assert_eq!(f(21), 42);
353366
}
354367

@@ -357,17 +370,20 @@ mod fn_registry_tests {
357370
let registry = FnRegistry::default();
358371
let outer = Id::new("test.double");
359372
let inner = Id::new("test.int");
360-
registry.register::<DoubleFn>(outer, inner, double);
373+
let id = DefaultHashBuilder::default().hash_one((outer, inner));
374+
375+
registry.register::<DoubleFn>(id, double);
361376

362377
type OtherFn = fn(i32) -> i32;
363-
assert!(registry.find::<OtherFn>(outer, inner).is_none());
378+
assert!(registry.find::<OtherFn>(id).is_none());
364379
}
365380

366381
#[test]
367382
fn missing_entry_returns_none() {
368383
let registry = FnRegistry::default();
369384
let outer = Id::new("test.missing");
370385
let inner = Id::new("test.int");
371-
assert!(registry.find::<DoubleFn>(outer, inner).is_none());
386+
let id = DefaultHashBuilder::default().hash_one((outer, inner));
387+
assert!(registry.find::<DoubleFn>(id).is_none());
372388
}
373389
}

vortex/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use vortex_array::dtype::session::DTypeSession;
1212
// vortex::expr is in the process of having its dependencies inverted, and will eventually be
1313
// pulled back out into a vortex_expr crate.
1414
pub use vortex_array::expr;
15-
use vortex_array::optimizer::session::OptimizerSession;
15+
use vortex_array::optimizer::kernels::ArrayKernels;
1616
pub use vortex_array::scalar_fn;
1717
use vortex_array::scalar_fn::session::ScalarFnSession;
1818
use vortex_array::session::ArraySession;
@@ -166,7 +166,7 @@ impl VortexSessionDefault for VortexSession {
166166
.with::<ArraySession>()
167167
.with::<LayoutSession>()
168168
.with::<ScalarFnSession>()
169-
.with::<OptimizerSession>()
169+
.with::<ArrayKernels>()
170170
.with::<AggregateFnSession>()
171171
.with::<RuntimeSession>();
172172

0 commit comments

Comments
 (0)