Skip to content

Commit 6f50ccf

Browse files
authored
feat: support TimestampType join keys in SortMergeJoin (#3986)
1 parent 06d2469 commit 6f50ccf

3 files changed

Lines changed: 181 additions & 11 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,25 @@ struct JoinParameters {
150150
pub join_type: DFJoinType,
151151
}
152152

153+
/// If `expr` evaluates to `Timestamp(_, Some(_))` against `schema`, wrap it in a
154+
/// metadata-only cast to `Timestamp(_, None)`. This is required because
155+
/// DataFusion's `SortMergeJoinExec` comparator only supports timezone-less
156+
/// timestamp types, while Spark's `TimestampType` serializes as
157+
/// `Timestamp(µs, "UTC")`. The cast preserves ordering on the same time unit.
158+
fn strip_timestamp_tz(
159+
expr: Arc<dyn PhysicalExpr>,
160+
schema: &Schema,
161+
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
162+
match expr.data_type(schema)? {
163+
DataType::Timestamp(unit, Some(_)) => Ok(Arc::new(CastExpr::new(
164+
expr,
165+
DataType::Timestamp(unit, None),
166+
None,
167+
))),
168+
_ => Ok(expr),
169+
}
170+
}
171+
153172
#[derive(Default)]
154173
pub struct BinaryExprOptions {
155174
pub is_integral_div: bool,
@@ -1727,10 +1746,23 @@ impl PhysicalPlanner {
17271746
let left = Arc::clone(&join_params.left.native_plan);
17281747
let right = Arc::clone(&join_params.right.native_plan);
17291748

1749+
let left_schema = left.schema();
1750+
let right_schema = right.schema();
1751+
let join_on = join_params
1752+
.join_on
1753+
.into_iter()
1754+
.map(|(l, r)| {
1755+
Ok((
1756+
strip_timestamp_tz(l, left_schema.as_ref())?,
1757+
strip_timestamp_tz(r, right_schema.as_ref())?,
1758+
))
1759+
})
1760+
.collect::<Result<Vec<_>, ExecutionError>>()?;
1761+
17301762
let join = Arc::new(SortMergeJoinExec::try_new(
17311763
Arc::clone(&left),
17321764
Arc::clone(&right),
1733-
join_params.join_on,
1765+
join_on,
17341766
join_params.join_filter,
17351767
join_params.join_type,
17361768
sort_options,

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
4343
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
4444
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
4545
import org.apache.spark.sql.internal.SQLConf
46-
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType}
46+
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType, TimestampType}
4747
import org.apache.spark.sql.vectorized.ColumnarBatch
4848
import org.apache.spark.util.SerializableConfiguration
4949
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -2270,7 +2270,7 @@ object CometSortMergeJoinExec extends CometOperatorSerde[SortMergeJoinExec] {
22702270
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
22712271
_: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType =>
22722272
true
2273-
case TimestampNTZType => true
2273+
case TimestampNTZType | _: TimestampType => true
22742274
case _ => false
22752275
}
22762276

spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala

Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.scalatest.Tag
2525
import org.apache.spark.sql.CometTestBase
2626
import org.apache.spark.sql.catalyst.TableIdentifier
2727
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
28-
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
28+
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometSortMergeJoinExec}
2929
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
3030
import org.apache.spark.sql.internal.SQLConf
3131

@@ -55,21 +55,159 @@ class CometJoinSuite extends CometTestBase {
5555
.toSeq)
5656
}
5757

58-
test("SortMergeJoin with unsupported key type should fall back to Spark") {
58+
test("SortMergeJoin with TimestampType key runs natively") {
5959
withSQLConf(
6060
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
6161
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
62-
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
62+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
63+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
6364
withTable("t1", "t2") {
6465
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
65-
sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")
66+
sql(
67+
"INSERT OVERWRITE t1 VALUES " +
68+
"('a', timestamp'2019-01-01 11:11:11'), " +
69+
"('b', timestamp'2020-05-05 05:05:05')")
6670

6771
sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
68-
sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")
72+
sql(
73+
"INSERT OVERWRITE t2 VALUES " +
74+
"('a', timestamp'2019-01-01 11:11:11'), " +
75+
"('c', timestamp'2021-07-07 07:07:07')")
76+
77+
checkSparkAnswerAndOperator(
78+
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
79+
Seq(classOf[CometSortMergeJoinExec]))
80+
}
81+
}
82+
}
6983

70-
val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
71-
val (sparkPlan, cometPlan) = checkSparkAnswer(df)
72-
assert(sparkPlan.canonicalized === cometPlan.canonicalized)
84+
test("SortMergeJoin with TimestampType key supports outer joins") {
85+
withSQLConf(
86+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
87+
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
88+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
89+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
90+
withTable("t1", "t2") {
91+
sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET")
92+
sql(
93+
"INSERT OVERWRITE t1 VALUES " +
94+
"(1, timestamp'2019-01-01 11:11:11'), " +
95+
"(2, timestamp'2020-05-05 05:05:05'), " +
96+
"(3, timestamp'2021-07-07 07:07:07')")
97+
98+
sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET")
99+
sql(
100+
"INSERT OVERWRITE t2 VALUES " +
101+
"(10, timestamp'2019-01-01 11:11:11'), " +
102+
"(20, timestamp'2022-02-02 02:02:02')")
103+
104+
for (joinType <- Seq("LEFT OUTER", "RIGHT OUTER", "FULL OUTER")) {
105+
checkSparkAnswerAndOperator(
106+
sql(s"SELECT * FROM t1 $joinType JOIN t2 ON t1.time = t2.time"),
107+
Seq(classOf[CometSortMergeJoinExec]))
108+
}
109+
}
110+
}
111+
}
112+
113+
test("SortMergeJoin with composite (string, timestamp) key runs natively") {
114+
withSQLConf(
115+
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
116+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
117+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
118+
withTable("t1", "t2") {
119+
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
120+
sql(
121+
"INSERT OVERWRITE t1 VALUES " +
122+
"('a', timestamp'2019-01-01 11:11:11'), " +
123+
"('b', timestamp'2019-01-01 11:11:11'), " +
124+
"('a', timestamp'2020-05-05 05:05:05')")
125+
126+
sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
127+
sql(
128+
"INSERT OVERWRITE t2 VALUES " +
129+
"('a', timestamp'2019-01-01 11:11:11'), " +
130+
"('b', timestamp'2020-05-05 05:05:05'), " +
131+
"('a', timestamp'2020-05-05 05:05:05')")
132+
133+
checkSparkAnswerAndOperator(
134+
sql(
135+
"SELECT * FROM t1 JOIN t2 " +
136+
"ON t1.name = t2.name AND t1.time = t2.time"),
137+
Seq(classOf[CometSortMergeJoinExec]))
138+
}
139+
}
140+
}
141+
142+
test("SortMergeJoin with nullable TimestampType key runs natively") {
143+
withSQLConf(
144+
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
145+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
146+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
147+
withTable("t1", "t2") {
148+
sql("CREATE TABLE t1(id INT, time TIMESTAMP) USING PARQUET")
149+
sql(
150+
"INSERT OVERWRITE t1 VALUES " +
151+
"(1, timestamp'2019-01-01 11:11:11'), " +
152+
"(2, CAST(NULL AS TIMESTAMP)), " +
153+
"(3, timestamp'2020-05-05 05:05:05')")
154+
155+
sql("CREATE TABLE t2(id INT, time TIMESTAMP) USING PARQUET")
156+
sql(
157+
"INSERT OVERWRITE t2 VALUES " +
158+
"(10, timestamp'2019-01-01 11:11:11'), " +
159+
"(20, CAST(NULL AS TIMESTAMP)), " +
160+
"(30, timestamp'2022-02-02 02:02:02')")
161+
162+
// Inner join: NULL = NULL must not match in Spark semantics.
163+
checkSparkAnswerAndOperator(
164+
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
165+
Seq(classOf[CometSortMergeJoinExec]))
166+
167+
// Full outer join: NULL-keyed rows from both sides surface as unmatched.
168+
checkSparkAnswerAndOperator(
169+
sql("SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.time = t2.time"),
170+
Seq(classOf[CometSortMergeJoinExec]))
171+
}
172+
}
173+
}
174+
175+
test("SortMergeJoin with TimestampType key across mixed write-time session timezones") {
176+
// TimestampType is an instant (UTC microseconds); only the parsing of literal
177+
// strings depends on the session timezone. Writing each side under a different
178+
// session zone with wall-clock literals that resolve to the same UTC instant
179+
// must still produce a join match.
180+
withSQLConf(
181+
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
182+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
183+
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
184+
withTable("t1", "t2") {
185+
// t1 written in America/Los_Angeles. 03:11:11 -0800 == 11:11:11 UTC.
186+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
187+
sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
188+
sql(
189+
"INSERT OVERWRITE t1 VALUES " +
190+
"('a', timestamp'2019-01-01 03:11:11'), " +
191+
"('b', timestamp'2020-05-04 22:05:05')")
192+
}
193+
194+
// t2 written in Asia/Tokyo. 20:11:11 +0900 == 11:11:11 UTC, so the 'a' and
195+
// 'a2' rows share a UTC instant with t1's 'a' row.
196+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Tokyo") {
197+
sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
198+
sql(
199+
"INSERT OVERWRITE t2 VALUES " +
200+
"('a', timestamp'2019-01-01 20:11:11'), " +
201+
"('c', timestamp'2021-07-07 16:07:07')")
202+
}
203+
204+
// Read at a third session timezone to confirm the equality is on the
205+
// stored UTC instant rather than the displayed wall-clock value.
206+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
207+
checkSparkAnswerAndOperator(
208+
sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time"),
209+
Seq(classOf[CometSortMergeJoinExec]))
210+
}
73211
}
74212
}
75213
}

0 commit comments

Comments
 (0)