@@ -219,6 +219,60 @@ impl JoinGraph {
219219 pub ( crate ) fn get_edge ( & self , key : EdgeId ) -> Option < & Edge > {
220220 self . edges . get ( key)
221221 }
222+
223+ /// Returns the id of an edge directly connecting `a` and `b`, if one
224+ /// exists.
225+ fn find_edge_between ( & self , a : NodeId , b : NodeId ) -> Option < EdgeId > {
226+ let node_a = self . nodes . get ( a) ?;
227+ node_a. connections . iter ( ) . copied ( ) . find ( |& eid| {
228+ self . edges
229+ . get ( eid)
230+ . map ( |e| e. nodes . contains ( & b) )
231+ . unwrap_or ( false )
232+ } )
233+ }
234+
235+ /// Appends `pairs` to the given edge's `on` list.
236+ fn extend_edge_on ( & mut self , edge_id : EdgeId , pairs : Vec < ( Expr , Expr ) > ) {
237+ if let Some ( edge) = self . edges . get_mut ( edge_id) {
238+ edge. on . extend ( pairs) ;
239+ }
240+ }
241+
242+ /// Returns true if a path already connects `from` to `to`, treating
243+ /// edges as undirected. Used to detect cycles before adding a new
244+ /// edge; if a path exists, the new edge would close a cycle.
245+ fn path_exists ( & self , from : NodeId , to : NodeId ) -> bool {
246+ use std:: collections:: HashSet ;
247+ if from == to {
248+ return true ;
249+ }
250+ let mut visited: HashSet < NodeId > = HashSet :: new ( ) ;
251+ let mut stack: Vec < NodeId > = vec ! [ from] ;
252+ while let Some ( n) = stack. pop ( ) {
253+ if !visited. insert ( n) {
254+ continue ;
255+ }
256+ if let Some ( node) = self . nodes . get ( n) {
257+ for & eid in & node. connections {
258+ if let Some ( edge) = self . edges . get ( eid) {
259+ for & neighbour in & edge. nodes {
260+ if neighbour == n {
261+ continue ;
262+ }
263+ if neighbour == to {
264+ return true ;
265+ }
266+ if !visited. contains ( & neighbour) {
267+ stack. push ( neighbour) ;
268+ }
269+ }
270+ }
271+ }
272+ }
273+ }
274+ false
275+ }
222276}
223277
224278/// Extracts the join subtree from a logical plan, separating it from wrapper operators.
@@ -339,18 +393,27 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
339393 join_graph,
340394 ) ?;
341395
342- // Process each equijoin predicate to find which nodes it connects
396+ // Group each equi-pair by which two nodes it connects. A
397+ // single `Join.on` can mix pairs that span different node-
398+ // pairs (e.g. an outer join in a bushy plan whose `on`
399+ // contains keys from disjoint sub-trees); putting all pairs
400+ // on every edge produces edges that reference columns
401+ // missing from their endpoints' schemas, and the resulting
402+ // multi-edge structure forms a cycle that IK84 can't
403+ // process.
404+ use std:: collections:: HashMap ;
405+ let mut pairs_by_node_pair: HashMap < ( NodeId , NodeId ) , Vec < ( Expr , Expr ) > > =
406+ HashMap :: new ( ) ;
407+ let mut insertion_order: Vec < ( NodeId , NodeId ) > = Vec :: new ( ) ;
408+
343409 for ( left_key, right_key) in & join. on {
344- // Extract column references from both join keys
345410 let left_columns = left_key. column_refs ( ) ;
346411 let right_columns = right_key. column_refs ( ) ;
347412
348- // Filter nodes by checking which ones contain the columns from each expression
349413 let matching_nodes: Vec < NodeId > = join_graph
350414 . nodes ( )
351415 . filter_map ( |( node_id, node) | {
352416 let schema = node. plan . schema ( ) ;
353- // Check if this node's schema contains columns from either left or right key
354417 let has_left =
355418 check_all_columns_from_schema ( & left_columns, schema. as_ref ( ) )
356419 . unwrap_or ( false ) ;
@@ -359,8 +422,6 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
359422 schema. as_ref ( ) ,
360423 )
361424 . unwrap_or ( false ) ;
362-
363- // Include node if it contains columns from either key (but not both, as that would be invalid)
364425 if ( has_left && !has_right) || ( !has_left && has_right) {
365426 Some ( node_id)
366427 } else {
@@ -369,7 +430,6 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
369430 } )
370431 . collect ( ) ;
371432
372- // We should have exactly two nodes: one with left_key columns, one with right_key columns
373433 if matching_nodes. len ( ) != 2 {
374434 return plan_err ! (
375435 "Could not find exactly two nodes for join predicate: {} = {} (found {} nodes)" ,
@@ -379,21 +439,49 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
379439 ) ;
380440 }
381441
382- let node_id_a = matching_nodes[ 0 ] ;
383- let node_id_b = matching_nodes[ 1 ] ;
442+ let mut endpoints = [ matching_nodes[ 0 ] , matching_nodes[ 1 ] ] ;
443+ endpoints. sort ( ) ;
444+ let key = ( endpoints[ 0 ] , endpoints[ 1 ] ) ;
445+ if !pairs_by_node_pair. contains_key ( & key) {
446+ insertion_order. push ( key) ;
447+ }
448+ pairs_by_node_pair
449+ . entry ( key)
450+ . or_default ( )
451+ . push ( ( left_key. clone ( ) , right_key. clone ( ) ) ) ;
452+ }
453+
454+ for ( node_a, node_b) in insertion_order {
455+ let pairs = pairs_by_node_pair. remove ( & ( node_a, node_b) ) . unwrap ( ) ;
384456
385- // Add an edge if one doesn't exist yet
386- if let Some ( node_a) = join_graph. get_node ( node_id_a)
387- && node_a. connection_with ( node_id_b, join_graph) . is_none ( )
457+ // If a prior recursive call already connected these two
458+ // nodes by an edge, merge our pairs into it instead of
459+ // adding a parallel edge.
460+ if let Some ( existing_edge_id) =
461+ join_graph. find_edge_between ( node_a, node_b)
388462 {
389- join_graph. add_edge (
390- node_id_a,
391- node_id_b,
392- join. on . clone ( ) ,
393- join. join_type ,
394- join. null_equality ,
395- ) ;
463+ join_graph. extend_edge_on ( existing_edge_id, pairs) ;
464+ continue ;
465+ }
466+
467+ // Cycle check: adding this edge would close a cycle.
468+ // IK84 needs a tree, so demote the equi-pairs of this
469+ // group to side-channel filter conjuncts; they'll be
470+ // re-applied as a Filter above the reordered join.
471+ if join_graph. path_exists ( node_a, node_b) {
472+ for ( l, r) in pairs {
473+ join_graph. add_filter ( l. eq ( r) ) ;
474+ }
475+ continue ;
396476 }
477+
478+ join_graph. add_edge (
479+ node_a,
480+ node_b,
481+ pairs,
482+ join. join_type ,
483+ join. null_equality ,
484+ ) ;
397485 }
398486
399487 Ok ( ( ) )
0 commit comments