Skip to content

Commit d8a829f

Browse files
feat: Support FILTER clause in aggregate window functions
WIP
1 parent abb9b85 commit d8a829f

28 files changed

Lines changed: 446 additions & 107 deletions

File tree

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
200200
window_frame: window_function.params.window_frame,
201201
null_treatment: window_function.params.null_treatment,
202202
distinct: window_function.params.distinct,
203+
filter: window_function.params.filter,
203204
},
204205
}))
205206
};

datafusion/common/src/tree_node.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,48 @@ impl<
990990
}
991991
}
992992

993+
impl<
994+
'a,
995+
T: 'a,
996+
C0: TreeNodeContainer<'a, T>,
997+
C1: TreeNodeContainer<'a, T>,
998+
C2: TreeNodeContainer<'a, T>,
999+
C3: TreeNodeContainer<'a, T>,
1000+
> TreeNodeContainer<'a, T> for (C0, C1, C2, C3)
1001+
{
1002+
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1003+
&'a self,
1004+
mut f: F,
1005+
) -> Result<TreeNodeRecursion> {
1006+
self.0
1007+
.apply_elements(&mut f)?
1008+
.visit_sibling(|| self.1.apply_elements(&mut f))?
1009+
.visit_sibling(|| self.2.apply_elements(&mut f))?
1010+
.visit_sibling(|| self.3.apply_elements(&mut f))
1011+
}
1012+
1013+
fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
1014+
self,
1015+
mut f: F,
1016+
) -> Result<Transformed<Self>> {
1017+
self.0
1018+
.map_elements(&mut f)?
1019+
.map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))?
1020+
.transform_sibling(|(new_c0, c1, c2, c3)| {
1021+
c1.map_elements(&mut f)?
1022+
.map_data(|new_c1| Ok((new_c0, new_c1, c2, c3)))
1023+
})?
1024+
.transform_sibling(|(new_c0, new_c1, c2, c3)| {
1025+
c2.map_elements(&mut f)?
1026+
.map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3)))
1027+
})?
1028+
.transform_sibling(|(new_c0, new_c1, new_c2, c3)| {
1029+
c3.map_elements(&mut f)?
1030+
.map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3)))
1031+
})
1032+
}
1033+
}
1034+
9931035
/// [`TreeNodeRefContainer`] contains references to elements that a function can be
9941036
/// applied on. The elements of the container are siblings so the continuation rules are
9951037
/// similar to [`TreeNodeRecursion::visit_sibling`].
@@ -1065,6 +1107,27 @@ impl<
10651107
}
10661108
}
10671109

1110+
impl<
1111+
'a,
1112+
T: 'a,
1113+
C0: TreeNodeContainer<'a, T>,
1114+
C1: TreeNodeContainer<'a, T>,
1115+
C2: TreeNodeContainer<'a, T>,
1116+
C3: TreeNodeContainer<'a, T>,
1117+
> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3)
1118+
{
1119+
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1120+
&self,
1121+
mut f: F,
1122+
) -> Result<TreeNodeRecursion> {
1123+
self.0
1124+
.apply_elements(&mut f)?
1125+
.visit_sibling(|| self.1.apply_elements(&mut f))?
1126+
.visit_sibling(|| self.2.apply_elements(&mut f))?
1127+
.visit_sibling(|| self.3.apply_elements(&mut f))
1128+
}
1129+
}
1130+
10681131
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
10691132
pub trait TreeNodeIterator: Iterator {
10701133
/// Apples `f` to each item in this iterator

datafusion/core/src/physical_planner.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,7 @@ pub fn create_window_expr_with_name(
16501650
window_frame,
16511651
null_treatment,
16521652
distinct,
1653+
filter,
16531654
},
16541655
} = window_fun.as_ref();
16551656
let physical_args =
@@ -1669,6 +1670,11 @@ pub fn create_window_expr_with_name(
16691670
let window_frame = Arc::new(window_frame.clone());
16701671
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
16711672
== NullTreatment::IgnoreNulls;
1673+
let physical_filter = filter
1674+
.as_ref()
1675+
.map(|f| create_physical_expr(f, logical_schema, execution_props))
1676+
.transpose()?;
1677+
16721678
windows::create_window_expr(
16731679
fun,
16741680
name,
@@ -1679,6 +1685,7 @@ pub fn create_window_expr_with_name(
16791685
physical_schema,
16801686
ignore_nulls,
16811687
*distinct,
1688+
physical_filter,
16821689
)
16831690
}
16841691
other => plan_err!("Invalid window expression '{other:?}'"),

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
289289
&extended_schema,
290290
false,
291291
false,
292+
None,
292293
)?;
293294
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
294295
vec![window_expr],
@@ -662,6 +663,7 @@ async fn run_window_test(
662663
&extended_schema,
663664
false,
664665
false,
666+
None,
665667
)?],
666668
exec1,
667669
false,
@@ -681,6 +683,7 @@ async fn run_window_test(
681683
&extended_schema,
682684
false,
683685
false,
686+
None,
684687
)?],
685688
exec2,
686689
search_mode.clone(),

