@@ -29,18 +29,22 @@ impl Expression {
2929 /// 2. `simplify` - type-aware simplifications
3030 /// 3. `reduce` - abstract reduction rules via `ReduceNode`/`ReduceCtx`
3131 pub fn optimize ( & self , scope : & DType ) -> VortexResult < Expression > {
32+ let cache = SimplifyCache {
33+ scope,
34+ dtype_cache : RefCell :: new ( HashMap :: new ( ) ) ,
35+ } ;
3236 Ok ( self
3337 . clone ( )
34- . try_optimize ( scope) ?
38+ . try_optimize ( scope, & cache ) ?
3539 . unwrap_or_else ( || self . clone ( ) ) )
3640 }
3741
3842 /// Try to optimize the root expression node only, returning None if no optimizations applied.
39- pub fn try_optimize ( & self , scope : & DType ) -> VortexResult < Option < Expression > > {
40- let cache = SimplifyCache {
41- scope,
42- dtype_cache : RefCell :: new ( HashMap :: new ( ) ) ,
43- } ;
43+ fn try_optimize (
44+ & self ,
45+ scope : & DType ,
46+ cache : & SimplifyCache < ' _ > ,
47+ ) -> VortexResult < Option < Expression > > {
4448 let reduce_ctx = ExpressionReduceCtx {
4549 scope : scope. clone ( ) ,
4650 } ;
@@ -67,7 +71,7 @@ impl Expression {
6771 }
6872
6973 // Try simplify (typed)
70- if let Some ( simplified) = current. scalar_fn ( ) . simplify ( & current, & cache) ? {
74+ if let Some ( simplified) = current. scalar_fn ( ) . simplify ( & current, cache) ? {
7175 current = simplified;
7276 changed = true ;
7377 any_optimizations = true ;
@@ -114,11 +118,28 @@ impl Expression {
114118
115119 /// Try to optimize the entire expression tree recursively.
116120 pub fn try_optimize_recursive ( & self , scope : & DType ) -> VortexResult < Option < Expression > > {
121+ let cache = SimplifyCache {
122+ scope,
123+ dtype_cache : RefCell :: new ( HashMap :: new ( ) ) ,
124+ } ;
125+ let result = self . try_optimize_recursive_inner ( scope, & cache) ?;
126+
127+ // Apply the between optimization once at the top level only.
128+ // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
129+ // to CNF?
130+ Ok ( Some ( find_between ( result. unwrap_or_else ( || self . clone ( ) ) ) ) )
131+ }
132+
133+ fn try_optimize_recursive_inner (
134+ & self ,
135+ scope : & DType ,
136+ cache : & SimplifyCache < ' _ > ,
137+ ) -> VortexResult < Option < Expression > > {
117138 let mut current = self . clone ( ) ;
118139 let mut any_optimizations = false ;
119140
120141 // First optimize the root
121- if let Some ( optimized) = current. clone ( ) . try_optimize ( scope) ? {
142+ if let Some ( optimized) = current. clone ( ) . try_optimize ( scope, cache ) ? {
122143 current = optimized;
123144 any_optimizations = true ;
124145 }
@@ -127,7 +148,7 @@ impl Expression {
127148 let mut new_children = Vec :: with_capacity ( current. children ( ) . len ( ) ) ;
128149 let mut any_child_optimized = false ;
129150 for child in current. children ( ) . iter ( ) {
130- if let Some ( optimized) = child. try_optimize_recursive ( scope) ? {
151+ if let Some ( optimized) = child. try_optimize_recursive_inner ( scope, cache ) ? {
131152 new_children. push ( optimized) ;
132153 any_child_optimized = true ;
133154 } else {
@@ -140,15 +161,11 @@ impl Expression {
140161 any_optimizations = true ;
141162
142163 // After updating children, try to optimize root again
143- if let Some ( optimized) = current. clone ( ) . try_optimize ( scope) ? {
164+ if let Some ( optimized) = current. clone ( ) . try_optimize ( scope, cache ) ? {
144165 current = optimized;
145166 }
146167 }
147168
148- // TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
149- // to CNF?
150- let current = find_between ( current) ;
151-
152169 if any_optimizations {
153170 Ok ( Some ( current) )
154171 } else {
@@ -294,3 +311,73 @@ impl ReduceCtx for ExpressionReduceCtx {
294311 } ) )
295312 }
296313}
314+
315+ #[ cfg( test) ]
316+ #[ expect( clippy:: cast_possible_truncation) ]
317+ mod tests {
318+ use std:: time:: Instant ;
319+
320+ use vortex_error:: VortexResult ;
321+
322+ use crate :: dtype:: DType ;
323+ use crate :: dtype:: Nullability ;
324+ use crate :: dtype:: PType ;
325+ use crate :: dtype:: StructFields ;
326+ use crate :: expr:: Expression ;
327+ use crate :: expr:: eq;
328+ use crate :: expr:: get_item;
329+ use crate :: expr:: lit;
330+ use crate :: expr:: or;
331+ use crate :: expr:: root;
332+
333+ fn build_large_or_chain ( n : usize ) -> Expression {
334+ let base = eq ( get_item ( "x" , root ( ) ) , lit ( 0i32 ) ) ;
335+ ( 1 ..n) . fold ( base, |acc, i| or ( acc, eq ( get_item ( "x" , root ( ) ) , lit ( i as i32 ) ) ) )
336+ }
337+
338+ fn struct_scope ( ) -> DType {
339+ DType :: Struct (
340+ StructFields :: new (
341+ [ "x" ] . into ( ) ,
342+ vec ! [ DType :: Primitive ( PType :: I32 , Nullability :: NonNullable ) ] ,
343+ ) ,
344+ Nullability :: NonNullable ,
345+ )
346+ }
347+
348+ #[ test]
349+ fn optimize_large_or_chain_does_not_hang ( ) -> VortexResult < ( ) > {
350+ let expr = build_large_or_chain ( 200 ) ;
351+ let scope = struct_scope ( ) ;
352+
353+ let start = Instant :: now ( ) ;
354+ let _result = expr. optimize_recursive ( & scope) ?;
355+ let elapsed = start. elapsed ( ) ;
356+
357+ // This should complete in well under a second. Before the fix, 200 ORs could take
358+ // many seconds due to per-node cache recreation and repeated find_between calls.
359+ assert ! (
360+ elapsed. as_secs( ) < 5 ,
361+ "optimize_recursive took {elapsed:?} for 200 ORs — regression detected"
362+ ) ;
363+ Ok ( ( ) )
364+ }
365+
366+ #[ test]
367+ fn optimize_or_chain_correctness ( ) -> VortexResult < ( ) > {
368+ // Verify the optimizer still produces correct results for a small OR chain.
369+ let expr = or (
370+ eq ( get_item ( "x" , root ( ) ) , lit ( 1i32 ) ) ,
371+ eq ( get_item ( "x" , root ( ) ) , lit ( 2i32 ) ) ,
372+ ) ;
373+ let scope = struct_scope ( ) ;
374+ let optimized = expr. optimize_recursive ( & scope) ?;
375+
376+ // The expression should still reference column "x" and both literals.
377+ let s = optimized. to_string ( ) ;
378+ assert ! ( s. contains( "$.x" ) , "expected $.x in {s}" ) ;
379+ assert ! ( s. contains( "1i32" ) || s. contains( '1' ) , "expected 1 in {s}" ) ;
380+ assert ! ( s. contains( "2i32" ) || s. contains( '2' ) , "expected 2 in {s}" ) ;
381+ Ok ( ( ) )
382+ }
383+ }
0 commit comments