Skip to content

Commit fbeeb86

Browse files
JanKaulclaude
andcommitted
fix multi-pair join.on producing edges with stale columns
A single Join.on that mixes equi-pairs spanning different extracted node pairs was cloning the whole on-list into every new edge, so reconstructed joins referenced columns missing from their schemas. The resulting multi-edge structure also formed cycles that IK84 cannot process and the denormalize step panicked on. Group pairs by their (sorted) node-pair, so each edge carries only its own equi-pairs. When a new edge would close a cycle, demote its pairs to the existing side-channel filter list (re-applied as a Filter above the reordered join). When a prior recursive call already added an edge between the same two nodes, extend its on-list instead of adding a parallel edge. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a40a555 commit fbeeb86

2 files changed

Lines changed: 211 additions & 19 deletions

File tree

datafusion/optimizer/src/reorder_join/join_graph.rs

Lines changed: 107 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,60 @@ impl JoinGraph {
219219
pub(crate) fn get_edge(&self, key: EdgeId) -> Option<&Edge> {
220220
self.edges.get(key)
221221
}
222+
223+
/// Returns the id of an edge directly connecting `a` and `b`, if one
224+
/// exists.
225+
fn find_edge_between(&self, a: NodeId, b: NodeId) -> Option<EdgeId> {
226+
let node_a = self.nodes.get(a)?;
227+
node_a.connections.iter().copied().find(|&eid| {
228+
self.edges
229+
.get(eid)
230+
.map(|e| e.nodes.contains(&b))
231+
.unwrap_or(false)
232+
})
233+
}
234+
235+
/// Appends `pairs` to the given edge's `on` list.
236+
fn extend_edge_on(&mut self, edge_id: EdgeId, pairs: Vec<(Expr, Expr)>) {
237+
if let Some(edge) = self.edges.get_mut(edge_id) {
238+
edge.on.extend(pairs);
239+
}
240+
}
241+
242+
/// Returns true if a path already connects `from` to `to`, treating
243+
/// edges as undirected. Used to detect cycles before adding a new
244+
/// edge; if a path exists, the new edge would close a cycle.
245+
fn path_exists(&self, from: NodeId, to: NodeId) -> bool {
246+
use std::collections::HashSet;
247+
if from == to {
248+
return true;
249+
}
250+
let mut visited: HashSet<NodeId> = HashSet::new();
251+
let mut stack: Vec<NodeId> = vec![from];
252+
while let Some(n) = stack.pop() {
253+
if !visited.insert(n) {
254+
continue;
255+
}
256+
if let Some(node) = self.nodes.get(n) {
257+
for &eid in &node.connections {
258+
if let Some(edge) = self.edges.get(eid) {
259+
for &neighbour in &edge.nodes {
260+
if neighbour == n {
261+
continue;
262+
}
263+
if neighbour == to {
264+
return true;
265+
}
266+
if !visited.contains(&neighbour) {
267+
stack.push(neighbour);
268+
}
269+
}
270+
}
271+
}
272+
}
273+
}
274+
false
275+
}
222276
}
223277

224278
/// Extracts the join subtree from a logical plan, separating it from wrapper operators.
@@ -339,18 +393,27 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
339393
join_graph,
340394
)?;
341395

