@@ -5,18 +5,36 @@ use std::cell::RefCell;
55use std:: iter;
66
77use itertools:: Itertools ;
8+ use vortex_error:: VortexResult ;
9+ use vortex_session:: VortexSession ;
810use vortex_utils:: aliases:: hash_map:: HashMap ;
911
1012use super :: relation:: Relation ;
13+ use crate :: aggregate_fn:: fns:: all_nan:: AllNan ;
14+ use crate :: aggregate_fn:: fns:: all_non_nan:: AllNonNan ;
15+ use crate :: aggregate_fn:: fns:: all_non_null:: AllNonNull ;
16+ use crate :: aggregate_fn:: fns:: all_null:: AllNull ;
17+ use crate :: aggregate_fn:: fns:: nan_count:: NanCount ;
18+ use crate :: dtype:: DType ;
1119use crate :: dtype:: Field ;
1220use crate :: dtype:: FieldName ;
1321use crate :: dtype:: FieldPath ;
1422use crate :: dtype:: FieldPathSet ;
1523use crate :: expr:: Expression ;
1624use crate :: expr:: StatsCatalog ;
25+ use crate :: expr:: analysis:: referenced_field_paths;
26+ use crate :: expr:: eq;
1727use crate :: expr:: get_item;
28+ use crate :: expr:: lit;
1829use crate :: expr:: root;
1930use crate :: expr:: stats:: Stat ;
31+ use crate :: expr:: traversal:: NodeExt ;
32+ use crate :: expr:: traversal:: Transformed ;
33+ use crate :: scalar:: Scalar ;
34+ use crate :: scalar_fn:: EmptyOptions ;
35+ use crate :: scalar_fn:: ScalarFnVTableExt ;
36+ use crate :: scalar_fn:: fns:: stat:: StatFn ;
37+ use crate :: scalar_fn:: internal:: row_count:: RowCount ;
2038
2139pub type RequiredStats = Relation < FieldPath , Stat > ;
2240
@@ -113,6 +131,163 @@ pub fn checked_pruning_expr(
113131 Some ( ( expr, relation) )
114132}
115133
134+ /// Build a pruning expression using session-registered stats rewrite rules.
135+ ///
136+ /// The returned expression is lowered to the same stats-table field references as
137+ /// [`checked_pruning_expr`]. If a rewrite asks for a stat that is not present in
138+ /// `available_stats`, this returns `Ok(None)`.
139+ pub fn checked_pruning_expr_with_session (
140+ expr : & Expression ,
141+ scope : & DType ,
142+ available_stats : & FieldPathSet ,
143+ session : & VortexSession ,
144+ ) -> VortexResult < Option < ( Expression , RequiredStats ) > > {
145+ let Some ( predicate) = expr. falsify ( scope, session) ? else {
146+ return Ok ( None ) ;
147+ } ;
148+
149+ lower_stat_fns ( predicate, scope, available_stats)
150+ }
151+
152+ fn lower_stat_fns (
153+ predicate : Expression ,
154+ scope : & DType ,
155+ available_stats : & FieldPathSet ,
156+ ) -> VortexResult < Option < ( Expression , RequiredStats ) > > {
157+ let mut required_stats = Relation :: new ( ) ;
158+ let mut missing_stat = false ;
159+ let lowered = predicate
160+ . transform_down ( |expr| {
161+ if !expr. is :: < StatFn > ( ) {
162+ return Ok ( Transformed :: no ( expr) ) ;
163+ }
164+
165+ if let Some ( lowered) =
166+ lower_stat_fn ( & expr, scope, available_stats, & mut required_stats) ?
167+ {
168+ return Ok ( Transformed :: yes ( lowered) ) ;
169+ }
170+
171+ missing_stat = true ;
172+ let dtype = expr. return_dtype ( scope) ?;
173+ Ok ( Transformed :: yes ( null_expr ( dtype) ) )
174+ } ) ?
175+ . into_inner ( ) ;
176+
177+ if missing_stat {
178+ return Ok ( None ) ;
179+ }
180+
181+ Ok ( Some ( ( lowered, required_stats) ) )
182+ }
183+
184+ fn lower_stat_fn (
185+ expr : & Expression ,
186+ scope : & DType ,
187+ available_stats : & FieldPathSet ,
188+ required_stats : & mut RequiredStats ,
189+ ) -> VortexResult < Option < Expression > > {
190+ let options = expr. as_ :: < StatFn > ( ) ;
191+ let aggregate_fn = options. aggregate_fn ( ) ;
192+ let input = expr. child ( 0 ) ;
193+ let input_dtype = input. return_dtype ( scope) ?;
194+
195+ if aggregate_fn. is :: < AllNan > ( ) {
196+ if !has_nans ( & input_dtype) {
197+ return Ok ( Some ( lit ( false ) ) ) ;
198+ }
199+ return lower_stat_ref (
200+ input,
201+ Stat :: NaNCount ,
202+ scope,
203+ available_stats,
204+ required_stats,
205+ )
206+ . map ( |stat| stat. map ( |stat| eq ( stat, row_count_expr ( ) ) ) ) ;
207+ }
208+
209+ if aggregate_fn. is :: < AllNonNan > ( ) {
210+ if !has_nans ( & input_dtype) {
211+ return Ok ( Some ( lit ( true ) ) ) ;
212+ }
213+ return lower_stat_ref (
214+ input,
215+ Stat :: NaNCount ,
216+ scope,
217+ available_stats,
218+ required_stats,
219+ )
220+ . map ( |stat| stat. map ( |stat| eq ( stat, lit ( 0u64 ) ) ) ) ;
221+ }
222+
223+ if aggregate_fn. is :: < NanCount > ( ) && !has_nans ( & input_dtype) {
224+ return Ok ( Some ( lit ( 0u64 ) ) ) ;
225+ }
226+
227+ if aggregate_fn. is :: < AllNull > ( ) {
228+ return lower_stat_ref (
229+ input,
230+ Stat :: NullCount ,
231+ scope,
232+ available_stats,
233+ required_stats,
234+ )
235+ . map ( |stat| stat. map ( |stat| eq ( stat, row_count_expr ( ) ) ) ) ;
236+ }
237+
238+ if aggregate_fn. is :: < AllNonNull > ( ) {
239+ return lower_stat_ref (
240+ input,
241+ Stat :: NullCount ,
242+ scope,
243+ available_stats,
244+ required_stats,
245+ )
246+ . map ( |stat| stat. map ( |stat| eq ( stat, lit ( 0u64 ) ) ) ) ;
247+ }
248+
249+ let Some ( stat) = Stat :: from_aggregate_fn ( aggregate_fn) else {
250+ return Ok ( None ) ;
251+ } ;
252+
253+ lower_stat_ref ( input, stat, scope, available_stats, required_stats)
254+ }
255+
256+ fn lower_stat_ref (
257+ input : & Expression ,
258+ stat : Stat ,
259+ scope : & DType ,
260+ available_stats : & FieldPathSet ,
261+ required_stats : & mut RequiredStats ,
262+ ) -> VortexResult < Option < Expression > > {
263+ let field_paths = referenced_field_paths ( input, scope) ?;
264+ let Some ( field_path) = field_paths. iter ( ) . exactly_one ( ) . ok ( ) else {
265+ return Ok ( None ) ;
266+ } ;
267+ let stat_path = field_path. clone ( ) . push ( stat. name ( ) ) ;
268+ if !available_stats. contains ( & stat_path) {
269+ return Ok ( None ) ;
270+ }
271+
272+ required_stats. insert ( field_path. clone ( ) , stat) ;
273+ Ok ( Some ( get_item (
274+ field_path_stat_field_name ( field_path, stat) ,
275+ root ( ) ,
276+ ) ) )
277+ }
278+
279+ fn row_count_expr ( ) -> Expression {
280+ RowCount . new_expr ( EmptyOptions , [ ] )
281+ }
282+
283+ fn null_expr ( dtype : DType ) -> Expression {
284+ lit ( Scalar :: null ( dtype. as_nullable ( ) ) )
285+ }
286+
287+ fn has_nans ( dtype : & DType ) -> bool {
288+ matches ! ( dtype, DType :: Primitive ( ptype, _) if ptype. is_float( ) )
289+ }
290+
116291#[ cfg( test) ]
117292mod tests {
118293 use rstest:: fixture;
0 commit comments