Skip to content

Commit 6d2840a

Browse files
stefankandiccloud-fan
authored andcommitted
[SPARK-56333][SQL] Use multiset intersection for NATURAL JOIN column matching
### What changes were proposed in this pull request? Replace `distinct.filter(resolver)` with a multiset intersection (`canonicalizedIntersect`) when computing common columns for NATURAL JOIN. This preserves duplicate column multiplicity (each name appears `min(left count, right count)` times) instead of deduplicating. ### Why are the changes needed? `distinct` drops duplicate column names, so `NATURAL JOIN` between relations with repeated column names (e.g., c1, c1, c2) silently loses join conditions. The multiset approach matches `Seq.intersect` semantics while still respecting `spark.sql.caseSensitive`. Now `NATURAL JOIN` behaves in the same way as before #54858 except it now respects caseSensitive conf. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New golden file tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #55163 from stefankandic/unionMultiplicity. Authored-by: Stefan Kandic <stefan.kandic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 66a5a89 commit 6d2840a

6 files changed

Lines changed: 80 additions & 4 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3632,9 +3632,8 @@ class Analyzer(
36323632
case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
36333633
if j.resolvedExceptNatural =>
36343634
// find common column names from both sides
3635-
val joinNames = left.output.map(_.name).distinct.filter { leftName =>
3636-
right.output.map(_.name).exists(resolver(leftName, _))
3637-
}
3635+
val joinNames = NaturalAndUsingJoinResolution.canonicalizedIntersect(
3636+
left.output.map(_.name), right.output.map(_.name))
36383637
val project = commonNaturalJoinProcessing(
36393638
left, right, joinType, joinNames, condition, hint)
36403639
j.getTagValue(LogicalPlan.PLAN_ID_TAG)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NaturalAndUsingJoinResolution.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.sql.catalyst.SQLConfHelper
2123
import org.apache.spark.sql.catalyst.expressions.{
2224
Alias,
@@ -82,6 +84,30 @@ object NaturalAndUsingJoinResolution extends DataTypeErrorsBase with SQLConfHelp
8284
(output, hiddenOutput, newCondition)
8385
}
8486

87+
/**
88+
* Computes a multiset intersection of two name sequences, respecting case sensitivity.
89+
* Preserves multiplicity: each name appears min(left count, right count) times.
90+
*/
91+
def canonicalizedIntersect(
92+
leftNames: Seq[String],
93+
rightNames: Seq[String]): Seq[String] = {
94+
val rightNameCounts = mutable.HashMap[String, Int]()
95+
for (name <- rightNames) {
96+
val key = conf.canonicalize(name)
97+
rightNameCounts(key) = rightNameCounts.getOrElse(key, 0) + 1
98+
}
99+
leftNames.filter { leftName =>
100+
val key = conf.canonicalize(leftName)
101+
val count = rightNameCounts.getOrElse(key, 0)
102+
if (count > 0) {
103+
rightNameCounts(key) = count - 1
104+
true
105+
} else {
106+
false
107+
}
108+
}
109+
}
110+
85111
/**
86112
* Returns resolved keys for joining based on the output of [[Join]]'s children or throws and
87113
* error if a key name doesn't exist.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,10 @@ class JoinResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
370370
private def getJoinNamesForNaturalJoin(
371371
leftNameScope: NameScope,
372372
rightNameScope: NameScope): Seq[String] = {
373-
leftNameScope.output.map(_.name).intersect(rightNameScope.output.map(_.name))
373+
NaturalAndUsingJoinResolution.canonicalizedIntersect(
374+
leftNameScope.output.map(_.name),
375+
rightNameScope.output.map(_.name)
376+
)
374377
}
375378

376379
/**

sql/core/src/test/resources/sql-tests/analyzer-results/natural-join.sql.out

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,31 @@ Project [ID#x]
692692
+- SubqueryAlias t2
693693
+- Project [1 AS id#x]
694694
+- OneRowRelation
695+
696+
697+
-- !query
698+
SELECT * FROM (SELECT 1 AS c1, 2 AS c1, 3 AS c2) NATURAL JOIN (SELECT 1 AS c1, 2 AS c1, 3 AS c2)
699+
-- !query analysis
700+
Project [c1#x, c1#x, c2#x, c1#x, c1#x]
701+
+- Project [c1#x, c1#x, c2#x, c1#x, c1#x]
702+
+- Join Inner, (((c1#x = c1#x) AND (c1#x = c1#x)) AND (c2#x = c2#x))
703+
:- SubqueryAlias __auto_generated_subquery_name
704+
: +- Project [1 AS c1#x, 2 AS c1#x, 3 AS c2#x]
705+
: +- OneRowRelation
706+
+- SubqueryAlias __auto_generated_subquery_name
707+
+- Project [1 AS c1#x, 2 AS c1#x, 3 AS c2#x]
708+
+- OneRowRelation
709+
710+
711+
-- !query
712+
SELECT * FROM (SELECT 1 AS c1, 2 AS c1, 3 AS c2) NATURAL JOIN (SELECT 1 AS c1, 3 AS c2)
713+
-- !query analysis
714+
Project [c1#x, c2#x, c1#x]
715+
+- Project [c1#x, c2#x, c1#x]
716+
+- Join Inner, ((c1#x = c1#x) AND (c2#x = c2#x))
717+
:- SubqueryAlias __auto_generated_subquery_name
718+
: +- Project [1 AS c1#x, 2 AS c1#x, 3 AS c2#x]
719+
: +- OneRowRelation
720+
+- SubqueryAlias __auto_generated_subquery_name
721+
+- Project [1 AS c1#x, 3 AS c2#x]
722+
+- OneRowRelation

sql/core/src/test/resources/sql-tests/inputs/natural-join.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ SELECT * FROM nt1 natural join nt2 join nt3 on nt2.k = nt3.k;
7777
SELECT nt1.*, nt2.*, nt3.*, nt4.* FROM nt1 natural join nt2 natural join nt3 natural join nt4;
7878

7979
SELECT * FROM (SELECT 1 AS ID) t1 NATURAL JOIN (SELECT 1 AS id) t2;
80+
81+
SELECT * FROM (SELECT 1 AS c1, 2 AS c1, 3 AS c2) NATURAL JOIN (SELECT 1 AS c1, 2 AS c1, 3 AS c2);
82+
83+
SELECT * FROM (SELECT 1 AS c1, 2 AS c1, 3 AS c2) NATURAL JOIN (SELECT 1 AS c1, 3 AS c2);

sql/core/src/test/resources/sql-tests/results/natural-join.sql.out

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,3 +348,19 @@ SELECT * FROM (SELECT 1 AS ID) t1 NATURAL JOIN (SELECT 1 AS id) t2
348348
struct<ID:int>
349349
-- !query output
350350
1
351+
352+
353+
-- !query
354+
SELECT * FROM (SELECT 1 AS c1, 2 AS c1, 3 AS c2) NATURAL JOIN (SELECT 1 AS c1, 2 AS c1, 3 AS c2)
355+
-- !query schema
356+
struct<c1:int,c1:int,c2:int,c1:int,c1:int>
357+
-- !query output
358+
1 1 3 2 2
359+
360+
361+
-- !query
362+
SELECT * FROM (SELECT 1 AS c1, 2 AS c1, 3 AS c2) NATURAL JOIN (SELECT 1 AS c1, 3 AS c2)
363+
-- !query schema
364+
struct<c1:int,c2:int,c1:int>
365+
-- !query output
366+
1 3 2

0 commit comments

Comments
 (0)