342-
// Process each equijoin predicate to find which nodes it connects
396+
// Group each equi-pair by which two nodes it connects. A
397+
// single `Join.on` can mix pairs that span different node-
398+
// pairs (e.g. an outer join in a bushy plan whose `on`
399+
// contains keys from disjoint sub-trees); putting all pairs
400+
// on every edge produces edges that reference columns
401+
// missing from their endpoints' schemas, and the resulting
402+
// multi-edge structure forms a cycle that IK84 can't
403+
// process.
404+
use std::collections::HashMap;
405+
let mut pairs_by_node_pair: HashMap<(NodeId, NodeId), Vec<(Expr, Expr)>> =
406+
HashMap::new();
407+
let mut insertion_order: Vec<(NodeId, NodeId)> = Vec::new();
408+
343409
for (left_key, right_key) in &join.on {
344-
// Extract column references from both join keys
345410
let left_columns = left_key.column_refs();
346411
let right_columns = right_key.column_refs();
347412

348-
// Filter nodes by checking which ones contain the columns from each expression
349413
let matching_nodes: Vec<NodeId> = join_graph
350414
.nodes()
351415
.filter_map(|(node_id, node)| {
352416
let schema = node.plan.schema();
353-
// Check if this node's schema contains columns from either left or right key
354417
let has_left =
355418
check_all_columns_from_schema(&left_columns, schema.as_ref())
356419
.unwrap_or(false);
@@ -359,8 +422,6 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
359422
schema.as_ref(),
360423
)
361424
.unwrap_or(false);
362-
363-
// Include node if it contains columns from either key (but not both, as that would be invalid)
364425
if (has_left && !has_right) || (!has_left && has_right) {
365426
Some(node_id)
366427
} else {
@@ -369,7 +430,6 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
369430
})
370431
.collect();
371432

372-
// We should have exactly two nodes: one with left_key columns, one with right_key columns
373433
if matching_nodes.len() != 2 {
374434
return plan_err!(
375435
"Could not find exactly two nodes for join predicate: {} = {} (found {} nodes)",
@@ -379,21 +439,49 @@ fn flatten_joins_recursive(plan: LogicalPlan, join_graph: &mut JoinGraph) -> Res
379439
);
380440
}
381441

382-
let node_id_a = matching_nodes[0];
383-
let node_id_b = matching_nodes[1];
442+
let mut endpoints = [matching_nodes[0], matching_nodes[1]];
443+
endpoints.sort();
444+
let key = (endpoints[0], endpoints[1]);
445+
if !pairs_by_node_pair.contains_key(&key) {
446+
insertion_order.push(key);
447+
}
448+
pairs_by_node_pair
449+
.entry(key)
450+
.or_default()
451+
.push((left_key.clone(), right_key.clone()));
452+
}
453+
454+
for (node_a, node_b) in insertion_order {
455+
let pairs = pairs_by_node_pair.remove(&(node_a, node_b)).unwrap();
384456

385-
// Add an edge if one doesn't exist yet
386-
if let Some(node_a) = join_graph.get_node(node_id_a)
387-
&& node_a.connection_with(node_id_b, join_graph).is_none()
457+
// If a prior recursive call already connected these two
458+
// nodes by an edge, merge our pairs into it instead of
459+
// adding a parallel edge.
460+
if let Some(existing_edge_id) =
461+
join_graph.find_edge_between(node_a, node_b)
388462
{
389-
join_graph.add_edge(
390-
node_id_a,
391-
node_id_b,
392-
join.on.clone(),
393-
join.join_type,
394-
join.null_equality,
395-
);
463+
join_graph.extend_edge_on(existing_edge_id, pairs);
464+
continue;
465+
}
466+
467+
// Cycle check: adding this edge would close a cycle.
468+
// IK84 needs a tree, so demote the equi-pairs of this
469+
// group to side-channel filter conjuncts; they'll be
470+
// re-applied as a Filter above the reordered join.
471+
if join_graph.path_exists(node_a, node_b) {
472+
for (l, r) in pairs {
473+
join_graph.add_filter(l.eq(r));
474+
}
475+
continue;
396476
}
477+
478+
join_graph.add_edge(
479+
node_a,
480+
node_b,
481+
pairs,
482+
join.join_type,
483+
join.null_equality,
484+
);
397485
}
398486

399487
Ok(())

