@@ -47,7 +47,8 @@ use arrow_schema::FieldRef;
4747use datafusion_common:: stats:: Precision ;
4848use datafusion_common:: tree_node:: TreeNodeRecursion ;
4949use datafusion_common:: {
50- Constraint , Constraints , Result , ScalarValue , assert_eq_or_internal_err, not_impl_err,
50+ Constraint , Constraints , Result , ScalarValue , assert_eq_or_internal_err,
51+ internal_err, not_impl_err,
5152} ;
5253use datafusion_execution:: TaskContext ;
5354use datafusion_expr:: { Accumulator , Aggregate } ;
@@ -893,6 +894,47 @@ impl AggregateExec {
893894 & self . filter_expr
894895 }
895896
897+ /// Returns the dynamic filter expression for this aggregate, if set.
898+ pub fn dynamic_filter_expr ( & self ) -> Option < & Arc < DynamicFilterPhysicalExpr > > {
899+ self . dynamic_filter . as_ref ( ) . map ( |df| & df. filter )
900+ }
901+
902+ /// Replace the dynamic filter expression. This method errors if the aggregate does not
903+ /// support dynamic filtering or if the filter expression is incompatible with this
904+ /// [`AggregateExec`].
905+ pub fn with_dynamic_filter_expr (
906+ mut self ,
907+ filter : Arc < DynamicFilterPhysicalExpr > ,
908+ ) -> Result < Self > {
909+ // If there is no dynamic filter state initialized via `try_new`, then
910+ // we can safely assume that the aggregate does not support dynamic filtering.
911+ let Some ( dyn_filter) = self . dynamic_filter . as_ref ( ) else {
912+ return internal_err ! ( "Aggregate does not support dynamic filtering" ) ;
913+ } ;
914+
915+ // Validate that the filter is compatible with the aggregation columns.
916+ let cols = self . cols_for_dynamic_filter ( & dyn_filter. supported_accumulators_info ) ;
917+ if cols. len ( ) != filter. children ( ) . len ( ) {
918+ return internal_err ! (
919+ "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
920+ ) ;
921+ }
922+ for ( col, child) in cols. iter ( ) . zip ( filter. children ( ) ) {
923+ if !col. eq ( child) {
924+ return internal_err ! (
925+ "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
926+ ) ;
927+ }
928+ }
929+
930+ // Overwrite our filter
931+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
932+ filter,
933+ supported_accumulators_info : dyn_filter. supported_accumulators_info . clone ( ) ,
934+ } ) ) ;
935+ Ok ( self )
936+ }
937+
896938 /// Input plan
897939 pub fn input ( & self ) -> & Arc < dyn ExecutionPlan > {
898940 & self . input
@@ -1285,6 +1327,28 @@ impl AggregateExec {
12851327 }
12861328 }
12871329
1330+ // Collect column references for the dynamic filter expression from the supported accumulators.
1331+ fn cols_for_dynamic_filter (
1332+ & self ,
1333+ supported_accumulators_info : & [ PerAccumulatorDynFilter ] ,
1334+ ) -> Vec < Arc < dyn PhysicalExpr > > {
1335+ let all_cols: Vec < Arc < dyn PhysicalExpr > > = supported_accumulators_info
1336+ . iter ( )
1337+ . filter_map ( |info| {
1338+ // This should always be true due to how the supported accumulators
1339+ // are constructed. See `init_dynamic_filter` for more details.
1340+ if let [ arg] = & self . aggr_expr [ info. aggr_index ] . expressions ( ) . as_slice ( )
1341+ && arg. is :: < Column > ( )
1342+ {
1343+ return Some ( Arc :: clone ( arg) ) ;
1344+ }
1345+ None
1346+ } )
1347+ . collect ( ) ;
1348+ debug_assert ! ( all_cols. len( ) == supported_accumulators_info. len( ) ) ;
1349+ all_cols
1350+ }
1351+
12881352 /// Calculate scaled byte size based on row count ratio.
12891353 /// Returns `Precision::Absent` if input statistics are insufficient.
12901354 /// Returns `Precision::Inexact` with the scaled value otherwise.
@@ -2200,6 +2264,7 @@ mod tests {
22002264 use crate :: coalesce_partitions:: CoalescePartitionsExec ;
22012265 use crate :: common;
22022266 use crate :: common:: collect;
2267+ use crate :: empty:: EmptyExec ;
22032268 use crate :: execution_plan:: Boundedness ;
22042269 use crate :: expressions:: col;
22052270 use crate :: metrics:: MetricValue ;
@@ -2225,6 +2290,7 @@ mod tests {
22252290 use datafusion_functions_aggregate:: count:: count_udaf;
22262291 use datafusion_functions_aggregate:: first_last:: { first_value_udaf, last_value_udaf} ;
22272292 use datafusion_functions_aggregate:: median:: median_udaf;
2293+ use datafusion_functions_aggregate:: min_max:: min_udaf;
22282294 use datafusion_functions_aggregate:: sum:: sum_udaf;
22292295 use datafusion_physical_expr:: Partitioning ;
22302296 use datafusion_physical_expr:: PhysicalSortExpr ;
@@ -3846,13 +3912,10 @@ mod tests {
38463912 // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
38473913 let aggregates: Vec < Arc < AggregateFunctionExpr > > = vec ! [
38483914 Arc :: new(
3849- AggregateExprBuilder :: new(
3850- datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
3851- vec![ col( "b" , & schema) ?] ,
3852- )
3853- . schema( Arc :: clone( & schema) )
3854- . alias( "MIN(b)" )
3855- . build( ) ?,
3915+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "b" , & schema) ?] )
3916+ . schema( Arc :: clone( & schema) )
3917+ . alias( "MIN(b)" )
3918+ . build( ) ?,
38563919 ) ,
38573920 Arc :: new(
38583921 AggregateExprBuilder :: new( avg_udaf( ) , vec![ col( "b" , & schema) ?] )
@@ -3991,13 +4054,10 @@ mod tests {
39914054 // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
39924055 let aggregates: Vec < Arc < AggregateFunctionExpr > > = vec ! [
39934056 Arc :: new(
3994- AggregateExprBuilder :: new(
3995- datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
3996- vec![ col( "b" , & schema) ?] ,
3997- )
3998- . schema( Arc :: clone( & schema) )
3999- . alias( "MIN(b)" )
4000- . build( ) ?,
4057+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "b" , & schema) ?] )
4058+ . schema( Arc :: clone( & schema) )
4059+ . alias( "MIN(b)" )
4060+ . build( ) ?,
40014061 ) ,
40024062 Arc :: new(
40034063 AggregateExprBuilder :: new( avg_udaf( ) , vec![ col( "b" , & schema) ?] )
@@ -4945,4 +5005,118 @@ mod tests {
49455005
49465006 Ok ( ( ) )
49475007 }
5008+
5009+ /// Test that [`AggregateExec::with_dynamic_filter_expr`] overrides the existing dynamic filter
5010+ #[ test]
5011+ fn test_with_dynamic_filter ( ) -> Result < ( ) > {
5012+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
5013+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
5014+
5015+ // Partial min aggregate supports dynamic filtering
5016+ let agg = AggregateExec :: try_new (
5017+ AggregateMode :: Partial ,
5018+ PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
5019+ vec ! [ Arc :: new(
5020+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "a" , & schema) ?] )
5021+ . schema( Arc :: clone( & schema) )
5022+ . alias( "min_a" )
5023+ . build( ) ?,
5024+ ) ] ,
5025+ vec ! [ None ] ,
5026+ child,
5027+ Arc :: clone ( & schema) ,
5028+ ) ?;
5029+
5030+ // Assertion 1: A filter with the same children can override the existing
5031+ // dynamic filter.
5032+ let new_df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
5033+ vec ! [ col( "a" , & schema) ?] ,
5034+ lit ( false ) ,
5035+ ) ) ;
5036+ let agg = agg. with_dynamic_filter_expr ( Arc :: clone ( & new_df) ) ?;
5037+
5038+ // The aggregate's filter should now resolve to the new inner expression.
5039+ let swapped = agg
5040+ . dynamic_filter_expr ( )
5041+ . expect ( "should still have dynamic filter" )
5042+ . current ( ) ?;
5043+ assert_eq ! ( format!( "{swapped}" ) , format!( "{}" , lit( false ) ) ) ;
5044+
5045+ // Assertion 2: A filter that has been through `PhysicalExpr::with_new_children`
5046+ // should still be accepted when the new children are equivalent to the originals.
5047+ let new_df_as_pexpr: Arc < dyn PhysicalExpr > =
5048+ Arc :: < DynamicFilterPhysicalExpr > :: clone ( & new_df) ;
5049+ let remapped_pexpr =
5050+ new_df_as_pexpr. with_new_children ( vec ! [ col( "a" , & schema) ?] ) ?;
5051+ let Ok ( remapped_df) = ( remapped_pexpr as Arc < dyn std:: any:: Any + Send + Sync > )
5052+ . downcast :: < DynamicFilterPhysicalExpr > ( )
5053+ else {
5054+ panic ! ( "should be DynamicFilterPhysicalExpr after with_new_children" ) ;
5055+ } ;
5056+ // Hard to assert this because the filter is identical. No error means
5057+ // the filter was accepted. That's a good enough assertion for now.
5058+ let _agg = agg. with_dynamic_filter_expr ( remapped_df) ?;
5059+ Ok ( ( ) )
5060+ }
5061+
5062+ /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the aggregate does not support dynamic filtering
5063+ #[ test]
5064+ fn test_with_dynamic_filter_error_unsupported ( ) -> Result < ( ) > {
5065+ let schema = Arc :: new ( Schema :: new ( vec ! [
5066+ Field :: new( "a" , DataType :: Int64 , false ) ,
5067+ Field :: new( "b" , DataType :: Int64 , false ) ,
5068+ ] ) ) ;
5069+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
5070+
5071+ // Final mode with a group-by does not support dynamic filters.
5072+ let agg = AggregateExec :: try_new (
5073+ AggregateMode :: Final ,
5074+ PhysicalGroupBy :: new_single ( vec ! [ ( col( "a" , & schema) ?, "a" . to_string( ) ) ] ) ,
5075+ vec ! [ Arc :: new(
5076+ AggregateExprBuilder :: new( sum_udaf( ) , vec![ col( "b" , & schema) ?] )
5077+ . schema( Arc :: clone( & schema) )
5078+ . alias( "sum_b" )
5079+ . build( ) ?,
5080+ ) ] ,
5081+ vec ! [ None ] ,
5082+ child,
5083+ Arc :: clone ( & schema) ,
5084+ ) ?;
5085+ assert ! ( agg. dynamic_filter_expr( ) . is_none( ) ) ;
5086+
5087+ let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
5088+ vec ! [ col( "a" , & schema) ?] ,
5089+ lit ( true ) ,
5090+ ) ) ;
5091+ assert ! ( agg. with_dynamic_filter_expr( df) . is_err( ) ) ;
5092+ Ok ( ( ) )
5093+ }
5094+
5095+ /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the column is not in the schema
5096+ #[ test]
5097+ fn test_with_dynamic_filter_error_column_mismatch ( ) -> Result < ( ) > {
5098+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
5099+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
5100+
5101+ let agg = AggregateExec :: try_new (
5102+ AggregateMode :: Partial ,
5103+ PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
5104+ vec ! [ Arc :: new(
5105+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "a" , & schema) ?] )
5106+ . schema( Arc :: clone( & schema) )
5107+ . alias( "min_a" )
5108+ . build( ) ?,
5109+ ) ] ,
5110+ vec ! [ None ] ,
5111+ child,
5112+ Arc :: clone ( & schema) ,
5113+ ) ?;
5114+
5115+ let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
5116+ vec ! [ Arc :: new( Column :: new( "bad" , 99 ) ) as _] ,
5117+ lit ( true ) ,
5118+ ) ) ;
5119+ assert ! ( agg. with_dynamic_filter_expr( df) . is_err( ) ) ;
5120+ Ok ( ( ) )
5121+ }
49485122}
0 commit comments