Skip to content

Commit 25e29fc

Browse files
committed
[SPARK-56402][SS] Apply rangeScan API in stream-stream join format version 4
### What changes were proposed in this pull request? This PR proposes to apply rangeScan API in stream-stream join format version 4, which will give an improvement of scanning on matching rows for time interval join and eviction. The main idea for eviction is to perform scanning secondary index from [the end timestamp of previous scan + 1, new end timestamp], which was [None, new end timestamp]. Previously it had to go through tombstones prior batches made in prior evictions (till compaction happens), and with this change we will be able to skip those tombstones. The idea of time interval join is straightforward - we know the timestamp range of matching rows and we used it to scope it. Previously we scan all timestamps within the key from RocksDB and apply filter. We move the due of filter to RocksDB, to leverage the same effect with the above (skipping tombstones). ### Why are the changes needed? This change will give a hit to RocksDB about the exact range to scan, reducing the chance of reading tombstone a lot. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UTs, and existing UTs. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude 4.6 Opus Closes #55267 from HeartSaVioR/SPARK-56402-on-top-of-SPARK-56369. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent 5b72273 commit 25e29fc

File tree

8 files changed

+422
-31
lines changed

8 files changed

+422
-31
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6870,6 +6870,11 @@
68706870
"joinSide=<joinSide>, storeVersion=<storeVersion>, partitionId=<partitionId>.",
68716871
"Enable <configKey> as a workaround to skip null values."
68726872
]
6873+
},
6874+
"RANGE_SCAN_TIMESTAMP_OUT_OF_RANGE" : {
6875+
"message" : [
6876+
"Range scan returned a row with timestamp <timestamp> outside the expected range [<minTimestamp>, <maxTimestamp>]."
6877+
]
68736878
}
68746879
},
68756880
"sqlState" : "XXKST"

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ case class StreamingSymmetricHashJoinExec(
675675
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
676676

677677
private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
678-
case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
678+
case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) =>
679679
// inputSchema can be empty as expr should only have BoundReferences and does not require
680680
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
681681
Predicate.create(expr, Seq.empty).eval _
@@ -684,7 +684,7 @@ case class StreamingSymmetricHashJoinExec(
684684
}
685685