datafusion/core/tests/physical_optimizer/enforce_sorting.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3686,6 +3686,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
36863686
input_schema.as_ref(),
36873687
false,
36883688
false,
3689+
None,
36893690
)?;
36903691
let window_exec = if window_expr.uses_bounded_memory() {
36913692
Arc::new(BoundedWindowAggExec::try_new(

datafusion/core/tests/physical_optimizer/test_utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ pub fn bounded_window_exec_with_partition(
266266
schema.as_ref(),
267267
false,
268268
false,
269+
None,
269270
)
270271
.unwrap();
271272

datafusion/core/tests/physical_optimizer/window_optimize.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod test {
4949
&partition,
5050
&[],
5151
Arc::new(frame),
52+
None,
5253
);
5354

5455
let bounded_agg_exec = BoundedWindowAggExec::try_new(

datafusion/expr/src/expr.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,8 @@ pub struct WindowFunctionParams {
12301230
pub order_by: Vec<Sort>,
12311231
/// Window frame
12321232
pub window_frame: WindowFrame,
1233+
/// Optional filter expression (FILTER (WHERE ...))
1234+
pub filter: Option<Box<Expr>>,
12331235
/// Specifies how NULL value is treated: ignore or respect
12341236
pub null_treatment: Option<NullTreatment>,
12351237
/// Distinct flag
@@ -1247,6 +1249,7 @@ impl WindowFunction {
12471249
partition_by: Vec::default(),
12481250
order_by: Vec::default(),
12491251
window_frame: WindowFrame::new(None),
1252+
filter: None,
12501253
null_treatment: None,
12511254
distinct: false,
12521255
},
@@ -2388,6 +2391,7 @@ impl NormalizeEq for Expr {
23882391
window_frame: self_window_frame,
23892392
partition_by: self_partition_by,
23902393
order_by: self_order_by,
2394+
filter: self_filter,
23912395
null_treatment: self_null_treatment,
23922396
distinct: self_distinct,
23932397
},
@@ -2400,13 +2404,19 @@ impl NormalizeEq for Expr {
24002404
window_frame: other_window_frame,
24012405
partition_by: other_partition_by,
24022406
order_by: other_order_by,
2407+
filter: other_filter,
24032408
null_treatment: other_null_treatment,
24042409
distinct: other_distinct,
24052410
},
24062411
} = other.as_ref();
24072412

24082413
self_fun.name() == other_fun.name()
24092414
&& self_window_frame == other_window_frame
2415+
&& match (self_filter, other_filter) {
2416+
(Some(a), Some(b)) => a.normalize_eq(b),
2417+
(None, None) => true,
2418+
_ => false,
2419+
}
24102420
&& self_null_treatment == other_null_treatment
24112421
&& self_args.len() == other_args.len()
24122422
&& self_args
@@ -2658,12 +2668,14 @@ impl HashNode for Expr {
26582668
partition_by: _,
26592669
order_by: _,
26602670
window_frame,
2671+
filter,
26612672
null_treatment,
26622673
distinct,
26632674
},
26642675
} = window_fun.as_ref();
26652676
fun.hash(state);
26662677
window_frame.hash(state);
2678+
filter.hash(state);
26672679
null_treatment.hash(state);
26682680
distinct.hash(state);
26692681
}
@@ -2967,6 +2979,7 @@ impl Display for SchemaDisplay<'_> {
29672979
partition_by,
29682980
order_by,
29692981
window_frame,
2982+
filter,
29702983
null_treatment,
29712984
distinct,
29722985
} = params;
@@ -2993,6 +3006,10 @@ impl Display for SchemaDisplay<'_> {
29933006
write!(f, " {null_treatment}")?;
29943007
}
29953008

