Skip to content

Commit b5b86b1

Browse files
authored
reject non-default collated string join keys in Comet hash join and sort-merge join (#4095)
1 parent 9a20f29 commit b5b86b1

2 files changed

Lines changed: 197 additions & 3 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,6 +1716,12 @@ trait CometHashJoin {
17161716
return None
17171717
}
17181718

1719+
val joinKeys = join.leftKeys ++ join.rightKeys
1720+
if (joinKeys.exists(key => isStringCollationType(key.dataType))) {
1721+
withInfo(join, "unsupported non-default collated string join keys")
1722+
return None
1723+
}
1724+
17191725
val condition = join.condition.map { cond =>
17201726
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
17211727
if (condProto.isEmpty) {
@@ -1757,7 +1763,7 @@ trait CometHashJoin {
17571763
condition.foreach(joinBuilder.setCondition)
17581764
Some(builder.setHashJoin(joinBuilder).build())
17591765
} else {
1760-
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
1766+
val allExprs: Seq[Expression] = joinKeys
17611767
withInfo(join, allExprs: _*)
17621768
None
17631769
}
@@ -2078,8 +2084,14 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
20782084
}
20792085
}
20802086

2087+
val joinKeys = join.leftKeys ++ join.rightKeys
2088+
if (joinKeys.exists(key => isStringCollationType(key.dataType))) {
2089+
withInfo(join, "unsupported non-default collated string join keys")
2090+
return None
2091+
}
2092+
20812093
// Checks if the join keys are supported by DataFusion SortMergeJoin.
2082-
val errorMsgs = join.leftKeys.flatMap { key =>
2094+
val errorMsgs = joinKeys.flatMap { key =>
20832095
if (!supportedSortMergeJoinEqualType(key.dataType)) {
20842096
Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
20852097
} else {
@@ -2111,7 +2123,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
21112123
condition.map(joinBuilder.setCondition)
21122124
Some(builder.setSortMergeJoin(joinBuilder).build())
21132125
} else {
2114-
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
2126+
val allExprs: Seq[Expression] = joinKeys
21152127
withInfo(join, allExprs: _*)
21162128
None
21172129
}
@@ -2136,6 +2148,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
21362148
* Returns true if given datatype is supported as a key in DataFusion sort merge join.
21372149
*/
21382150
private def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType match {
2151+
case st: StringType if isStringCollationType(st) => false
21392152
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
21402153
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
21412154
true

spark/src/test/spark-4.0/org/apache/spark/sql/CometCollationSuite.scala

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@
1919

2020
package org.apache.spark.sql
2121

22+
import org.apache.spark.sql.catalyst.expressions.AttributeReference
23+
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
24+
import org.apache.spark.sql.catalyst.plans.Inner
25+
import org.apache.spark.sql.comet.{CometBroadcastHashJoinExec, CometHashJoinExec, CometSortMergeJoinExec}
26+
import org.apache.spark.sql.execution.LocalTableScanExec
27+
import org.apache.spark.sql.execution.SparkPlan
28+
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
29+
import org.apache.spark.sql.types.StringType
30+
31+
import org.apache.comet.{CometConf, CometExplainInfo}
32+
import org.apache.comet.serde.OperatorOuterClass
33+
2234
class CometCollationSuite extends CometTestBase {
2335

2436
// Queries that group, sort, or shuffle on a non-default collated string must fall back to
@@ -29,6 +41,8 @@ class CometCollationSuite extends CometTestBase {
2941
"unsupported hash partitioning data type for columnar shuffle"
3042
private val rangeShuffleCollationReason =
3143
"unsupported range partitioning data type for columnar shuffle"
44+
private val joinKeyCollationReason =
45+
"unsupported non-default collated string join keys"
3246

3347
test("listagg DISTINCT with utf8_lcase collation (issue #1947)") {
3448
checkSparkAnswerAndFallbackReason(
@@ -66,4 +80,171 @@ class CometCollationSuite extends CometTestBase {
6680
checkSparkAnswerAndOperator("SELECT DISTINCT _1 FROM tbl ORDER BY _1")
6781
}
6882
}
83+
84+
// ---- Join collation guards (issue #4051) ----------------------------------------
85+
//
86+
// Comet's native join compares keys byte-by-byte, so 'a' and 'A' would not match
87+
// under utf8_lcase, producing wrong results. The converters must reject any join
88+
// whose keys carry a non-default collation.
89+
//
90+
// End-to-end SQL cannot reach the join converter today: higher-level guards
91+
// (CometScanRule, Collate-expression serialization, #4035 shuffle guard) short-circuit
92+
// first. The tests below bypass those guards by constructing physical-plan operators
93+
// directly and calling convert() — the contract is that convert() returns None for
94+
// collated keys.
95+
96+
private def collatedKey(name: String): AttributeReference =
97+
AttributeReference(name, StringType("UTF8_LCASE"), nullable = false)()
98+
99+
private def placeholderChildOp(): OperatorOuterClass.Operator =
100+
OperatorOuterClass.Operator.newBuilder().build()
101+
102+
// Ensure converters are on so that None from convert() means the collation guard fired,
103+
// not that the join type is disabled.
104+
private def withJoinConvertersEnabled(f: => Unit): Unit =
105+
withSQLConf(
106+
CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true",
107+
CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.key -> "true",
108+
CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.key -> "true") {
109+
f
110+
}
111+
112+
private def assertFallbackReason(plan: SparkPlan, expectedReason: String): Unit = {
113+
val reasons = plan.getTagValue(CometExplainInfo.EXTENSION_INFO).getOrElse(Set.empty[String])
114+
assert(
115+
reasons.contains(expectedReason),
116+
s"Expected fallback reason '$expectedReason' on ${plan.nodeName}, got: $reasons")
117+
}
118+
119+
test("CometBroadcastHashJoinExec rejects non-default collated join keys") {
120+
withJoinConvertersEnabled {
121+
val left = collatedKey("l")
122+
val right = collatedKey("r")
123+
val join = BroadcastHashJoinExec(
124+
leftKeys = Seq(left),
125+
rightKeys = Seq(right),
126+
joinType = Inner,
127+
buildSide = BuildRight,
128+
condition = None,
129+
left = LocalTableScanExec(Seq(left), Nil, None),
130+
right = LocalTableScanExec(Seq(right), Nil, None))
131+
132+
val builder = OperatorOuterClass.Operator.newBuilder()
133+
val result =
134+
CometBroadcastHashJoinExec.convert(
135+
join,
136+
builder,
137+
placeholderChildOp(),
138+
placeholderChildOp())
139+
140+
assert(
141+
result.isEmpty,
142+
"CometBroadcastHashJoinExec.convert must reject non-default collated join keys " +
143+
"(issue #4051): native byte equality cannot match values that compare equal " +
144+
"under utf8_lcase. Got a non-empty proto: " + result)
145+
assertFallbackReason(join, joinKeyCollationReason)
146+
}
147+
}
148+
149+
test("CometHashJoinExec rejects non-default collated join keys") {
150+
withJoinConvertersEnabled {
151+
val left = collatedKey("l")
152+
val right = collatedKey("r")
153+
val join = ShuffledHashJoinExec(
154+
leftKeys = Seq(left),
155+
rightKeys = Seq(right),
156+
joinType = Inner,
157+
buildSide = BuildLeft,
158+
condition = None,
159+
left = LocalTableScanExec(Seq(left), Nil, None),
160+
right = LocalTableScanExec(Seq(right), Nil, None))
161+
162+
val builder = OperatorOuterClass.Operator.newBuilder()
163+
val result =
164+
CometHashJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp())
165+
166+
assert(
167+
result.isEmpty,
168+
"CometHashJoinExec.convert must reject non-default collated join keys (issue " +
169+
"#4051): native byte equality cannot match values that compare equal under " +
170+
"utf8_lcase. Got a non-empty proto: " + result)
171+
assertFallbackReason(join, joinKeyCollationReason)
172+
}
173+
}
174+
175+
test("CometBroadcastHashJoinExec still accepts default UTF8_BINARY string keys") {
176+
withJoinConvertersEnabled {
177+
val left = AttributeReference("l", StringType, nullable = false)()
178+
val right = AttributeReference("r", StringType, nullable = false)()
179+
val join = BroadcastHashJoinExec(
180+
leftKeys = Seq(left),
181+
rightKeys = Seq(right),
182+
joinType = Inner,
183+
buildSide = BuildRight,
184+
condition = None,
185+
left = LocalTableScanExec(Seq(left), Nil, None),
186+
right = LocalTableScanExec(Seq(right), Nil, None))
187+
188+
val builder = OperatorOuterClass.Operator.newBuilder()
189+
val result =
190+
CometBroadcastHashJoinExec.convert(
191+
join,
192+
builder,
193+
placeholderChildOp(),
194+
placeholderChildOp())
195+
196+
assert(
197+
result.isDefined,
198+
"CometBroadcastHashJoinExec.convert must continue to accept default UTF8_BINARY " +
199+
"string keys; the collation guard for #4051 must not over-block.")
200+
}
201+
}
202+
203+
test("CometSortMergeJoinExec rejects non-default collated join keys") {
204+
withJoinConvertersEnabled {
205+
val left = collatedKey("l")
206+
val right = collatedKey("r")
207+
val join = SortMergeJoinExec(
208+
leftKeys = Seq(left),
209+
rightKeys = Seq(right),
210+
joinType = Inner,
211+
condition = None,
212+
left = LocalTableScanExec(Seq(left), Nil, None),
213+
right = LocalTableScanExec(Seq(right), Nil, None))
214+
215+
val builder = OperatorOuterClass.Operator.newBuilder()
216+
val result =
217+
CometSortMergeJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp())
218+
219+
assert(
220+
result.isEmpty,
221+
"CometSortMergeJoinExec.convert must reject non-default collated join keys " +
222+
"(issue #4051): supportedSortMergeJoinEqualType must check collation. Got a " +
223+
"non-empty proto: " + result)
224+
assertFallbackReason(join, joinKeyCollationReason)
225+
}
226+
}
227+
228+
test("CometSortMergeJoinExec still accepts default UTF8_BINARY string keys") {
229+
withJoinConvertersEnabled {
230+
val left = AttributeReference("l", StringType, nullable = false)()
231+
val right = AttributeReference("r", StringType, nullable = false)()
232+
val join = SortMergeJoinExec(
233+
leftKeys = Seq(left),
234+
rightKeys = Seq(right),
235+
joinType = Inner,
236+
condition = None,
237+
left = LocalTableScanExec(Seq(left), Nil, None),
238+
right = LocalTableScanExec(Seq(right), Nil, None))
239+
240+
val builder = OperatorOuterClass.Operator.newBuilder()
241+
val result =
242+
CometSortMergeJoinExec.convert(join, builder, placeholderChildOp(), placeholderChildOp())
243+
244+
assert(
245+
result.isDefined,
246+
"CometSortMergeJoinExec.convert must continue to accept default UTF8_BINARY " +
247+
"string keys; the collation guard for #4051 must not over-block.")
248+
}
249+
}
69250
}

0 commit comments

Comments
 (0)