686686
private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
687-
case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
687+
case Some(JoinStateValueWatermarkPredicate(expr, _, _)) =>
688688
Predicate.create(expr, inputAttributes).eval _
689689
case _ =>
690690
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
@@ -905,21 +905,25 @@ case class StreamingSymmetricHashJoinExec(
905905
*/
906906
def removeOldState(): Long = {
907907
stateWatermarkPredicate match {
908-
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
908+
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
909909
joinStateManager match {
910910
case s: SupportsEvictByCondition =>
911911
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)
912912

913913
case s: SupportsEvictByTimestamp =>
914-
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
914+
s.evictByTimestamp(
915+
watermarkMsToStateTimestamp(stateWatermark),
916+
prevStateWatermark.map(watermarkMsToStateTimestamp))
915917
}
916-
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
918+
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
917919
joinStateManager match {
918920
case s: SupportsEvictByCondition =>
919921
s.evictByValueCondition(stateValueWatermarkPredicateFunc)
920922

921923
case s: SupportsEvictByTimestamp =>
922-
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
924+
s.evictByTimestamp(
925+
watermarkMsToStateTimestamp(stateWatermark),
926+
prevStateWatermark.map(watermarkMsToStateTimestamp))
923927
}
924928
case _ => 0L
925929
}
@@ -937,21 +941,25 @@ case class StreamingSymmetricHashJoinExec(
937941
*/
938942
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
939943
stateWatermarkPredicate match {
940-
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
944+
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
941945
joinStateManager match {
942946
case s: SupportsEvictByCondition =>
943947
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)
944948

945949
case s: SupportsEvictByTimestamp =>
946-
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
950+
s.evictAndReturnByTimestamp(
951+
watermarkMsToStateTimestamp(stateWatermark),
952+
prevStateWatermark.map(watermarkMsToStateTimestamp))
947953
}
948-
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
954+
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
949955
joinStateManager match {
950956
case s: SupportsEvictByCondition =>
951957
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)
952958

953959
case s: SupportsEvictByTimestamp =>
954-
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
960+
s.evictAndReturnByTimestamp(
961+
watermarkMsToStateTimestamp(stateWatermark),
962+
prevStateWatermark.map(watermarkMsToStateTimestamp))
955963
}
956964
case _ => Iterator.empty
957965
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,37 @@ object StreamingSymmetricHashJoinHelper extends Logging {
4545
def desc: String
4646
override def toString: String = s"$desc: $expr"
4747
}
48-
/** Predicate for watermark on state keys */
49-
case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long)
48+
/**
49+
* Predicate for watermark on state keys.
50+
*
51+
* @param stateWatermark Current batch's eviction watermark. Entries with timestamp
52+
* at or below this value are eligible for eviction in this batch.
53+
* @param prevStateWatermark Previous batch's eviction watermark, i.e. the watermark
54+
* used for filtering late events in the current batch.
55+
* Entries with timestamp at or below this value were already
56+
* evicted in prior batches, so the effective range of entries
57+
* to evict in this batch is `(prevStateWatermark, stateWatermark]`.
58+
* State manager implementations can leverage this lower bound
59+
* to optimize eviction (e.g. narrowing the scan range to skip
60+
* already-evicted entries). `None` means we do not have a known
61+
* lower bound (e.g. the first batch after restart), in which
62+
* case eviction must consider all entries up to `stateWatermark`.
63+
*/
64+
case class JoinStateKeyWatermarkPredicate(
65+
expr: Expression,
66+
stateWatermark: Long,
67+
prevStateWatermark: Option[Long] = None)
5068
extends JoinStateWatermarkPredicate {
5169
def desc: String = "key predicate"
5270
}
53-
/** Predicate for watermark on state values */
54-
case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long)
71+
/**
72+
* Predicate for watermark on state values. See [[JoinStateKeyWatermarkPredicate]] for
73+
* the semantics of `stateWatermark` and `prevStateWatermark`.
74+
*/
75+
case class JoinStateValueWatermarkPredicate(
76+
expr: Expression,
77+
stateWatermark: Long,
78+
prevStateWatermark: Option[Long] = None)
5579
extends JoinStateWatermarkPredicate {
5680
def desc: String = "value predicate"
5781
}
@@ -185,6 +209,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
185209
rightKeys: Seq[Expression],
186210
condition: Option[Expression],
187211
eventTimeWatermarkForEviction: Option[Long],
212+
eventTimeWatermarkForLateEvents: Option[Long],
188213
useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = {
189214

190215
// Perform assertions against multiple event time columns in the same DataFrame. This method
@@ -215,20 +240,30 @@ object StreamingSymmetricHashJoinHelper extends Logging {
215240
expr.map { e =>
216241
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
217242
// is defined
218-
JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
243+
JoinStateKeyWatermarkPredicate(
244+
e,
245+
eventTimeWatermarkForEviction.get,
246+
eventTimeWatermarkForLateEvents)
219247
}
220248
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
221249
val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
222250
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
223251
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
224252
condition,
225253
eventTimeWatermarkForEviction)
254+
val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ =>
255+
StreamingJoinHelper.getStateValueWatermark(
256+
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
257+
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
258+
condition,
259+
eventTimeWatermarkForLateEvents)
260+
}
226261
val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
227262
val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
228263
expr.map { e =>
229264
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
230265
// is defined
231-
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
266+
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark)
232267
}
233268
} else {
234269
None

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
3232
import org.apache.spark.sql.execution.metric.SQLMetric
3333
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
3434
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
35-
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
35+
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, RangeScanBoundaryUtils, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
3636
import org.apache.spark.sql.internal.SQLConf
3737
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType}
3838
import org.apache.spark.util.NextIterator
@@ -184,15 +184,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
184184
trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
185185
import SymmetricHashJoinStateManager._
186186

187-
/** Evict the state by timestamp. Returns the number of values evicted. */
188-
def evictByTimestamp(endTimestamp: Long): Long
187+
/**
188+
* Evict the state by timestamp. Returns the number of values evicted.
189+
*
190+
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
191+
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
192+
* assumed to have been evicted already (e.g. from the previous batch). When provided,
193+
* the scan starts from startTimestamp + 1.
194+
*/
195+
def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long
189196

190197
/**
191198
* Evict the state by timestamp and return the evicted key-value pairs.
192199
*
193200
* It is caller's responsibility to consume the whole iterator.
201+
*
202+
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
203+
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
204+
* assumed to have been evicted already (e.g. from the previous batch). When provided,
205+
* the scan starts from startTimestamp + 1.
194206
*/
195-
def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
207+
def evictAndReturnByTimestamp(
208+
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair]
196209
}
197210