3009+
if let Some(filter) = filter {
3010+
write!(f, " FILTER (WHERE {filter})")?;
3011+
}
3012+
29963013
if !partition_by.is_empty() {
29973014
write!(
29983015
f,
@@ -3370,6 +3387,7 @@ impl Display for Expr {
33703387
partition_by,
33713388
order_by,
33723389
window_frame,
3390+
filter,
33733391
null_treatment,
33743392
distinct,
33753393
} = params;
@@ -3380,6 +3398,10 @@ impl Display for Expr {
33803398
write!(f, "{nt}")?;
33813399
}
33823400

3401+
if let Some(fe) = filter {
3402+
write!(f, " FILTER (WHERE {fe})")?;
3403+
}
3404+
33833405
if !partition_by.is_empty() {
33843406
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
33853407
}

datafusion/expr/src/expr_fn.rs

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use crate::expr::{
2121
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery,
22-
Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, WindowFunctionParams,
22+
Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction,
2323
};
2424
use crate::function::{
2525
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
@@ -832,23 +832,16 @@ impl ExprFuncBuilder {
832832
udaf.params.null_treatment = null_treatment;
833833
Expr::AggregateFunction(udaf)
834834
}
835-
ExprFuncKind::Window(WindowFunction {
836-
fun,
837-
params: WindowFunctionParams { args, .. },
838-
}) => {
835+
ExprFuncKind::Window(mut udwf) => {
839836
let has_order_by = order_by.as_ref().map(|o| !o.is_empty());
840-
Expr::from(WindowFunction {
841-
fun,
842-
params: WindowFunctionParams {
843-
args,
844-
partition_by: partition_by.unwrap_or_default(),
845-
order_by: order_by.unwrap_or_default(),
846-
window_frame: window_frame
847-
.unwrap_or_else(|| WindowFrame::new(has_order_by)),
848-
null_treatment,
849-
distinct,
850-
},
851-
})
837+
udwf.params.partition_by = partition_by.unwrap_or_default();
838+
udwf.params.order_by = order_by.unwrap_or_default();
839+
udwf.params.window_frame =
840+
window_frame.unwrap_or_else(|| WindowFrame::new(has_order_by));
841+
udwf.params.filter = filter.map(Box::new);
842+
udwf.params.null_treatment = null_treatment;
843+
udwf.params.distinct = distinct;
844+
Expr::WindowFunction(Box::new(udwf))
852845
}
853846
};
854847

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2469,6 +2469,20 @@ impl Window {
24692469
window_func_dependencies.extend(new_deps);
24702470
}
24712471

2472+
// Validate that FILTER if present is only used with aggregate window functions
2473+
if let Some(e) = window_expr.iter().find(|e| {
2474+
matches!(
2475+
e,
2476+
Expr::WindowFunction(wf)
2477+
if !matches!(wf.fun, WindowFunctionDefinition::AggregateUDF(_))
2478+
&& wf.params.filter.is_some()
2479+
)
2480+
}) {
2481+
return plan_err!(
2482+
"FILTER clause can only be used with aggregate window functions. Found in '{e}'"
2483+
);
2484+
}
2485+
24722486
Self::try_new_with_schema(
24732487
window_expr,
24742488
input,

0 commit comments

Comments
 (0)