@@ -28,7 +28,6 @@ use datafusion_expr::{
2828 Distinct , Expr , Filter , LogicalPlan , Projection , SubqueryAlias , Union ,
2929} ;
3030use log:: debug;
31- use std:: collections:: HashMap ;
3231use std:: sync:: Arc ;
3332
3433#[ derive( Default , Debug ) ]
@@ -76,17 +75,18 @@ struct UnionsToFilterRewriter;
7675impl TreeNodeRewriter for UnionsToFilterRewriter {
7776 type Node = LogicalPlan ;
7877
79- fn f_up ( & mut self , plan : LogicalPlan ) -> Result < Transformed < LogicalPlan > > {
80- match & plan {
81- LogicalPlan :: Distinct ( Distinct :: All ( input) ) => {
82- match try_rewrite_distinct_union ( input. as_ref ( ) . clone ( ) ) ? {
83- Some ( rewritten) => Ok ( Transformed :: yes ( rewritten) ) ,
84- None => Ok ( Transformed :: no ( plan) ) ,
85- }
86- }
87- _ => Ok ( Transformed :: no ( plan) ) ,
88- }
89- }
78+ fn f_up ( & mut self , plan : LogicalPlan ) -> Result < Transformed < LogicalPlan > > {
79+ match & plan {
80+ LogicalPlan :: Distinct ( Distinct :: All ( input) ) => {
81+ match try_rewrite_distinct_union ( input. as_ref ( ) . clone ( ) ) ? {
82+ Some ( rewritten) => Ok ( Transformed :: yes ( rewritten) ) ,
83+ None => Ok ( Transformed :: no ( plan) ) ,
84+ }
85+ }
86+ _ => Ok ( Transformed :: no ( plan) ) ,
87+ }
88+ }
89+ }
9090
9191fn try_rewrite_distinct_union ( plan : LogicalPlan ) -> Result < Option < LogicalPlan > > {
9292 let LogicalPlan :: Union ( Union { inputs, schema } ) = plan else {
@@ -102,8 +102,10 @@ fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>>
102102 return Ok ( None ) ;
103103 }
104104
105- let mut grouped: HashMap < GroupKey , Vec < Expr > > = HashMap :: new ( ) ;
106- let mut input_order: Vec < GroupKey > = Vec :: new ( ) ;
105+ // Use a Vec instead of HashMap: union branches are typically 2-10 entries,
106+ // so a linear scan with PartialEq is faster than recursively hashing entire
107+ // LogicalPlan subtrees (O(N * tree_size) hashing for every insert/lookup).
108+ let mut grouped: Vec < ( GroupKey , Vec < Expr > ) > = Vec :: new ( ) ;
107109 let mut transformed = false ;
108110
109111 for input in inputs {
@@ -115,12 +117,11 @@ fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>>
115117 source : branch. source ,
116118 wrappers : branch. wrappers ,
117119 } ;
118- if let Some ( conds) = grouped. get_mut ( & key) {
120+ if let Some ( ( _ , conds) ) = grouped. iter_mut ( ) . find ( | ( k , _ ) | k == & key) {
119121 conds. push ( branch. predicate ) ;
120122 transformed = true ;
121123 } else {
122- input_order. push ( key. clone ( ) ) ;
123- grouped. insert ( key, vec ! [ branch. predicate] ) ;
124+ grouped. push ( ( key, vec ! [ branch. predicate] ) ) ;
124125 }
125126 }
126127
@@ -130,10 +131,7 @@ fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result<Option<LogicalPlan>>
130131 }
131132
132133 let mut builder: Option < LogicalPlanBuilder > = None ;
133- for key in input_order {
134- let predicates = grouped
135- . remove ( & key)
136- . expect ( "grouped predicates should exist for every source" ) ;
134+ for ( key, predicates) in grouped {
137135 let combined =
138136 disjunction ( predicates) . expect ( "union branches always provide predicates" ) ;
139137 let branch = LogicalPlanBuilder :: from ( key. source )
@@ -203,7 +201,7 @@ fn extract_branch(plan: LogicalPlan) -> Result<Option<UnionBranch>> {
203201 Ok ( None )
204202 }
205203 other => Ok ( Some ( UnionBranch {
206- source : strip_passthrough_nodes ( other. clone ( ) ) ,
204+ source : strip_passthrough_nodes ( other) ,
207205 predicate : Expr :: Literal (
208206 datafusion_common:: ScalarValue :: Boolean ( Some ( true ) ) ,
209207 None ,
@@ -213,13 +211,13 @@ fn extract_branch(plan: LogicalPlan) -> Result<Option<UnionBranch>> {
213211 }
214212}
215213
216- #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
214+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
217215struct GroupKey {
218216 source : LogicalPlan ,
219217 wrappers : Vec < Wrapper > ,
220218}
221219
222- #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
220+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
223221enum Wrapper {
224222 Projection {
225223 expr : Vec < Expr > ,
@@ -268,6 +266,10 @@ fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result<LogicalPla
268266 Arc :: clone ( schema) ,
269267 ) ?)
270268 }
269+ // SubqueryAlias::try_new recomputes the schema from the new input.
270+ // This is safe because the source table is unchanged; only the
271+ // filter predicate differs, so the recomputed schema matches the
272+ // original one stored in peel_wrappers.
271273 Wrapper :: SubqueryAlias { alias, .. } => LogicalPlan :: SubqueryAlias (
272274 SubqueryAlias :: try_new ( Arc :: new ( plan) , alias. clone ( ) ) ?,
273275 ) ,
@@ -276,15 +278,17 @@ fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result<LogicalPla
276278 Ok ( plan)
277279}
278280
279- fn strip_passthrough_nodes ( plan : LogicalPlan ) -> LogicalPlan {
280- match plan {
281- LogicalPlan :: Projection ( Projection { input, .. } ) => {
282- strip_passthrough_nodes ( Arc :: unwrap_or_clone ( input) )
283- }
284- LogicalPlan :: SubqueryAlias ( SubqueryAlias { input, .. } ) => {
285- strip_passthrough_nodes ( Arc :: unwrap_or_clone ( input) )
286- }
287- other => other,
281+ fn strip_passthrough_nodes ( mut plan : LogicalPlan ) -> LogicalPlan {
282+ loop {
283+ plan = match plan {
284+ LogicalPlan :: Projection ( Projection { input, .. } ) => {
285+ Arc :: unwrap_or_clone ( input)
286+ }
287+ LogicalPlan :: SubqueryAlias ( SubqueryAlias { input, .. } ) => {
288+ Arc :: unwrap_or_clone ( input)
289+ }
290+ other => return other,
291+ } ;
288292 }
289293}
290294
0 commit comments