198211
/**
@@ -519,11 +532,11 @@ class SymmetricHashJoinStateManagerV4(
519532
}
520533
}
521534

522-
override def evictByTimestamp(endTimestamp: Long): Long = {
535+
override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = {
523536
require(hasEventTime,
524537
"evictByTimestamp requires event time; secondary index was not populated")
525538
var removed = 0L
526-
tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
539+
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted =>
527540
val key = evicted.key
528541
val timestamp = evicted.timestamp
529542
val numValues = evicted.numValues
@@ -537,12 +550,13 @@ class SymmetricHashJoinStateManagerV4(
537550
removed
538551
}
539552

540-
override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = {
553+
override def evictAndReturnByTimestamp(
554+
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = {
541555
require(hasEventTime,
542556
"evictAndReturnByTimestamp requires event time; secondary index was not populated")
543557
val reusableKeyToValuePair = KeyToValuePair()
544558

545-
tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
559+
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted =>
546560
val key = evicted.key
547561
val timestamp = evicted.timestamp
548562
val values = keyWithTsToValues.get(key, timestamp)
@@ -663,14 +677,33 @@ class SymmetricHashJoinStateManagerV4(
663677

664678
/**
665679
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
666-
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
680+
* When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient
681+
* range access; falls back to prefixScan otherwise to stay within the key's scope.
682+
*
683+
* When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are
684+
* filtered out so both code paths produce identical results.
667685
*/
668686
def getValuesInRange(
669687
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
670688
val reusableGetValuesResult = new GetValuesResult()
689+
// Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires
690+
// an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue.
691+
val useRangeScan = maxTs < Long.MaxValue
671692

672693
new NextIterator[GetValuesResult] {
673-
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
694+
private val iter = if (useRangeScan) {
695+
// startKey must be copied because the second createKeyRow call below reuses
696+
// the same projection buffer and would otherwise overwrite its contents.
697+
// endKey does not need a copy: rangeScanWithMultiValues encodes both bounds
698+
// to independent byte arrays eagerly at call time, and the scope of endKey
699+
// ends with the call of rangeScanWithMultiValues.
700+
val startKey = createKeyRow(key, minTs).copy()
701+
// rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1
702+
val endKey = Some(createKeyRow(key, maxTs + 1))
703+
stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName)
704+
} else {
705+
stateStore.prefixScanWithMultiValues(key, colFamilyName)
706+
}
674707

675708
private var currentTs = -1L
676709
private var pastUpperBound = false
@@ -697,6 +730,11 @@ class SymmetricHashJoinStateManagerV4(
697730
val unsafeRowPair = iter.next()
698731
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
699732

733+
if (useRangeScan && (ts < minTs || ts > maxTs)) {
734+
throw StateStoreErrors.streamStreamJoinRangeScanTimestampOutOfRange(
735+
ts, minTs, maxTs)
736+
}
737+
700738
if (ts > maxTs) {
701739
pastUpperBound = true
702740
getNext()
@@ -773,6 +811,8 @@ class SymmetricHashJoinStateManagerV4(
773811
isInternal = true
774812
)
775813

814+
// Returns an UnsafeRow backed by a reused projection buffer. Callers that need to
815+
// hold the row beyond the immediate state store call must invoke copy() on the result.
776816
private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
777817
TimestampKeyStateEncoder.attachTimestamp(
778818
attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
@@ -788,9 +828,60 @@ class SymmetricHashJoinStateManagerV4(
788828

789829
case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)
790830

791-
// NOTE: This assumes we consume the whole iterator to trigger completion.
792-
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
793-
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
831+
// Reusable default key row for scan boundary construction; see
832+
// [[RangeScanBoundaryUtils]] for rationale. Safe to reuse because createKeyRow
833+
// only reads this row (via BoundReference evaluations) and writes to the
834+
// projection's own internal buffer. Correctness relies on real stored entries
835+
// never having internally-null key fields, which is preserved by join-key
836+
// expressions being evaluated via the user's expression encoder. Preserve this
837+
// invariant if you change how entries are written.
838+
private lazy val defaultKey: UnsafeRow = RangeScanBoundaryUtils.defaultUnsafeRow(keySchema)
839+
840+
/**
841+
* Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
842+
* TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields].
843+
* We need a full-schema row (not just the timestamp) because the encoder expects all
844+
* key columns to be present. Default values are used for the key fields since only the
845+
* timestamp matters for ordering in the prefix encoder.
846+
*/
847+
private def createScanBoundaryRow(timestamp: Long): UnsafeRow = {
848+
createKeyRow(defaultKey, timestamp).copy()
849+
}
850+
851+
/**
852+
* Scan keys eligible for eviction within the timestamp range.
853+
*
854+
* This assumes we consume the whole iterator to trigger completion.
855+
*
856+
* @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are
857+
* eligible for eviction.
858+
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp
859+
* are assumed to have been evicted already. The scan starts from startTimestamp + 1.
860+
*/
861+
def scanEvictedKeys(
862+
endTimestamp: Long,
863+
startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = {
864+
// If startTimestamp == Long.MaxValue, everything has already been evicted;
865+
// nothing can match, so return immediately.
866+
if (startTimestamp.contains(Long.MaxValue)) {
867+
return Iterator.empty
868+
}
869+
870+
// rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive.
871+
// startTimestamp is exclusive (already evicted), so we seek from st + 1.
872+
val startKeyRow = startTimestamp.map { st =>
873+
createScanBoundaryRow(st + 1)
874+
}
875+
// endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound.
876+
// When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is
877+
// safe because rangeScanWithMultiValues with no end key uses the column-family prefix
878+
// as the upper bound, naturally scoping the scan within this column family.
879+
val endKeyRow = if (endTimestamp < Long.MaxValue) {
880+
Some(createScanBoundaryRow(endTimestamp + 1))
881+
} else {
882+
None
883+
}
884+
val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName)
794885
new NextIterator[EvictedKeysResult]() {
795886
var currentKeyRow: UnsafeRow = null
796887
var currentEventTime: Long = -1L

0 commit comments

Comments
 (0)