@@ -10,11 +10,6 @@ use vortex_session::VortexSession;
1010use vortex_utils:: aliases:: hash_map:: HashMap ;
1111
1212use super :: relation:: Relation ;
13- use crate :: aggregate_fn:: fns:: all_nan:: AllNan ;
14- use crate :: aggregate_fn:: fns:: all_non_nan:: AllNonNan ;
15- use crate :: aggregate_fn:: fns:: all_non_null:: AllNonNull ;
16- use crate :: aggregate_fn:: fns:: all_null:: AllNull ;
17- use crate :: aggregate_fn:: fns:: nan_count:: NanCount ;
1813use crate :: dtype:: DType ;
1914use crate :: dtype:: Field ;
2015use crate :: dtype:: FieldName ;
@@ -23,18 +18,11 @@ use crate::dtype::FieldPathSet;
2318use crate :: expr:: Expression ;
2419use crate :: expr:: StatsCatalog ;
2520use crate :: expr:: analysis:: referenced_field_paths;
26- use crate :: expr:: eq;
2721use crate :: expr:: get_item;
28- use crate :: expr:: lit;
2922use crate :: expr:: root;
3023use crate :: expr:: stats:: Stat ;
31- use crate :: expr:: traversal:: NodeExt ;
32- use crate :: expr:: traversal:: Transformed ;
33- use crate :: scalar:: Scalar ;
34- use crate :: scalar_fn:: EmptyOptions ;
35- use crate :: scalar_fn:: ScalarFnVTableExt ;
36- use crate :: scalar_fn:: fns:: stat:: StatFn ;
37- use crate :: scalar_fn:: internal:: row_count:: RowCount ;
24+ use crate :: stats:: bind:: StatBinder ;
25+ use crate :: stats:: bind:: bind_stats;
3826
3927pub type RequiredStats = Relation < FieldPath , Stat > ;
4028
@@ -146,146 +134,54 @@ pub fn checked_pruning_expr_with_session(
146134 return Ok ( None ) ;
147135 } ;
148136
149- lower_stat_fns ( predicate, scope, available_stats)
150- }
151-
152- fn lower_stat_fns (
153- predicate : Expression ,
154- scope : & DType ,
155- available_stats : & FieldPathSet ,
156- ) -> VortexResult < Option < ( Expression , RequiredStats ) > > {
157- let mut required_stats = Relation :: new ( ) ;
158- let mut missing_stat = false ;
159- let lowered = predicate
160- . transform_down ( |expr| {
161- if !expr. is :: < StatFn > ( ) {
162- return Ok ( Transformed :: no ( expr) ) ;
163- }
164-
165- if let Some ( lowered) =
166- lower_stat_fn ( & expr, scope, available_stats, & mut required_stats) ?
167- {
168- return Ok ( Transformed :: yes ( lowered) ) ;
169- }
170-
171- missing_stat = true ;
172- let dtype = expr. return_dtype ( scope) ?;
173- Ok ( Transformed :: yes ( null_expr ( dtype) ) )
174- } ) ?
175- . into_inner ( ) ;
176-
177- if missing_stat {
137+ let mut binder = RequiredStatsBinder {
138+ scope,
139+ available_stats,
140+ required_stats : Relation :: new ( ) ,
141+ } ;
142+ let Some ( lowered) = bind_stats ( predicate, & mut binder) ? else {
178143 return Ok ( None ) ;
179- }
144+ } ;
180145
181- Ok ( Some ( ( lowered, required_stats) ) )
146+ Ok ( Some ( ( lowered, binder . required_stats ) ) )
182147}
183148
184- fn lower_stat_fn (
185- expr : & Expression ,
186- scope : & DType ,
187- available_stats : & FieldPathSet ,
188- required_stats : & mut RequiredStats ,
189- ) -> VortexResult < Option < Expression > > {
190- let options = expr. as_ :: < StatFn > ( ) ;
191- let aggregate_fn = options. aggregate_fn ( ) ;
192- let input = expr. child ( 0 ) ;
193- let input_dtype = input. return_dtype ( scope) ?;
194-
195- if aggregate_fn. is :: < AllNan > ( ) {
196- if !has_nans ( & input_dtype) {
197- return Ok ( Some ( lit ( false ) ) ) ;
198- }
199- return lower_stat_ref (
200- input,
201- Stat :: NaNCount ,
202- scope,
203- available_stats,
204- required_stats,
205- )
206- . map ( |stat| stat. map ( |stat| eq ( stat, row_count_expr ( ) ) ) ) ;
207- }
208-
209- if aggregate_fn. is :: < AllNonNan > ( ) {
210- if !has_nans ( & input_dtype) {
211- return Ok ( Some ( lit ( true ) ) ) ;
212- }
213- return lower_stat_ref (
214- input,
215- Stat :: NaNCount ,
216- scope,
217- available_stats,
218- required_stats,
219- )
220- . map ( |stat| stat. map ( |stat| eq ( stat, lit ( 0u64 ) ) ) ) ;
221- }
149+ struct RequiredStatsBinder < ' a > {
150+ scope : & ' a DType ,
151+ available_stats : & ' a FieldPathSet ,
152+ required_stats : RequiredStats ,
153+ }
222154
223- if aggregate_fn. is :: < NanCount > ( ) && !has_nans ( & input_dtype) {
224- return Ok ( Some ( lit ( 0u64 ) ) ) ;
155+ impl StatBinder for RequiredStatsBinder < ' _ > {
156+ fn scope ( & self ) -> & DType {
157+ self . scope
225158 }
226159
227- if aggregate_fn. is :: < AllNull > ( ) {
228- return lower_stat_ref (
229- input,
230- Stat :: NullCount ,
231- scope,
232- available_stats,
233- required_stats,
234- )
235- . map ( |stat| stat. map ( |stat| eq ( stat, row_count_expr ( ) ) ) ) ;
236- }
160+ fn bind_stat (
161+ & mut self ,
162+ input : & Expression ,
163+ stat : Stat ,
164+ _stat_dtype : & DType ,
165+ ) -> VortexResult < Option < Expression > > {
166+ let field_paths = referenced_field_paths ( input, self . scope ) ?;
167+ let Some ( field_path) = field_paths. iter ( ) . exactly_one ( ) . ok ( ) else {
168+ return Ok ( None ) ;
169+ } ;
170+ let stat_path = field_path. clone ( ) . push ( stat. name ( ) ) ;
171+ if !self . available_stats . contains ( & stat_path) {
172+ return Ok ( None ) ;
173+ }
237174
238- if aggregate_fn. is :: < AllNonNull > ( ) {
239- return lower_stat_ref (
240- input,
241- Stat :: NullCount ,
242- scope,
243- available_stats,
244- required_stats,
245- )
246- . map ( |stat| stat. map ( |stat| eq ( stat, lit ( 0u64 ) ) ) ) ;
175+ self . required_stats . insert ( field_path. clone ( ) , stat) ;
176+ Ok ( Some ( get_item (
177+ field_path_stat_field_name ( field_path, stat) ,
178+ root ( ) ,
179+ ) ) )
247180 }
248181
249- let Some ( stat) = Stat :: from_aggregate_fn ( aggregate_fn) else {
250- return Ok ( None ) ;
251- } ;
252-
253- lower_stat_ref ( input, stat, scope, available_stats, required_stats)
254- }
255-
256- fn lower_stat_ref (
257- input : & Expression ,
258- stat : Stat ,
259- scope : & DType ,
260- available_stats : & FieldPathSet ,
261- required_stats : & mut RequiredStats ,
262- ) -> VortexResult < Option < Expression > > {
263- let field_paths = referenced_field_paths ( input, scope) ?;
264- let Some ( field_path) = field_paths. iter ( ) . exactly_one ( ) . ok ( ) else {
265- return Ok ( None ) ;
266- } ;
267- let stat_path = field_path. clone ( ) . push ( stat. name ( ) ) ;
268- if !available_stats. contains ( & stat_path) {
269- return Ok ( None ) ;
182+ fn missing_stat ( & mut self , _dtype : DType ) -> VortexResult < Option < Expression > > {
183+ Ok ( None )
270184 }
271-
272- required_stats. insert ( field_path. clone ( ) , stat) ;
273- Ok ( Some ( get_item (
274- field_path_stat_field_name ( field_path, stat) ,
275- root ( ) ,
276- ) ) )
277- }
278-
279- fn row_count_expr ( ) -> Expression {
280- RowCount . new_expr ( EmptyOptions , [ ] )
281- }
282-
283- fn null_expr ( dtype : DType ) -> Expression {
284- lit ( Scalar :: null ( dtype. as_nullable ( ) ) )
285- }
286-
287- fn has_nans ( dtype : & DType ) -> bool {
288- matches ! ( dtype, DType :: Primitive ( ptype, _) if ptype. is_float( ) )
289185}
290186
291187#[ cfg( test) ]
0 commit comments