Skip to content

Commit 17c968a

Browse files
feat: Support FILTER clause in aggregate window functions
1 parent abb9b85 commit 17c968a

27 files changed

Lines changed: 439 additions & 76 deletions

File tree

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
200200
window_frame: window_function.params.window_frame,
201201
null_treatment: window_function.params.null_treatment,
202202
distinct: window_function.params.distinct,
203+
filter: window_function.params.filter,
203204
},
204205
}))
205206
};

datafusion/common/src/tree_node.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,48 @@ impl<
990990
}
991991
}
992992

993+
impl<
994+
'a,
995+
T: 'a,
996+
C0: TreeNodeContainer<'a, T>,
997+
C1: TreeNodeContainer<'a, T>,
998+
C2: TreeNodeContainer<'a, T>,
999+
C3: TreeNodeContainer<'a, T>,
1000+
> TreeNodeContainer<'a, T> for (C0, C1, C2, C3)
1001+
{
1002+
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1003+
&'a self,
1004+
mut f: F,
1005+
) -> Result<TreeNodeRecursion> {
1006+
self.0
1007+
.apply_elements(&mut f)?
1008+
.visit_sibling(|| self.1.apply_elements(&mut f))?
1009+
.visit_sibling(|| self.2.apply_elements(&mut f))?
1010+
.visit_sibling(|| self.3.apply_elements(&mut f))
1011+
}
1012+
1013+
fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
1014+
self,
1015+
mut f: F,
1016+
) -> Result<Transformed<Self>> {
1017+
self.0
1018+
.map_elements(&mut f)?
1019+
.map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))?
1020+
.transform_sibling(|(new_c0, c1, c2, c3)| {
1021+
c1.map_elements(&mut f)?
1022+
.map_data(|new_c1| Ok((new_c0, new_c1, c2, c3)))
1023+
})?
1024+
.transform_sibling(|(new_c0, new_c1, c2, c3)| {
1025+
c2.map_elements(&mut f)?
1026+
.map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3)))
1027+
})?
1028+
.transform_sibling(|(new_c0, new_c1, new_c2, c3)| {
1029+
c3.map_elements(&mut f)?
1030+
.map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3)))
1031+
})
1032+
}
1033+
}
1034+
9931035
/// [`TreeNodeRefContainer`] contains references to elements that a function can be
9941036
/// applied on. The elements of the container are siblings so the continuation rules are
9951037
/// similar to [`TreeNodeRecursion::visit_sibling`].
@@ -1065,6 +1107,27 @@ impl<
10651107
}
10661108
}
10671109

1110+
impl<
1111+
'a,
1112+
T: 'a,
1113+
C0: TreeNodeContainer<'a, T>,
1114+
C1: TreeNodeContainer<'a, T>,
1115+
C2: TreeNodeContainer<'a, T>,
1116+
C3: TreeNodeContainer<'a, T>,
1117+
> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3)
1118+
{
1119+
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1120+
&self,
1121+
mut f: F,
1122+
) -> Result<TreeNodeRecursion> {
1123+
self.0
1124+
.apply_elements(&mut f)?
1125+
.visit_sibling(|| self.1.apply_elements(&mut f))?
1126+
.visit_sibling(|| self.2.apply_elements(&mut f))?
1127+
.visit_sibling(|| self.3.apply_elements(&mut f))
1128+
}
1129+
}
1130+
10681131
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
10691132
pub trait TreeNodeIterator: Iterator {
10701133
/// Apples `f` to each item in this iterator

datafusion/core/src/physical_planner.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,7 @@ pub fn create_window_expr_with_name(
16501650
window_frame,
16511651
null_treatment,
16521652
distinct,
1653+
filter,
16531654
},
16541655
} = window_fun.as_ref();
16551656
let physical_args =
@@ -1669,6 +1670,13 @@ pub fn create_window_expr_with_name(
16691670
let window_frame = Arc::new(window_frame.clone());
16701671
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
16711672
== NullTreatment::IgnoreNulls;
1673+
let physical_filter = match filter {
1674+
Some(f) => {
1675+
Some(create_physical_expr(f, logical_schema, execution_props)?)
1676+
}
1677+
None => None,
1678+
};
1679+
16721680
windows::create_window_expr(
16731681
fun,
16741682
name,
@@ -1679,6 +1687,7 @@ pub fn create_window_expr_with_name(
16791687
physical_schema,
16801688
ignore_nulls,
16811689
*distinct,
1690+
physical_filter,
16821691
)
16831692
}
16841693
other => plan_err!("Invalid window expression '{other:?}'"),

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
289289
&extended_schema,
290290
false,
291291
false,
292+
None,
292293
)?;
293294
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
294295
vec![window_expr],
@@ -662,6 +663,7 @@ async fn run_window_test(
662663
&extended_schema,
663664
false,
664665
false,
666+
None,
665667
)?],
666668
exec1,
667669
false,
@@ -681,6 +683,7 @@ async fn run_window_test(
681683
&extended_schema,
682684
false,
683685
false,
686+
None,
684687
)?],
685688
exec2,
686689
search_mode.clone(),

