Skip to content

Commit 5fe7219

Browse files
committed
Add group join physical optimizer
1 parent 2f2fe8f commit 5fe7219

7 files changed

Lines changed: 905 additions & 0 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/physical-optimizer/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ recursive = { workspace = true, optional = true }
5656
[dev-dependencies]
5757
datafusion-expr = { workspace = true }
5858
datafusion-functions = { workspace = true }
59+
datafusion-functions-aggregate = { workspace = true }
5960
datafusion-functions-window = { workspace = true }
6061
insta = { workspace = true }
6162
tokio = { workspace = true }
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`GroupJoinOptimizer`] replaces an `AggregateExec` directly above a
19+
//! `HashJoinExec` with a fused [`GroupJoinExec`] when the aggregate's GROUP BY
20+
//! keys match the join's equi-join keys.
21+
//!
22+
//! Based on: Moerkotte & Neumann, "Accelerating Queries with Group-By and Join
23+
//! by Groupjoin", PVLDB 4(11), 2011.
24+
25+
use std::sync::Arc;
26+
27+
use datafusion_common::config::ConfigOptions;
28+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
29+
use datafusion_common::{JoinType, Result};
30+
use datafusion_physical_expr::physical_exprs_equal;
31+
use datafusion_physical_plan::ExecutionPlan;
32+
use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode};
33+
use datafusion_physical_plan::joins::HashJoinExec;
34+
use datafusion_physical_plan::joins::group_join::GroupJoinExec;
35+
use datafusion_physical_plan::projection::ProjectionExec;
36+
37+
use crate::PhysicalOptimizerRule;
38+
39+
/// Replaces `AggregateExec(HashJoinExec)` with a fused `GroupJoinExec` when:
40+
///
41+
/// 1. The aggregate mode is `Single`, `SinglePartitioned`, or `Partial`
42+
/// 2. The aggregate has at least one aggregate expression (not just DISTINCT)
43+
/// 3. The aggregate has no GROUPING SETS
44+
/// 4. The input is a `HashJoinExec` (possibly through a `ProjectionExec`)
45+
/// 5. The join type is `Inner` or `Left`
46+
/// 6. The join has no residual filter (equi-join only)
47+
/// 7. The GROUP BY expressions exactly match the left join keys
48+
/// 8. All aggregate functions support `GroupsAccumulator`
49+
///
50+
/// This rule should run after `CombinePartialFinalAggregate` (which may
51+
/// collapse two-phase aggregation into Single mode) and after `JoinSelection`
52+
/// (which decides build/probe sides).
53+
#[derive(Default, Debug)]
54+
pub struct GroupJoinOptimizer {}
55+
56+
impl GroupJoinOptimizer {
57+
/// Create a new `GroupJoinOptimizer`.
58+
pub fn new() -> Self {
59+
Self {}
60+
}
61+
}
62+
63+
impl PhysicalOptimizerRule for GroupJoinOptimizer {
64+
fn optimize(
65+
&self,
66+
plan: Arc<dyn ExecutionPlan>,
67+
_config: &ConfigOptions,
68+
) -> Result<Arc<dyn ExecutionPlan>> {
69+
plan.transform_down(|plan| {
70+
let Some(agg_exec) = plan.downcast_ref::<AggregateExec>() else {
71+
return Ok(Transformed::no(plan));
72+
};
73+
74+
if !matches!(
75+
agg_exec.mode(),
76+
AggregateMode::Single
77+
| AggregateMode::SinglePartitioned
78+
| AggregateMode::Partial
79+
) {
80+
return Ok(Transformed::no(plan));
81+
}
82+
83+
// Must have actual aggregate functions (not just GROUP BY for DISTINCT)
84+
let aggr_exprs = agg_exec.aggr_expr();
85+
if aggr_exprs.is_empty() {
86+
return Ok(Transformed::no(plan));
87+
}
88+
89+
// No GROUPING SETS
90+
if agg_exec.group_expr().groups().len() > 1 {
91+
return Ok(Transformed::no(plan));
92+
}
93+
94+
// Find HashJoinExec (possibly through a ProjectionExec)
95+
let input = agg_exec.input();
96+
let hash_join: &HashJoinExec;
97+
if let Some(hj) = input.downcast_ref::<HashJoinExec>() {
98+
hash_join = hj;
99+
} else if let Some(proj) = input.downcast_ref::<ProjectionExec>() {
100+
if let Some(hj) = proj.input().downcast_ref::<HashJoinExec>() {
101+
hash_join = hj;
102+
} else {
103+
return Ok(Transformed::no(plan));
104+
}
105+
} else {
106+
return Ok(Transformed::no(plan));
107+
};
108+
109+
// Inner and Left joins
110+
if !matches!(
111+
hash_join.join_type(),
112+
JoinType::Inner | JoinType::Left
113+
) {
114+
return Ok(Transformed::no(plan));
115+
}
116+
117+
// No residual join filter (equi-join only)
118+
if hash_join.filter().is_some() {
119+
return Ok(Transformed::no(plan));
120+
}
121+
122+
// GROUP BY keys must exactly match left join keys
123+
let group_exprs: Vec<_> = agg_exec
124+
.group_expr()
125+
.expr()
126+
.iter()
127+
.map(|(expr, _)| Arc::clone(expr))
128+
.collect();
129+
130+
let join_on = hash_join.on();
131+
let left_join_keys: Vec<_> =
132+
join_on.iter().map(|(l, _)| Arc::clone(l)).collect();
133+
134+
if group_exprs.len() != left_join_keys.len() {
135+
return Ok(Transformed::no(plan));
136+
}
137+
138+
if !physical_exprs_equal(&group_exprs, &left_join_keys) {
139+
return Ok(Transformed::no(plan));
140+
}
141+
142+
// All aggregates must support GroupsAccumulator
143+
for agg in aggr_exprs {
144+
if !agg.groups_accumulator_supported() {
145+
return Ok(Transformed::no(plan));
146+
}
147+
}
148+
149+
// For Inner joins, skip if any aggregate has a literal argument
150+
// (e.g., COUNT(*) rewritten as count(Int64(1))). These queries
151+
// don't benefit enough from GroupJoin to justify changing the plan.
152+
if *hash_join.join_type() == JoinType::Inner {
153+
let has_literal_arg = aggr_exprs.iter().any(|agg| {
154+
agg.expressions().iter().any(|expr| {
155+
expr.as_ref()
156+
.downcast_ref::<datafusion_physical_expr::expressions::Literal>()
157+
.is_some()
158+
})
159+
});
160+
if has_literal_arg {
161+
return Ok(Transformed::no(plan));
162+
}
163+
}
164+
165+
// All preconditions met — create GroupJoinExec
166+
let group_by_with_names: Vec<_> = agg_exec
167+
.group_expr()
168+
.expr()
169+
.iter()
170+
.map(|(expr, name)| (Arc::clone(expr), name.clone()))
171+
.collect();
172+
173+
let group_join = GroupJoinExec::try_new(
174+
Arc::clone(hash_join.left()),
175+
Arc::clone(hash_join.right()),
176+
join_on.to_vec(),
177+
*hash_join.join_type(),
178+
group_by_with_names,
179+
aggr_exprs.to_vec(),
180+
)?;
181+
182+
Ok(Transformed::yes(
183+
Arc::new(group_join) as Arc<dyn ExecutionPlan>
184+
))
185+
})
186+
.data()
187+
}
188+
189+
fn name(&self) -> &str {
190+
"group_join"
191+
}
192+
193+
fn schema_check(&self) -> bool {
194+
false // Schema changes (aggregate output differs from join output)
195+
}
196+
}
197+
198+
#[cfg(test)]
199+
mod tests {
200+
use super::*;
201+
202+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
203+
use datafusion_common::NullEquality;
204+
use datafusion_functions_aggregate::count::count_udaf;
205+
use datafusion_physical_expr::PhysicalExprRef;
206+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
207+
use datafusion_physical_expr::expressions::{col, lit};
208+
use datafusion_physical_plan::displayable;
209+
use datafusion_physical_plan::empty::EmptyExec;
210+
use datafusion_physical_plan::joins::PartitionMode;
211+
use datafusion_physical_plan::projection::ProjectionExec;
212+
use insta::assert_snapshot;
213+
214+
fn left_schema() -> SchemaRef {
215+
Arc::new(Schema::new(vec![
216+
Field::new("l_key", DataType::Int32, false),
217+
Field::new("l_value", DataType::Int32, true),
218+
]))
219+
}
220+
221+
fn right_schema() -> SchemaRef {
222+
Arc::new(Schema::new(vec![
223+
Field::new("r_key", DataType::Int32, false),
224+
Field::new("r_value", DataType::Int32, true),
225+
]))
226+
}
227+
228+
fn join(join_type: JoinType, left_key: &str) -> Result<Arc<HashJoinExec>> {
229+
let left_schema = left_schema();
230+
let right_schema = right_schema();
231+
let left = Arc::new(EmptyExec::new(Arc::clone(&left_schema)));
232+
let right = Arc::new(EmptyExec::new(Arc::clone(&right_schema)));
233+
234+
Ok(Arc::new(HashJoinExec::try_new(
235+
left,
236+
right,
237+
vec![(col(left_key, &left_schema)?, col("r_key", &right_schema)?)],
238+
None,
239+
&join_type,
240+
None,
241+
PartitionMode::CollectLeft,
242+
NullEquality::NullEqualsNull,
243+
false,
244+
)?))
245+
}
246+
247+
fn aggregate(
248+
input: Arc<dyn ExecutionPlan>,
249+
group_expr: PhysicalExprRef,
250+
aggr_expr: PhysicalExprRef,
251+
) -> Result<Arc<dyn ExecutionPlan>> {
252+
let input_schema = input.schema();
253+
let aggr_expr = Arc::new(
254+
AggregateExprBuilder::new(count_udaf(), vec![aggr_expr])
255+
.schema(Arc::clone(&input_schema))
256+
.alias("count_values")
257+
.build()?,
258+
);
259+
260+
Ok(Arc::new(AggregateExec::try_new(
261+
AggregateMode::Single,
262+
datafusion_physical_plan::aggregates::PhysicalGroupBy::new_single(vec![(
263+
group_expr,
264+
"l_key".to_string(),
265+
)]),
266+
vec![aggr_expr],
267+
vec![None],
268+
input,
269+
input_schema,
270+
)?))
271+
}
272+
273+
fn optimize(plan: Arc<dyn ExecutionPlan>) -> Result<String> {
274+
let optimized =
275+
GroupJoinOptimizer::new().optimize(plan, &ConfigOptions::new())?;
276+
Ok(displayable(optimized.as_ref()).indent(true).to_string())
277+
}
278+
279+
#[test]
280+
fn rewrites_aggregate_above_inner_hash_join() -> Result<()> {
281+
let join = join(JoinType::Inner, "l_key")?;
282+
let join_schema = join.schema();
283+
let plan = aggregate(
284+
join,
285+
col("l_key", &join_schema)?,
286+
col("r_value", &join_schema)?,
287+
)?;
288+
289+
assert_snapshot!(optimize(plan)?, @r"
290+
GroupJoinExec: join_type=Inner, on=[(l_key@0, r_key@0)], aggr=[count_values]
291+
EmptyExec
292+
EmptyExec
293+
");
294+
Ok(())
295+
}
296+
297+
#[test]
298+
fn rewrites_through_projection() -> Result<()> {
299+
let join = join(JoinType::Left, "l_key")?;
300+
let join_schema = join.schema();
301+
let projection = Arc::new(ProjectionExec::try_new(
302+
vec![
303+
(col("l_key", &join_schema)?, "l_key".to_string()),
304+
(col("r_value", &join_schema)?, "r_value".to_string()),
305+
],
306+
join,
307+
)?);
308+
let projection_schema = projection.schema();
309+
let plan = aggregate(
310+
projection,
311+
col("l_key", &projection_schema)?,
312+
col("r_value", &projection_schema)?,
313+
)?;
314+
315+
assert_snapshot!(optimize(plan)?, @r"
316+
GroupJoinExec: join_type=Left, on=[(l_key@0, r_key@0)], aggr=[count_values]
317+
EmptyExec
318+
EmptyExec
319+
");
320+
Ok(())
321+
}
322+
323+
#[test]
324+
fn does_not_rewrite_when_group_by_does_not_match_join_key() -> Result<()> {
325+
let join = join(JoinType::Inner, "l_key")?;
326+
let join_schema = join.schema();
327+
let plan = aggregate(
328+
join,
329+
col("l_value", &join_schema)?,
330+
col("r_value", &join_schema)?,
331+
)?;
332+
333+
assert_snapshot!(optimize(plan)?, @r"
334+
AggregateExec: mode=Single, gby=[l_value@1 as l_key], aggr=[count_values]
335+
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(l_key@0, r_key@0)], NullsEqual: true
336+
EmptyExec
337+
EmptyExec
338+
");
339+
Ok(())
340+
}
341+
342+
#[test]
343+
fn does_not_rewrite_unsupported_join_type() -> Result<()> {
344+
let join = join(JoinType::Right, "l_key")?;
345+
let join_schema = join.schema();
346+
let plan = aggregate(
347+
join,
348+
col("l_key", &join_schema)?,
349+
col("r_value", &join_schema)?,
350+
)?;
351+
352+
assert_snapshot!(optimize(plan)?, @r"
353+
AggregateExec: mode=Single, gby=[l_key@0 as l_key], aggr=[count_values]
354+
HashJoinExec: mode=CollectLeft, join_type=Right, on=[(l_key@0, r_key@0)], NullsEqual: true
355+
EmptyExec
356+
EmptyExec
357+
");
358+
Ok(())
359+
}
360+
361+
#[test]
362+
fn does_not_rewrite_inner_join_with_literal_aggregate_argument() -> Result<()> {
363+
let join = join(JoinType::Inner, "l_key")?;
364+
let join_schema = join.schema();
365+
let plan = aggregate(join, col("l_key", &join_schema)?, lit(1i64))?;
366+
367+
assert_snapshot!(optimize(plan)?, @r"
368+
AggregateExec: mode=Single, gby=[l_key@0 as l_key], aggr=[count_values]
369+
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(l_key@0, r_key@0)], NullsEqual: true
370+
EmptyExec
371+
EmptyExec
372+
");
373+
Ok(())
374+
}
375+
}

datafusion/physical-optimizer/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub mod enforce_distribution;
3131
pub mod enforce_sorting;
3232
pub mod ensure_coop;
3333
pub mod filter_pushdown;
34+
pub mod group_join;
3435
pub mod join_selection;
3536
pub mod limit_pushdown;
3637
pub mod limit_pushdown_past_window;

0 commit comments

Comments
 (0)