Skip to content

Commit 73f4713

Browse files
committed
Add group join physical optimizer
1 parent 2f2fe8f commit 73f4713

5 files changed

Lines changed: 724 additions & 0 deletions

File tree

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`GroupJoinOptimizer`] replaces an `AggregateExec` directly above a
19+
//! `HashJoinExec` with a fused [`GroupJoinExec`] when the aggregate's GROUP BY
20+
//! keys match the join's equi-join keys.
21+
//!
22+
//! Based on: Moerkotte & Neumann, "Accelerating Queries with Group-By and Join
23+
//! by Groupjoin", PVLDB 4(11), 2011.
24+
25+
use std::sync::Arc;
26+
27+
use datafusion_common::config::ConfigOptions;
28+
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
29+
use datafusion_common::{JoinType, Result};
30+
use datafusion_physical_expr::physical_exprs_equal;
31+
use datafusion_physical_plan::ExecutionPlan;
32+
use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode};
33+
use datafusion_physical_plan::joins::HashJoinExec;
34+
use datafusion_physical_plan::joins::group_join::GroupJoinExec;
35+
use datafusion_physical_plan::projection::ProjectionExec;
36+
37+
use crate::PhysicalOptimizerRule;
38+
39+
/// Replaces `AggregateExec(HashJoinExec)` with a fused `GroupJoinExec` when:
40+
///
41+
/// 1. The aggregate mode is `Single`, `SinglePartitioned`, or `Partial`
42+
/// 2. The aggregate has at least one aggregate expression (not just DISTINCT)
43+
/// 3. The aggregate has no GROUPING SETS
44+
/// 4. The input is a `HashJoinExec` (possibly through a `ProjectionExec`)
45+
/// 5. The join type is `Inner` or `Left`
46+
/// 6. The join has no residual filter (equi-join only)
47+
/// 7. The GROUP BY expressions exactly match the left join keys
48+
/// 8. All aggregate functions support `GroupsAccumulator`
49+
///
50+
/// This rule should run after `CombinePartialFinalAggregate` (which may
51+
/// collapse two-phase aggregation into Single mode) and after `JoinSelection`
52+
/// (which decides build/probe sides).
53+
#[derive(Default, Debug)]
54+
pub struct GroupJoinOptimizer {}
55+
56+
impl GroupJoinOptimizer {
57+
/// Create a new `GroupJoinOptimizer`.
58+
pub fn new() -> Self {
59+
Self {}
60+
}
61+
}
62+
63+
impl PhysicalOptimizerRule for GroupJoinOptimizer {
64+
fn optimize(
65+
&self,
66+
plan: Arc<dyn ExecutionPlan>,
67+
_config: &ConfigOptions,
68+
) -> Result<Arc<dyn ExecutionPlan>> {
69+
plan.transform_down(|plan| {
70+
let Some(agg_exec) = plan.downcast_ref::<AggregateExec>() else {
71+
return Ok(Transformed::no(plan));
72+
};
73+
74+
if !matches!(
75+
agg_exec.mode(),
76+
AggregateMode::Single
77+
| AggregateMode::SinglePartitioned
78+
| AggregateMode::Partial
79+
) {
80+
return Ok(Transformed::no(plan));
81+
}
82+
83+
// Must have actual aggregate functions (not just GROUP BY for DISTINCT)
84+
let aggr_exprs = agg_exec.aggr_expr();
85+
if aggr_exprs.is_empty() {
86+
return Ok(Transformed::no(plan));
87+
}
88+
89+
// No GROUPING SETS
90+
if agg_exec.group_expr().groups().len() > 1 {
91+
return Ok(Transformed::no(plan));
92+
}
93+
94+
// Find HashJoinExec (possibly through a ProjectionExec)
95+
let input = agg_exec.input();
96+
let hash_join: &HashJoinExec;
97+
if let Some(hj) = input.downcast_ref::<HashJoinExec>() {
98+
hash_join = hj;
99+
} else if let Some(proj) = input.downcast_ref::<ProjectionExec>() {
100+
if let Some(hj) = proj.input().downcast_ref::<HashJoinExec>() {
101+
hash_join = hj;
102+
} else {
103+
return Ok(Transformed::no(plan));
104+
}
105+
} else {
106+
return Ok(Transformed::no(plan));
107+
};
108+
109+
// Inner and Left joins
110+
if !matches!(
111+
hash_join.join_type(),
112+
JoinType::Inner | JoinType::Left
113+
) {
114+
return Ok(Transformed::no(plan));
115+
}
116+
117+
// No residual join filter (equi-join only)
118+
if hash_join.filter().is_some() {
119+
return Ok(Transformed::no(plan));
120+
}
121+
122+
// GROUP BY keys must exactly match left join keys
123+
let group_exprs: Vec<_> = agg_exec
124+
.group_expr()
125+
.expr()
126+
.iter()
127+
.map(|(expr, _)| Arc::clone(expr))
128+
.collect();
129+
130+
let join_on = hash_join.on();
131+
let left_join_keys: Vec<_> =
132+
join_on.iter().map(|(l, _)| Arc::clone(l)).collect();
133+
134+
if group_exprs.len() != left_join_keys.len() {
135+
return Ok(Transformed::no(plan));
136+
}
137+
138+
if !physical_exprs_equal(&group_exprs, &left_join_keys) {
139+
return Ok(Transformed::no(plan));
140+
}
141+
142+
// All aggregates must support GroupsAccumulator
143+
for agg in aggr_exprs {
144+
if !agg.groups_accumulator_supported() {
145+
return Ok(Transformed::no(plan));
146+
}
147+
}
148+
149+
// For Inner joins, skip if any aggregate has a literal argument
150+
// (e.g., COUNT(*) rewritten as count(Int64(1))). These queries
151+
// don't benefit enough from GroupJoin to justify changing the plan.
152+
if *hash_join.join_type() == JoinType::Inner {
153+
let has_literal_arg = aggr_exprs.iter().any(|agg| {
154+
agg.expressions().iter().any(|expr| {
155+
expr.as_ref()
156+
.downcast_ref::<datafusion_physical_expr::expressions::Literal>()
157+
.is_some()
158+
})
159+
});
160+
if has_literal_arg {
161+
return Ok(Transformed::no(plan));
162+
}
163+
}
164+
165+
// All preconditions met — create GroupJoinExec
166+
let group_by_with_names: Vec<_> = agg_exec
167+
.group_expr()
168+
.expr()
169+
.iter()
170+
.map(|(expr, name)| (Arc::clone(expr), name.clone()))
171+
.collect();
172+
173+
let group_join = GroupJoinExec::try_new(
174+
Arc::clone(hash_join.left()),
175+
Arc::clone(hash_join.right()),
176+
join_on.to_vec(),
177+
*hash_join.join_type(),
178+
group_by_with_names,
179+
aggr_exprs.to_vec(),
180+
)?;
181+
182+
Ok(Transformed::yes(
183+
Arc::new(group_join) as Arc<dyn ExecutionPlan>
184+
))
185+
})
186+
.data()
187+
}
188+
189+
fn name(&self) -> &str {
190+
"group_join"
191+
}
192+
193+
fn schema_check(&self) -> bool {
194+
false // Schema changes (aggregate output differs from join output)
195+
}
196+
}

