@@ -959,6 +959,83 @@ async fn window_using_aggregates() -> Result<()> {
959959 Ok ( ( ) )
960960}
961961
962+ #[ tokio:: test]
963+ async fn window_aggregates_with_filter ( ) -> Result < ( ) > {
964+ // Define a small in-memory table to make expected values clear
965+ let ts: Int32Array = [ 1 , 2 , 3 , 4 , 5 ] . into_iter ( ) . collect ( ) ;
966+ let val: Int32Array = [ -3 , -2 , 1 , 4 , -1 ] . into_iter ( ) . collect ( ) ;
967+ let batch = RecordBatch :: try_from_iter ( vec ! [
968+ ( "ts" , Arc :: new( ts) as _) ,
969+ ( "val" , Arc :: new( val) as _) ,
970+ ] ) ?;
971+
972+ let ctx = SessionContext :: new ( ) ;
973+ ctx. register_batch ( "t" , batch) ?;
974+
975+ let df = ctx. table ( "t" ) . await ?;
976+
977+ // Build filtered window aggregates over ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
978+ let mut exprs = vec ! [
979+ ( datafusion_functions_aggregate:: sum:: sum_udaf( ) , "sum_pos" ) ,
980+ (
981+ datafusion_functions_aggregate:: average:: avg_udaf( ) ,
982+ "avg_pos" ,
983+ ) ,
984+ (
985+ datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
986+ "min_pos" ,
987+ ) ,
988+ (
989+ datafusion_functions_aggregate:: min_max:: max_udaf( ) ,
990+ "max_pos" ,
991+ ) ,
992+ (
993+ datafusion_functions_aggregate:: count:: count_udaf( ) ,
994+ "cnt_pos" ,
995+ ) ,
996+ ]
997+ . into_iter ( )
998+ . map ( |( func, alias) | {
999+ let w = WindowFunction :: new (
1000+ WindowFunctionDefinition :: AggregateUDF ( func) ,
1001+ vec ! [ col( "val" ) ] ,
1002+ ) ;
1003+
1004+ Expr :: from ( w)
1005+ . order_by ( vec ! [ col( "ts" ) . sort( true , true ) ] )
1006+ . window_frame ( WindowFrame :: new_bounds (
1007+ WindowFrameUnits :: Rows ,
1008+ WindowFrameBound :: Preceding ( ScalarValue :: UInt64 ( None ) ) ,
1009+ WindowFrameBound :: CurrentRow ,
1010+ ) )
1011+ . filter ( col ( "val" ) . gt ( lit ( 0 ) ) )
1012+ . build ( )
1013+ . unwrap ( )
1014+ . alias ( alias)
1015+ } )
1016+ . collect :: < Vec < _ > > ( ) ;
1017+ exprs. extend_from_slice ( & [ col ( "ts" ) , col ( "val" ) ] ) ;
1018+
1019+ let results = df. select ( exprs) ?. collect ( ) . await ?;
1020+
1021+ assert_snapshot ! (
1022+ batches_to_string( & results) ,
1023+ @r###"
1024+ +---------+---------+---------+---------+---------+----+-----+
1025+ | sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val |
1026+ +---------+---------+---------+---------+---------+----+-----+
1027+ | | | | | 0 | 1 | -3 |
1028+ | | | | | 0 | 2 | -2 |
1029+ | 1 | 1.0 | 1 | 1 | 1 | 3 | 1 |
1030+ | 5 | 2.5 | 1 | 4 | 2 | 4 | 4 |
1031+ | 5 | 2.5 | 1 | 4 | 2 | 5 | -1 |
1032+ +---------+---------+---------+---------+---------+----+-----+
1033+ "###
1034+ ) ;
1035+
1036+ Ok ( ( ) )
1037+ }
1038+
9621039// Test issue: https://github.com/apache/datafusion/issues/10346
9631040#[ tokio:: test]
9641041async fn test_select_over_aggregate_schema ( ) -> Result < ( ) > {
0 commit comments