Skip to content

Commit a1fcd0c

Browse files
committed
reflect review comments
1 parent b6ed9a6 commit a1fcd0c

2 files changed

Lines changed: 39 additions & 23 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper
3434
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
3535
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
3636
import org.apache.spark.sql.internal.SQLConf
37-
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType}
38-
import org.apache.spark.unsafe.types.UTF8String
37+
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType}
3938
import org.apache.spark.util.NextIterator
4039

4140
/**
@@ -702,6 +701,7 @@ class SymmetricHashJoinStateManagerV4(
702701
}
703702

704703
private var currentTs = -1L
704+
private var pastUpperBound = false
705705
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
706706

707707
private def flushAccumulated(): GetValuesResult = {
@@ -719,16 +719,16 @@ class SymmetricHashJoinStateManagerV4(
719719

720720
@tailrec
721721
override protected def getNext(): GetValuesResult = {
722-
if (!iter.hasNext) {
722+
if (pastUpperBound || !iter.hasNext) {
723723
flushAccumulated()
724724
} else {
725725
val unsafeRowPair = iter.next()
726726
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
727727

728-
// Filter out entries outside [minTs, maxTs]. This is essential when using
729-
// prefixScan (which returns all timestamps for the key) and serves as a
730-
// safety guard for rangeScan as well.
731-
if (ts < minTs || ts > maxTs) {
728+
if (ts > maxTs) {
729+
pastUpperBound = true
730+
getNext()
731+
} else if (ts < minTs) {
732732
getNext()
733733
} else if (currentTs == -1L || currentTs == ts) {
734734
currentTs = ts
@@ -819,22 +819,16 @@ class SymmetricHashJoinStateManagerV4(
819819
case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)
820820

821821
private def defaultInternalRow(schema: StructType): InternalRow = {
822-
InternalRow.fromSeq(schema.map(f => defaultValueForType(f.dataType)))
822+
InternalRow.fromSeq(schema.map(f => Literal.default(f.dataType).value))
823823
}
824824

825-
private def defaultValueForType(dt: DataType): Any = dt match {
826-
case BooleanType => false
827-
case ByteType => 0.toByte
828-
case ShortType => 0.toShort
829-
case IntegerType | DateType => 0
830-
case LongType | TimestampType | TimestampNTZType => 0L
831-
case FloatType => 0.0f
832-
case DoubleType => 0.0
833-
case StringType => UTF8String.EMPTY_UTF8
834-
case BinaryType => Array.emptyByteArray
835-
case st: StructType => defaultInternalRow(st)
836-
case _ => null
837-
}
825+
/**
826+
* Reusable default key row for scan boundary construction. Safe to reuse because
827+
* createKeyRow only reads this row (via BoundReference evaluations) and writes to
828+
* the projection's own internal buffer.
829+
*/
830+
private lazy val defaultKey: UnsafeRow = UnsafeProjection.create(keySchema)
831+
.apply(defaultInternalRow(keySchema))
838832

839833
/**
840834
* Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
@@ -844,8 +838,6 @@ class SymmetricHashJoinStateManagerV4(
844838
* timestamp matters for ordering in the prefix encoder.
845839
*/
846840
private def createScanBoundaryRow(timestamp: Long): UnsafeRow = {
847-
val defaultKey = UnsafeProjection.create(keySchema)
848-
.apply(defaultInternalRow(keySchema))
849841
createKeyRow(defaultKey, timestamp).copy()
850842
}
851843

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,30 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite
11581158
// startTimestamp=39 (exclusive) means entries >= 40 are scanned; endTimestamp=40 inclusive
11591159
assert(evictByTs.evictByTimestamp(40, Some(39)) === 1)
11601160
assert(get(40) === Seq(10, 50))
1161+
1162+
// --- overflow boundary: endTimestamp = Long.MaxValue ---
1163+
// Restore entries for a clean slate
1164+
Seq(20, 30, 40).foreach(append(40, _))
1165+
// endTimestamp=Long.MaxValue with no startTimestamp: evicts all entries
1166+
assert(evictByTs.evictByTimestamp(Long.MaxValue) === 5)
1167+
assert(get(40) === Seq.empty)
1168+
1169+
// --- overflow boundary: startTimestamp = Some(Long.MinValue) ---
1170+
Seq(10, 20, 30).foreach(append(40, _))
1171+
// startTimestamp=Long.MinValue (exclusive), endTimestamp=20 (inclusive):
1172+
// Long.MinValue is excluded per the contract (already evicted), so the scan
1173+
// starts from Long.MinValue + 1. Since no real entry has timestamp Long.MinValue,
1174+
// this effectively scans all entries up to endTimestamp.
1175+
assert(evictByTs.evictByTimestamp(20, Some(Long.MinValue)) === 2)
1176+
assert(get(40) === Seq(30))
1177+
1178+
// --- overflow boundary: startTimestamp = Some(Long.MaxValue) ---
1179+
Seq(10, 20).foreach(append(40, _))
1180+
// startTimestamp=Long.MaxValue (exclusive) means everything <= Long.MaxValue was already
1181+
// evicted. Since startKeyRow falls back to None, endTimestamp=50 bounds the scan.
1182+
// All remaining entries (10, 20, 30) have timestamps <= 50, so they are evicted.
1183+
assert(evictByTs.evictByTimestamp(50, Some(Long.MaxValue)) === 3)
1184+
assert(get(40) === Seq.empty)
11611185
}
11621186
}
11631187

0 commit comments

Comments
 (0)