Skip to content

Commit 238e467

Browse files
committed
perf[array]: add the SimplifyCache to optimize
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 1f6fb0a commit 238e467

3 files changed

Lines changed: 106 additions & 16 deletions

File tree

vortex-array/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,11 @@ name = "expr_case_when"
143143
path = "benches/expr/case_when_bench.rs"
144144
harness = false
145145

146+
[[bench]]
147+
name = "expr_optimize"
148+
path = "benches/expr/optimize_bench.rs"
149+
harness = false
150+
146151
[[bench]]
147152
name = "chunked_dict_builder"
148153
harness = false

vortex-array/public-api.lock

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12812,8 +12812,6 @@ pub fn vortex_array::expr::Expression::simplify(&self, &vortex_array::dtype::DTy
1281212812

1281312813
pub fn vortex_array::expr::Expression::simplify_untyped(&self) -> vortex_error::VortexResult<vortex_array::expr::Expression>
1281412814

12815-
pub fn vortex_array::expr::Expression::try_optimize(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>
12816-
1281712815
pub fn vortex_array::expr::Expression::try_optimize_recursive(&self, &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<vortex_array::expr::Expression>>
1281812816

1281912817
impl core::clone::Clone for vortex_array::expr::Expression

vortex-array/src/expr/optimize.rs

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,22 @@ impl Expression {
2929
/// 2. `simplify` - type-aware simplifications
3030
/// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx`
3131
pub fn optimize(&self, scope: &DType) -> VortexResult<Expression> {
32+
let cache = SimplifyCache {
33+
scope,
34+
dtype_cache: RefCell::new(HashMap::new()),
35+
};
3236
Ok(self
3337
.clone()
34-
.try_optimize(scope)?
38+
.try_optimize(scope, &cache)?
3539
.unwrap_or_else(|| self.clone()))
3640
}
3741

3842
/// Try to optimize the root expression node only, returning None if no optimizations applied.
39-
pub fn try_optimize(&self, scope: &DType) -> VortexResult<Option<Expression>> {
40-
let cache = SimplifyCache {
41-
scope,
42-
dtype_cache: RefCell::new(HashMap::new()),
43-
};
43+
fn try_optimize(
44+
&self,
45+
scope: &DType,
46+
cache: &SimplifyCache<'_>,
47+
) -> VortexResult<Option<Expression>> {
4448
let reduce_ctx = ExpressionReduceCtx {
4549
scope: scope.clone(),
4650
};
@@ -67,7 +71,7 @@ impl Expression {
6771
}
6872

6973
// Try simplify (typed)
70-
if let Some(simplified) = current.scalar_fn().simplify(&current, &cache)? {
74+
if let Some(simplified) = current.scalar_fn().simplify(&current, cache)? {
7175
current = simplified;
7276
changed = true;
7377
any_optimizations = true;
@@ -114,11 +118,28 @@ impl Expression {
114118

115119
/// Try to optimize the entire expression tree recursively.
116120
pub fn try_optimize_recursive(&self, scope: &DType) -> VortexResult<Option<Expression>> {
121+
let cache = SimplifyCache {
122+
scope,
123+
dtype_cache: RefCell::new(HashMap::new()),
124+
};
125+
let result = self.try_optimize_recursive_inner(scope, &cache)?;
126+
127+
// Apply the between optimization once at the top level only.
128+
// TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
129+
// to CNF?
130+
Ok(Some(find_between(result.unwrap_or_else(|| self.clone()))))
131+
}
132+
133+
fn try_optimize_recursive_inner(
134+
&self,
135+
scope: &DType,
136+
cache: &SimplifyCache<'_>,
137+
) -> VortexResult<Option<Expression>> {
117138
let mut current = self.clone();
118139
let mut any_optimizations = false;
119140

120141
// First optimize the root
121-
if let Some(optimized) = current.clone().try_optimize(scope)? {
142+
if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
122143
current = optimized;
123144
any_optimizations = true;
124145
}
@@ -127,7 +148,7 @@ impl Expression {
127148
let mut new_children = Vec::with_capacity(current.children().len());
128149
let mut any_child_optimized = false;
129150
for child in current.children().iter() {
130-
if let Some(optimized) = child.try_optimize_recursive(scope)? {
151+
if let Some(optimized) = child.try_optimize_recursive_inner(scope, cache)? {
131152
new_children.push(optimized);
132153
any_child_optimized = true;
133154
} else {
@@ -140,15 +161,11 @@ impl Expression {
140161
any_optimizations = true;
141162

142163
// After updating children, try to optimize root again
143-
if let Some(optimized) = current.clone().try_optimize(scope)? {
164+
if let Some(optimized) = current.clone().try_optimize(scope, cache)? {
144165
current = optimized;
145166
}
146167
}
147168

148-
// TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
149-
// to CNF?
150-
let current = find_between(current);
151-
152169
if any_optimizations {
153170
Ok(Some(current))
154171
} else {
@@ -294,3 +311,73 @@ impl ReduceCtx for ExpressionReduceCtx {
294311
}))
295312
}
296313
}
314+
315+
#[cfg(test)]
316+
#[expect(clippy::cast_possible_truncation)]
317+
mod tests {
318+
use std::time::Instant;
319+
320+
use vortex_error::VortexResult;
321+
322+
use crate::dtype::DType;
323+
use crate::dtype::Nullability;
324+
use crate::dtype::PType;
325+
use crate::dtype::StructFields;
326+
use crate::expr::Expression;
327+
use crate::expr::eq;
328+
use crate::expr::get_item;
329+
use crate::expr::lit;
330+
use crate::expr::or;
331+
use crate::expr::root;
332+
333+
fn build_large_or_chain(n: usize) -> Expression {
334+
let base = eq(get_item("x", root()), lit(0i32));
335+
(1..n).fold(base, |acc, i| or(acc, eq(get_item("x", root()), lit(i as i32))))
336+
}
337+
338+
fn struct_scope() -> DType {
339+
DType::Struct(
340+
StructFields::new(
341+
["x"].into(),
342+
vec![DType::Primitive(PType::I32, Nullability::NonNullable)],
343+
),
344+
Nullability::NonNullable,
345+
)
346+
}
347+
348+
#[test]
349+
fn optimize_large_or_chain_does_not_hang() -> VortexResult<()> {
350+
let expr = build_large_or_chain(200);
351+
let scope = struct_scope();
352+
353+
let start = Instant::now();
354+
let _result = expr.optimize_recursive(&scope)?;
355+
let elapsed = start.elapsed();
356+
357+
// This should complete in well under a second. Before the fix, 200 ORs could take
358+
// many seconds due to per-node cache recreation and repeated find_between calls.
359+
assert!(
360+
elapsed.as_secs() < 5,
361+
"optimize_recursive took {elapsed:?} for 200 ORs — regression detected"
362+
);
363+
Ok(())
364+
}
365+
366+
#[test]
367+
fn optimize_or_chain_correctness() -> VortexResult<()> {
368+
// Verify the optimizer still produces correct results for a small OR chain.
369+
let expr = or(
370+
eq(get_item("x", root()), lit(1i32)),
371+
eq(get_item("x", root()), lit(2i32)),
372+
);
373+
let scope = struct_scope();
374+
let optimized = expr.optimize_recursive(&scope)?;
375+
376+
// The expression should still reference column "x" and both literals.
377+
let s = optimized.to_string();
378+
assert!(s.contains("$.x"), "expected $.x in {s}");
379+
assert!(s.contains("1i32") || s.contains('1'), "expected 1 in {s}");
380+
assert!(s.contains("2i32") || s.contains('2'), "expected 2 in {s}");
381+
Ok(())
382+
}
383+
}

0 commit comments

Comments
 (0)