1919
2020package org .apache .spark .sql
2121
22+ import org .apache .spark .sql .catalyst .expressions .AttributeReference
23+ import org .apache .spark .sql .catalyst .optimizer .{BuildLeft , BuildRight }
24+ import org .apache .spark .sql .catalyst .plans .Inner
25+ import org .apache .spark .sql .comet .{CometBroadcastHashJoinExec , CometHashJoinExec , CometSortMergeJoinExec }
26+ import org .apache .spark .sql .execution .LocalTableScanExec
27+ import org .apache .spark .sql .execution .SparkPlan
28+ import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , ShuffledHashJoinExec , SortMergeJoinExec }
29+ import org .apache .spark .sql .types .StringType
30+
31+ import org .apache .comet .{CometConf , CometExplainInfo }
32+ import org .apache .comet .serde .OperatorOuterClass
33+
2234class CometCollationSuite extends CometTestBase {
2335
2436 // Queries that group, sort, or shuffle on a non-default collated string must fall back to
@@ -29,6 +41,8 @@ class CometCollationSuite extends CometTestBase {
2941 " unsupported hash partitioning data type for columnar shuffle"
3042 private val rangeShuffleCollationReason =
3143 " unsupported range partitioning data type for columnar shuffle"
44+ private val joinKeyCollationReason =
45+ " unsupported non-default collated string join keys"
3246
3347 test(" listagg DISTINCT with utf8_lcase collation (issue #1947)" ) {
3448 checkSparkAnswerAndFallbackReason(
@@ -66,4 +80,171 @@ class CometCollationSuite extends CometTestBase {
6680 checkSparkAnswerAndOperator(" SELECT DISTINCT _1 FROM tbl ORDER BY _1" )
6781 }
6882 }
83+
84+ // ---- Join collation guards (issue #4051) ----------------------------------------
85+ //
86+ // Comet's native join compares keys byte-by-byte, so 'a' and 'A' would not match
87+ // under utf8_lcase, producing wrong results. The converters must reject any join
88+ // whose keys carry a non-default collation.
89+ //
90+ // End-to-end SQL cannot reach the join converter today: higher-level guards
91+ // (CometScanRule, Collate-expression serialization, #4035 shuffle guard) short-circuit
92+ // first. The tests below bypass those guards by constructing physical-plan operators
93+ // directly and calling convert() — the contract is that convert() returns None for
94+ // collated keys.
95+
96+ private def collatedKey (name : String ): AttributeReference =
97+ AttributeReference (name, StringType (" UTF8_LCASE" ), nullable = false )()
98+
99+ private def placeholderChildOp (): OperatorOuterClass .Operator =
100+ OperatorOuterClass .Operator .newBuilder().build()
101+
102+ // Ensure converters are on so that None from convert() means the collation guard fired,
103+ // not that the join type is disabled.
104+ private def withJoinConvertersEnabled (f : => Unit ): Unit =
105+ withSQLConf(
106+ CometConf .COMET_EXEC_HASH_JOIN_ENABLED .key -> " true" ,
107+ CometConf .COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED .key -> " true" ,
108+ CometConf .COMET_EXEC_SORT_MERGE_JOIN_ENABLED .key -> " true" ) {
109+ f
110+ }
111+
112+ private def assertFallbackReason (plan : SparkPlan , expectedReason : String ): Unit = {
113+ val reasons = plan.getTagValue(CometExplainInfo .EXTENSION_INFO ).getOrElse(Set .empty[String ])
114+ assert(
115+ reasons.contains(expectedReason),
116+ s " Expected fallback reason ' $expectedReason' on ${plan.nodeName}, got: $reasons" )
117+ }
118+
119+ test(" CometBroadcastHashJoinExec rejects non-default collated join keys" ) {
120+ withJoinConvertersEnabled {
121+ val left = collatedKey(" l" )
122+ val right = collatedKey(" r" )
123+ val join = BroadcastHashJoinExec (
124+ leftKeys = Seq (left),
125+ rightKeys = Seq (right),
126+ joinType = Inner ,
127+ buildSide = BuildRight ,
128+ condition = None ,
129+ left = LocalTableScanExec (Seq (left), Nil , None ),
130+ right = LocalTableScanExec (Seq (right), Nil , None ))
131+
132+ val builder = OperatorOuterClass .Operator .newBuilder()
133+ val result =
134+ CometBroadcastHashJoinExec .convert(
135+ join,
136+ builder,
137+ placeholderChildOp(),
138+ placeholderChildOp())
139+
140+ assert(
141+ result.isEmpty,
142+ " CometBroadcastHashJoinExec.convert must reject non-default collated join keys " +
143+ " (issue #4051): native byte equality cannot match values that compare equal " +
144+ " under utf8_lcase. Got a non-empty proto: " + result)
145+ assertFallbackReason(join, joinKeyCollationReason)
146+ }
147+ }
148+
149+ test(" CometHashJoinExec rejects non-default collated join keys" ) {
150+ withJoinConvertersEnabled {
151+ val left = collatedKey(" l" )
152+ val right = collatedKey(" r" )
153+ val join = ShuffledHashJoinExec (
154+ leftKeys = Seq (left),
155+ rightKeys = Seq (right),
156+ joinType = Inner ,
157+ buildSide = BuildLeft ,
158+ condition = None ,
159+ left = LocalTableScanExec (Seq (left), Nil , None ),
160+ right = LocalTableScanExec (Seq (right), Nil , None ))
161+
162+ val builder = OperatorOuterClass .Operator .newBuilder()
163+ val result =
164+ CometHashJoinExec .convert(join, builder, placeholderChildOp(), placeholderChildOp())
165+
166+ assert(
167+ result.isEmpty,
168+ " CometHashJoinExec.convert must reject non-default collated join keys (issue " +
169+ " #4051): native byte equality cannot match values that compare equal under " +
170+ " utf8_lcase. Got a non-empty proto: " + result)
171+ assertFallbackReason(join, joinKeyCollationReason)
172+ }
173+ }
174+
175+ test(" CometBroadcastHashJoinExec still accepts default UTF8_BINARY string keys" ) {
176+ withJoinConvertersEnabled {
177+ val left = AttributeReference (" l" , StringType , nullable = false )()
178+ val right = AttributeReference (" r" , StringType , nullable = false )()
179+ val join = BroadcastHashJoinExec (
180+ leftKeys = Seq (left),
181+ rightKeys = Seq (right),
182+ joinType = Inner ,
183+ buildSide = BuildRight ,
184+ condition = None ,
185+ left = LocalTableScanExec (Seq (left), Nil , None ),
186+ right = LocalTableScanExec (Seq (right), Nil , None ))
187+
188+ val builder = OperatorOuterClass .Operator .newBuilder()
189+ val result =
190+ CometBroadcastHashJoinExec .convert(
191+ join,
192+ builder,
193+ placeholderChildOp(),
194+ placeholderChildOp())
195+
196+ assert(
197+ result.isDefined,
198+ " CometBroadcastHashJoinExec.convert must continue to accept default UTF8_BINARY " +
199+ " string keys; the collation guard for #4051 must not over-block." )
200+ }
201+ }
202+
203+ test(" CometSortMergeJoinExec rejects non-default collated join keys" ) {
204+ withJoinConvertersEnabled {
205+ val left = collatedKey(" l" )
206+ val right = collatedKey(" r" )
207+ val join = SortMergeJoinExec (
208+ leftKeys = Seq (left),
209+ rightKeys = Seq (right),
210+ joinType = Inner ,
211+ condition = None ,
212+ left = LocalTableScanExec (Seq (left), Nil , None ),
213+ right = LocalTableScanExec (Seq (right), Nil , None ))
214+
215+ val builder = OperatorOuterClass .Operator .newBuilder()
216+ val result =
217+ CometSortMergeJoinExec .convert(join, builder, placeholderChildOp(), placeholderChildOp())
218+
219+ assert(
220+ result.isEmpty,
221+ " CometSortMergeJoinExec.convert must reject non-default collated join keys " +
222+ " (issue #4051): supportedSortMergeJoinEqualType must check collation. Got a " +
223+ " non-empty proto: " + result)
224+ assertFallbackReason(join, joinKeyCollationReason)
225+ }
226+ }
227+
228+ test(" CometSortMergeJoinExec still accepts default UTF8_BINARY string keys" ) {
229+ withJoinConvertersEnabled {
230+ val left = AttributeReference (" l" , StringType , nullable = false )()
231+ val right = AttributeReference (" r" , StringType , nullable = false )()
232+ val join = SortMergeJoinExec (
233+ leftKeys = Seq (left),
234+ rightKeys = Seq (right),
235+ joinType = Inner ,
236+ condition = None ,
237+ left = LocalTableScanExec (Seq (left), Nil , None ),
238+ right = LocalTableScanExec (Seq (right), Nil , None ))
239+
240+ val builder = OperatorOuterClass .Operator .newBuilder()
241+ val result =
242+ CometSortMergeJoinExec .convert(join, builder, placeholderChildOp(), placeholderChildOp())
243+
244+ assert(
245+ result.isDefined,
246+ " CometSortMergeJoinExec.convert must continue to accept default UTF8_BINARY " +
247+ " string keys; the collation guard for #4051 must not over-block." )
248+ }
249+ }
69250}
0 commit comments