Skip to content

Commit ff0b1c8

Browse files
test: Add DataFrame API test for FILTER clause on aggregate window functions
1 parent 930c841 commit ff0b1c8

1 file changed

Lines changed: 77 additions & 0 deletions

File tree

  • datafusion/core/tests/dataframe

datafusion/core/tests/dataframe/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,83 @@ async fn window_using_aggregates() -> Result<()> {
959959
Ok(())
960960
}
961961

962+
#[tokio::test]
963+
async fn window_aggregates_with_filter() -> Result<()> {
964+
// Define a small in-memory table to make expected values clear
965+
let ts: Int32Array = [1, 2, 3, 4, 5].into_iter().collect();
966+
let val: Int32Array = [-3, -2, 1, 4, -1].into_iter().collect();
967+
let batch = RecordBatch::try_from_iter(vec![
968+
("ts", Arc::new(ts) as _),
969+
("val", Arc::new(val) as _),
970+
])?;
971+
972+
let ctx = SessionContext::new();
973+
ctx.register_batch("t", batch)?;
974+
975+
let df = ctx.table("t").await?;
976+
977+
// Build filtered window aggregates over ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
978+
let mut exprs = vec![
979+
(datafusion_functions_aggregate::sum::sum_udaf(), "sum_pos"),
980+
(
981+
datafusion_functions_aggregate::average::avg_udaf(),
982+
"avg_pos",
983+
),
984+
(
985+
datafusion_functions_aggregate::min_max::min_udaf(),
986+
"min_pos",
987+
),
988+
(
989+
datafusion_functions_aggregate::min_max::max_udaf(),
990+
"max_pos",
991+
),
992+
(
993+
datafusion_functions_aggregate::count::count_udaf(),
994+
"cnt_pos",
995+
),
996+
]
997+
.into_iter()
998+
.map(|(func, alias)| {
999+
let w = WindowFunction::new(
1000+
WindowFunctionDefinition::AggregateUDF(func),
1001+
vec![col("val")],
1002+
);
1003+
1004+
Expr::from(w)
1005+
.order_by(vec![col("ts").sort(true, true)])
1006+
.window_frame(WindowFrame::new_bounds(
1007+
WindowFrameUnits::Rows,
1008+
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1009+
WindowFrameBound::CurrentRow,
1010+
))
1011+
.filter(col("val").gt(lit(0)))
1012+
.build()
1013+
.unwrap()
1014+
.alias(alias)
1015+
})
1016+
.collect::<Vec<_>>();
1017+
exprs.extend_from_slice(&[col("ts"), col("val")]);
1018+
1019+
let results = df.select(exprs)?.collect().await?;
1020+
1021+
assert_snapshot!(
1022+
batches_to_string(&results),
1023+
@r###"
1024+
+---------+---------+---------+---------+---------+----+-----+
1025+
| sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val |
1026+
+---------+---------+---------+---------+---------+----+-----+
1027+
| | | | | 0 | 1 | -3 |
1028+
| | | | | 0 | 2 | -2 |
1029+
| 1 | 1.0 | 1 | 1 | 1 | 3 | 1 |
1030+
| 5 | 2.5 | 1 | 4 | 2 | 4 | 4 |
1031+
| 5 | 2.5 | 1 | 4 | 2 | 5 | -1 |
1032+
+---------+---------+---------+---------+---------+----+-----+
1033+
"###
1034+
);
1035+
1036+
Ok(())
1037+
}
1038+
9621039
// Test issue: https://github.com/apache/datafusion/issues/10346
9631040
#[tokio::test]
9641041
async fn test_select_over_aggregate_schema() -> Result<()> {

0 commit comments

Comments
 (0)