@@ -25,7 +25,7 @@ import org.scalatest.Tag
2525import org .apache .spark .sql .CometTestBase
2626import org .apache .spark .sql .catalyst .TableIdentifier
2727import org .apache .spark .sql .catalyst .analysis .UnresolvedRelation
28- import org .apache .spark .sql .comet .{CometBroadcastExchangeExec , CometBroadcastHashJoinExec }
28+ import org .apache .spark .sql .comet .{CometBroadcastExchangeExec , CometBroadcastHashJoinExec , CometSortMergeJoinExec }
2929import org .apache .spark .sql .execution .adaptive .AQEShuffleReadExec
3030import org .apache .spark .sql .internal .SQLConf
3131
@@ -55,21 +55,159 @@ class CometJoinSuite extends CometTestBase {
5555 .toSeq)
5656 }
5757
58- test(" SortMergeJoin with unsupported key type should fall back to Spark " ) {
58+ test(" SortMergeJoin with TimestampType key runs natively " ) {
5959 withSQLConf(
6060 SQLConf .SESSION_LOCAL_TIMEZONE .key -> " Asia/Kathmandu" ,
6161 SQLConf .ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
62- SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
62+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
63+ SQLConf .PREFER_SORTMERGEJOIN .key -> " true" ) {
6364 withTable(" t1" , " t2" ) {
6465 sql(" CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET" )
65- sql(" INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')" )
66+ sql(
67+ " INSERT OVERWRITE t1 VALUES " +
68+ " ('a', timestamp'2019-01-01 11:11:11'), " +
69+ " ('b', timestamp'2020-05-05 05:05:05')" )
6670
6771 sql(" CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET" )
68- sql(" INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')" )
72+ sql(
73+ " INSERT OVERWRITE t2 VALUES " +
74+ " ('a', timestamp'2019-01-01 11:11:11'), " +
75+ " ('c', timestamp'2021-07-07 07:07:07')" )
76+
77+ checkSparkAnswerAndOperator(
78+ sql(" SELECT * FROM t1 JOIN t2 ON t1.time = t2.time" ),
79+ Seq (classOf [CometSortMergeJoinExec ]))
80+ }
81+ }
82+ }
6983
70- val df = sql(" SELECT * FROM t1 JOIN t2 ON t1.time = t2.time" )
71- val (sparkPlan, cometPlan) = checkSparkAnswer(df)
72- assert(sparkPlan.canonicalized === cometPlan.canonicalized)
84+ test(" SortMergeJoin with TimestampType key supports outer joins" ) {
85+ withSQLConf(
86+ SQLConf .SESSION_LOCAL_TIMEZONE .key -> " Asia/Kathmandu" ,
87+ SQLConf .ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
88+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
89+ SQLConf .PREFER_SORTMERGEJOIN .key -> " true" ) {
90+ withTable(" t1" , " t2" ) {
91+ sql(" CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET" )
92+ sql(
93+ " INSERT OVERWRITE t1 VALUES " +
94+ " (1, timestamp'2019-01-01 11:11:11'), " +
95+ " (2, timestamp'2020-05-05 05:05:05'), " +
96+ " (3, timestamp'2021-07-07 07:07:07')" )
97+
98+ sql(" CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET" )
99+ sql(
100+ " INSERT OVERWRITE t2 VALUES " +
101+ " (10, timestamp'2019-01-01 11:11:11'), " +
102+ " (20, timestamp'2022-02-02 02:02:02')" )
103+
104+ for (joinType <- Seq (" LEFT OUTER" , " RIGHT OUTER" , " FULL OUTER" )) {
105+ checkSparkAnswerAndOperator(
106+ sql(s " SELECT * FROM t1 $joinType JOIN t2 ON t1.time = t2.time " ),
107+ Seq (classOf [CometSortMergeJoinExec ]))
108+ }
109+ }
110+ }
111+ }
112+
113+ test(" SortMergeJoin with composite (string, timestamp) key runs natively" ) {
114+ withSQLConf(
115+ SQLConf .ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
116+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
117+ SQLConf .PREFER_SORTMERGEJOIN .key -> " true" ) {
118+ withTable(" t1" , " t2" ) {
119+ sql(" CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET" )
120+ sql(
121+ " INSERT OVERWRITE t1 VALUES " +
122+ " ('a', timestamp'2019-01-01 11:11:11'), " +
123+ " ('b', timestamp'2019-01-01 11:11:11'), " +
124+ " ('a', timestamp'2020-05-05 05:05:05')" )
125+
126+ sql(" CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET" )
127+ sql(
128+ " INSERT OVERWRITE t2 VALUES " +
129+ " ('a', timestamp'2019-01-01 11:11:11'), " +
130+ " ('b', timestamp'2020-05-05 05:05:05'), " +
131+ " ('a', timestamp'2020-05-05 05:05:05')" )
132+
133+ checkSparkAnswerAndOperator(
134+ sql(
135+ " SELECT * FROM t1 JOIN t2 " +
136+ " ON t1.name = t2.name AND t1.time = t2.time" ),
137+ Seq (classOf [CometSortMergeJoinExec ]))
138+ }
139+ }
140+ }
141+
142+ test(" SortMergeJoin with nullable TimestampType key runs natively" ) {
143+ withSQLConf(
144+ SQLConf .ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
145+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
146+ SQLConf .PREFER_SORTMERGEJOIN .key -> " true" ) {
147+ withTable(" t1" , " t2" ) {
148+ sql(" CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET" )
149+ sql(
150+ " INSERT OVERWRITE t1 VALUES " +
151+ " (1, timestamp'2019-01-01 11:11:11'), " +
152+ " (2, CAST(NULL AS TIMESTAMP)), " +
153+ " (3, timestamp'2020-05-05 05:05:05')" )
154+
155+ sql(" CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET" )
156+ sql(
157+ " INSERT OVERWRITE t2 VALUES " +
158+ " (10, timestamp'2019-01-01 11:11:11'), " +
159+ " (20, CAST(NULL AS TIMESTAMP)), " +
160+ " (30, timestamp'2022-02-02 02:02:02')" )
161+
162+ // Inner join: NULL = NULL must not match in Spark semantics.
163+ checkSparkAnswerAndOperator(
164+ sql(" SELECT * FROM t1 JOIN t2 ON t1.time = t2.time" ),
165+ Seq (classOf [CometSortMergeJoinExec ]))
166+
167+ // Full outer join: NULL-keyed rows from both sides surface as unmatched.
168+ checkSparkAnswerAndOperator(
169+ sql(" SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.time = t2.time" ),
170+ Seq (classOf [CometSortMergeJoinExec ]))
171+ }
172+ }
173+ }
174+
175+ test(" SortMergeJoin with TimestampType key across mixed write-time session timezones" ) {
176+ // TimestampType is an instant (UTC microseconds); only the parsing of literal
177+ // strings depends on the session timezone. Writing each side under a different
178+ // session zone with wall-clock literals that resolve to the same UTC instant
179+ // must still produce a join match.
180+ withSQLConf(
181+ SQLConf .ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
182+ SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ,
183+ SQLConf .PREFER_SORTMERGEJOIN .key -> " true" ) {
184+ withTable(" t1" , " t2" ) {
185+ // t1 written in America/Los_Angeles. 03:11:11 -0800 == 11:11:11 UTC.
186+ withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Los_Angeles" ) {
187+ sql(" CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET" )
188+ sql(
189+ " INSERT OVERWRITE t1 VALUES " +
190+ " ('a', timestamp'2019-01-01 03:11:11'), " +
191+ " ('b', timestamp'2020-05-04 22:05:05')" )
192+ }
193+
194+ // t2 written in Asia/Tokyo. 20:11:11 +0900 == 11:11:11 UTC, so the 'a' and
195+ // 'a2' rows share a UTC instant with t1's 'a' row.
196+ withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> " Asia/Tokyo" ) {
197+ sql(" CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET" )
198+ sql(
199+ " INSERT OVERWRITE t2 VALUES " +
200+ " ('a', timestamp'2019-01-01 20:11:11'), " +
201+ " ('c', timestamp'2021-07-07 16:07:07')" )
202+ }
203+
204+ // Read at a third session timezone to confirm the equality is on the
205+ // stored UTC instant rather than the displayed wall-clock value.
206+ withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> " UTC" ) {
207+ checkSparkAnswerAndOperator(
208+ sql(" SELECT * FROM t1 JOIN t2 ON t1.time = t2.time" ),
209+ Seq (classOf [CometSortMergeJoinExec ]))
210+ }
73211 }
74212 }
75213 }
0 commit comments