datafusion/optimizer/src/reorder_join/left_deep_join_plan.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,4 +760,108 @@ mod tests {
760760

761761
Ok(())
762762
}
763+
764+
/// Walk a plan and verify every `Join.on` column is resolvable in the
765+
/// join's own left or right schema.
766+
fn assert_join_on_columns_resolvable(plan: &LogicalPlan) {
767+
if let LogicalPlan::Join(j) = plan {
768+
let left = j.left.schema();
769+
let right = j.right.schema();
770+
for (lk, rk) in &j.on {
771+
for c in lk.column_refs().into_iter().chain(rk.column_refs()) {
772+
let in_left = left.has_column(c)
773+
|| left.has_column_with_unqualified_name(&c.name);
774+
let in_right = right.has_column(c)
775+
|| right.has_column_with_unqualified_name(&c.name);
776+
assert!(
777+
in_left || in_right,
778+
"Join on-key references `{}` not present in either side.\n\
779+
Left fields: {:?}\n\
780+
Right fields: {:?}\n\
781+
Full plan:\n{}",
782+
c,
783+
left.fields().iter().map(|f| f.name()).collect::<Vec<_>>(),
784+
right.fields().iter().map(|f| f.name()).collect::<Vec<_>>(),
785+
plan.display_indent()
786+
);
787+
}
788+
}
789+
}
790+
for input in plan.inputs() {
791+
assert_join_on_columns_resolvable(input);
792+
}
793+
}
794+
795+
/// Repro: a single `Join` with a multi-pair `on` where the pairs span
796+
/// *different* already-extracted node pairs. Here the outer join has
797+
/// `on = [(c_custkey, l_partkey), (o_orderkey, l_orderkey)]` —
798+
/// the first pair connects customer↔lineitem, the second connects
799+
/// orders↔lineitem. Each reconstructed join must only carry the
800+
/// equi-pair(s) whose columns are present in its left and right
801+
/// inputs; bundling all pairs into every edge produces joins that
802+
/// reference columns missing from their schemas.
803+
#[test]
804+
fn test_multi_pair_join_keys_split_to_correct_edges() -> Result<()> {
805+
let customer = scan_tpch_table("customer");
806+
let orders = scan_tpch_table("orders");
807+
let lineitem = scan_tpch_table("lineitem");
808+
809+
let plan = LogicalPlanBuilder::from(customer)
810+
.join(
811+
orders,
812+
JoinType::Inner,
813+
(vec!["c_custkey"], vec!["o_custkey"]),
814+
None,
815+
)?
816+
.join(
817+
lineitem,
818+
JoinType::Inner,
819+
(
820+
vec!["c_custkey", "o_orderkey"],
821+
vec!["l_partkey", "l_orderkey"],
822+
),
823+
None,
824+
)?
825+
.build()?;
826+
827+
let optimized = optimal_left_deep_join_plan(plan, &TestCostEstimator)?;
828+
assert_join_on_columns_resolvable(&optimized);
829+
Ok(())
830+
}
831+
832+
/// Repro: one outer Join with multi-pair `on` where the two equi-pairs
833+
/// connect *different* node pairs after extraction. Without
834+
/// per-pair edge construction the customer↔lineitem edge would carry
835+
/// the orders↔lineitem keys (and vice-versa), and the resulting
836+
/// three edges form a triangle (cycle) that IK84 can't process.
837+
#[test]
838+
fn test_multi_pair_on_creates_cycle_then_resolves() -> Result<()> {
839+
let customer = scan_tpch_table("customer");
840+
let orders = scan_tpch_table("orders");
841+
let lineitem = scan_tpch_table("lineitem");
842+
843+
// Outer join `on = [(c_custkey, l_orderkey), (o_orderkey, l_partkey)]`.
844+
// Pairs map to: customer↔lineitem AND orders↔lineitem.
845+
let plan = LogicalPlanBuilder::from(customer)
846+
.join(
847+
orders,
848+
JoinType::Inner,
849+
(vec!["c_custkey"], vec!["o_custkey"]),
850+
None,
851+
)?
852+
.join(
853+
lineitem,
854+
JoinType::Inner,
855+
(
856+
vec!["c_custkey", "o_orderkey"],
857+
vec!["l_orderkey", "l_partkey"],
858+
),
859+
None,
860+
)?
861+
.build()?;
862+
863+
let optimized = optimal_left_deep_join_plan(plan, &TestCostEstimator)?;
864+
assert_join_on_columns_resolvable(&optimized);
865+
Ok(())
866+
}
763867
}

0 commit comments

Comments
 (0)