datafusion/physical-optimizer/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub mod enforce_distribution;
3131
pub mod enforce_sorting;
3232
pub mod ensure_coop;
3333
pub mod filter_pushdown;
34+
pub mod group_join;
3435
pub mod join_selection;
3536
pub mod limit_pushdown;
3637
pub mod limit_pushdown_past_window;

datafusion/physical-optimizer/src/optimizer.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use crate::enforce_distribution::EnforceDistribution;
2626
use crate::enforce_sorting::EnforceSorting;
2727
use crate::ensure_coop::EnsureCooperative;
2828
use crate::filter_pushdown::FilterPushdown;
29+
use crate::group_join::GroupJoinOptimizer;
2930
use crate::join_selection::JoinSelection;
3031
use crate::limit_pushdown::LimitPushdown;
3132
use crate::limited_distinct_aggregation::LimitedDistinctAggregation;
@@ -177,6 +178,11 @@ impl PhysicalOptimizer {
177178
Arc::new(EnforceDistribution::new()),
178179
// The CombinePartialFinalAggregate rule should be applied after the EnforceDistribution rule
179180
Arc::new(CombinePartialFinalAggregate::new()),
181+
// GroupJoinOptimizer fuses Aggregate+HashJoin into GroupJoinExec
182+
// when GROUP BY keys match join keys. Runs after
183+
// CombinePartialFinalAggregate (which creates SinglePartitioned
184+
// aggregates) and JoinSelection (which decides build/probe sides).
185+
Arc::new(GroupJoinOptimizer::new()),
180186
// The EnforceSorting rule is for adding essential local sorting to satisfy the required
181187
// ordering. Please make sure that the whole plan tree is determined before this rule.
182188
// Note that one should always run this rule after running the EnforceDistribution rule

0 commit comments

Comments
 (0)