datafusion/core/tests/physical_optimizer/enforce_sorting.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3686,6 +3686,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
36863686
input_schema.as_ref(),
36873687
false,
36883688
false,
3689+
None,
36893690
)?;
36903691
let window_exec = if window_expr.uses_bounded_memory() {
36913692
Arc::new(BoundedWindowAggExec::try_new(

datafusion/core/tests/physical_optimizer/test_utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ pub fn bounded_window_exec_with_partition(
266266
schema.as_ref(),
267267
false,
268268
false,
269+
None,
269270
)
270271
.unwrap();
271272

datafusion/core/tests/physical_optimizer/window_optimize.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod test {
4949
&partition,
5050
&[],
5151
Arc::new(frame),
52+
None,
5253
);
5354

5455
let bounded_agg_exec = BoundedWindowAggExec::try_new(

datafusion/expr/src/expr.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,8 @@ pub struct WindowFunctionParams {
12301230
pub order_by: Vec<Sort>,
12311231
/// Window frame
12321232
pub window_frame: WindowFrame,
1233+
/// Optional filter expression (FILTER (WHERE ...))
1234+
pub filter: Option<Box<Expr>>,
12331235
/// Specifies how NULL value is treated: ignore or respect
12341236
pub null_treatment: Option<NullTreatment>,
12351237
/// Distinct flag
@@ -1247,6 +1249,7 @@ impl WindowFunction {
12471249
partition_by: Vec::default(),
12481250
order_by: Vec::default(),
12491251
window_frame: WindowFrame::new(None),
1252+
filter: None,
12501253
null_treatment: None,
12511254
distinct: false,
12521255
},
@@ -2388,6 +2391,7 @@ impl NormalizeEq for Expr {
23882391
window_frame: self_window_frame,
23892392
partition_by: self_partition_by,
23902393
order_by: self_order_by,
2394+
filter: self_filter,
23912395
null_treatment: self_null_treatment,
23922396
distinct: self_distinct,
23932397
},
@@ -2400,13 +2404,19 @@ impl NormalizeEq for Expr {
24002404
window_frame: other_window_frame,
24012405
partition_by: other_partition_by,
24022406
order_by: other_order_by,
2407+
filter: other_filter,
24032408
null_treatment: other_null_treatment,
24042409
distinct: other_distinct,
24052410
},
24062411
} = other.as_ref();
24072412

24082413
self_fun.name() == other_fun.name()
24092414
&& self_window_frame == other_window_frame
2415+
&& match (self_filter, other_filter) {
2416+
(Some(a), Some(b)) => a.normalize_eq(b),
2417+
(None, None) => true,
2418+
_ => false,
2419+
}
24102420
&& self_null_treatment == other_null_treatment
24112421
&& self_args.len() == other_args.len()
24122422
&& self_args
@@ -2658,12 +2668,14 @@ impl HashNode for Expr {
26582668
partition_by: _,
26592669
order_by: _,
26602670
window_frame,
2671+
filter,
26612672
null_treatment,
26622673
distinct,
26632674
},
26642675
} = window_fun.as_ref();
26652676
fun.hash(state);
26662677
window_frame.hash(state);
2678+
filter.hash(state);
26672679
null_treatment.hash(state);
26682680
distinct.hash(state);
26692681
}
@@ -2967,6 +2979,7 @@ impl Display for SchemaDisplay<'_> {
29672979
partition_by,
29682980
order_by,
29692981
window_frame,
2982+
filter,
29702983
null_treatment,
29712984
distinct,
29722985
} = params;
@@ -2993,6 +3006,10 @@ impl Display for SchemaDisplay<'_> {
29933006
write!(f, " {null_treatment}")?;
29943007
}
29953008

3009+
if let Some(filter) = filter {
3010+
write!(f, " FILTER (WHERE {filter})")?;
3011+
}
3012+
29963013
if !partition_by.is_empty() {
29973014
write!(
29983015
f,
@@ -3370,6 +3387,7 @@ impl Display for Expr {
33703387
partition_by,
33713388
order_by,
33723389
window_frame,
3390+
filter,
33733391
null_treatment,
33743392
distinct,
33753393
} = params;
@@ -3380,6 +3398,10 @@ impl Display for Expr {
33803398
write!(f, "{nt}")?;
33813399
}
33823400

3401+
if let Some(fe) = filter {
3402+
write!(f, " FILTER (WHERE {fe})")?;
3403+
}
3404+
33833405
if !partition_by.is_empty() {
33843406
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
33853407
}

datafusion/expr/src/expr_fn.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
2020
use crate::expr::{
2121
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22-
Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams,
22+
Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
23+
WindowFunctionDefinition, WindowFunctionParams,
2324
};
2425
use crate::function::{
2526
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
@@ -836,6 +837,14 @@ impl ExprFuncBuilder {
836837
fun,
837838
params: WindowFunctionParams { args, .. },
838839
}) => {
840+
// FILTER is only supported for aggregate window functions
841+
if filter.is_some()
842+
&& matches!(fun, WindowFunctionDefinition::WindowUDF(_))
843+
{
844+
return plan_err!(
845+
"FILTER clause is only permitted for aggregate window functions"
846+
);
847+
}
839848
let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
840849
Expr::from(WindowFunction {
841850
fun,
@@ -845,6 +854,7 @@ impl ExprFuncBuilder {
845854
order_by: order_by.unwrap_or_default(),
846855
window_frame: window_frame
847856
.unwrap_or_else(|| WindowFrame::new(has_order_by)),
857+
filter: filter.map(Box::new),
848858
null_treatment,
849859
distinct,
850860
},

datafusion/expr/src/planner.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ pub struct RawWindowExpr {
307307
pub partition_by: Vec<Expr>,
308308
pub order_by: Vec<SortExpr>,
309309
pub window_frame: WindowFrame,
310+
pub filter: Option<Box<Expr>>,
310311
pub null_treatment: Option<NullTreatment>,
311312
pub distinct: bool,
312313
}

0 commit comments

Comments
 (0)