Skip to content

Commit 46bcc10

Browse files
xiedeyantuCopilot
andcommitted
Addressed
Co-authored-by: Copilot <copilot@github.com>
1 parent 4cf91eb commit 46bcc10

1 file changed

Lines changed: 37 additions & 33 deletions

File tree

datafusion/optimizer/src/unions_to_filter.rs

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ use datafusion_expr::{
2828
Distinct, Expr, Filter, LogicalPlan, Projection, SubqueryAlias, Union,
2929
};
3030
use log::debug;
31-
use std::collections::HashMap;
3231
use std::sync::Arc;
3332

3433
#[derive(Default, Debug)]
@@ -76,17 +75,18 @@ struct UnionsToFilterRewriter;
7675
impl TreeNodeRewriter for UnionsToFilterRewriter {
7776
type Node = LogicalPlan;
7877

79-
fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
80-
match &plan {
81-
LogicalPlan::Distinct(Distinct::All(input)) => {
82-
match try_rewrite_distinct_union(input.as_ref().clone())? {
83-
Some(rewritten) => Ok(Transformed::yes(rewritten)),
84-
None => Ok(Transformed::no(plan)),
85-
}
86-
}
87-
_ => Ok(Transformed::no(plan)),
88-
}
89-
}
78+
fn f_up(&mut self, plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
79+
match &plan {
80+
LogicalPlan::Distinct(Distinct::All(input)) => {
81+
match try_rewrite_distinct_union(input.as_ref().clone())? {
82+
Some(rewritten) => Ok(Transformed::yes(rewritten)),
83+
None => Ok(Transformed::no(plan)),
84+
}
85+
}
86+
_ => Ok(Transformed::no(plan)),
87+
}
88+
}
89+
}
9090

9191
fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>> {
9292
let LogicalPlan::Union(Union { inputs, schema }) = plan else {
@@ -102,8 +102,10 @@ fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>>
102102
return Ok(None);
103103
}
104104

105-
let mut grouped: HashMap<GroupKey, Vec<Expr>> = HashMap::new();
106-
let mut input_order: Vec<GroupKey> = Vec::new();
105+
// Use a Vec instead of HashMap: union branches are typically 2-10 entries,
106+
// so a linear scan with PartialEq is faster than recursively hashing entire
107+
// LogicalPlan subtrees (O(N * tree_size) hashing for every insert/lookup).
108+
let mut grouped: Vec<(GroupKey, Vec<Expr>)> = Vec::new();
107109
let mut transformed = false;
108110

109111
for input in inputs {
@@ -115,12 +117,11 @@ fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>>
115117
source: branch.source,
116118
wrappers: branch.wrappers,
117119
};
118-
if let Some(conds) = grouped.get_mut(&key) {
120+
if let Some((_, conds)) = grouped.iter_mut().find(|(k, _)| k == &key) {
119121
conds.push(branch.predicate);
120122
transformed = true;
121123
} else {
122-
input_order.push(key.clone());
123-
grouped.insert(key, vec![branch.predicate]);
124+
grouped.push((key, vec![branch.predicate]));
124125
}
125126
}
126127

@@ -130,10 +131,7 @@ fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>>
130131
}
131132

132133
let mut builder: Option<LogicalPlanBuilder> = None;
133-
for key in input_order {
134-
let predicates = grouped
135-
.remove(&key)
136-
.expect("grouped predicates should exist for every source");
134+
for (key, predicates) in grouped {
137135
let combined =
138136
disjunction(predicates).expect("union branches always provide predicates");
139137
let branch = LogicalPlanBuilder::from(key.source)
@@ -203,7 +201,7 @@ fn extract_branch(plan: LogicalPlan) -> Result<Option<UnionBranch>> {
203201
Ok(None)
204202
}
205203
other => Ok(Some(UnionBranch {
206-
source: strip_passthrough_nodes(other.clone()),
204+
source: strip_passthrough_nodes(other),
207205
predicate: Expr::Literal(
208206
datafusion_common::ScalarValue::Boolean(Some(true)),
209207
None,
@@ -213,13 +211,13 @@ fn extract_branch(plan: LogicalPlan) -> Result<Option<UnionBranch>> {
213211
}
214212
}
215213

216-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
214+
#[derive(Debug, Clone, PartialEq, Eq)]
217215
struct GroupKey {
218216
source: LogicalPlan,
219217
wrappers: Vec<Wrapper>,
220218
}
221219

222-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
220+
#[derive(Debug, Clone, PartialEq, Eq)]
223221
enum Wrapper {
224222
Projection {
225223
expr: Vec<Expr>,
@@ -268,6 +266,10 @@ fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result<LogicalPla
268266
Arc::clone(schema),
269267
)?)
270268
}
269+
// SubqueryAlias::try_new recomputes the schema from the new input.
270+
// This is safe because the source table is unchanged; only the
271+
// filter predicate differs, so the recomputed schema matches the
272+
// original one stored in peel_wrappers.
271273
Wrapper::SubqueryAlias { alias, .. } => LogicalPlan::SubqueryAlias(
272274
SubqueryAlias::try_new(Arc::new(plan), alias.clone())?,
273275
),
@@ -276,15 +278,17 @@ fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result<LogicalPla
276278
Ok(plan)
277279
}
278280

279-
fn strip_passthrough_nodes(plan: LogicalPlan) -> LogicalPlan {
280-
match plan {
281-
LogicalPlan::Projection(Projection { input, .. }) => {
282-
strip_passthrough_nodes(Arc::unwrap_or_clone(input))
283-
}
284-
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
285-
strip_passthrough_nodes(Arc::unwrap_or_clone(input))
286-
}
287-
other => other,
281+
fn strip_passthrough_nodes(mut plan: LogicalPlan) -> LogicalPlan {
282+
loop {
283+
plan = match plan {
284+
LogicalPlan::Projection(Projection { input, .. }) => {
285+
Arc::unwrap_or_clone(input)
286+
}
287+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
288+
Arc::unwrap_or_clone(input)
289+
}
290+
other => return other,
291+
};
288292
}
289293
}
290294

0 commit comments

Comments
 (0)