Skip to content

Commit b6ed9a6

Browse files
committed
Apply scan/scanWithMultiValues to stream-stream join V4
Use bounded scan ranges in stream-stream join V4 operators to narrow the iteration scope during eviction and value lookup: - scanEvictedKeys (TsWithKeyTypeStore): use scanWithMultiValues with startKey derived from the previous batch's state watermark and endKey from the current eviction threshold. Thread prevBatchStateWatermark through JoinStateWatermarkPredicate -> SupportsEvictByTimestamp. - getValuesInRange (KeyWithTsToValuesStore): use scanWithMultiValues for bounded timestamp ranges, falling back to prefixScan for full range. Create default-valued boundary rows to avoid NullPointerException when the join key schema contains non-nullable fields (e.g. window structs).
1 parent 5fd1438 commit b6ed9a6

6 files changed

Lines changed: 281 additions & 35 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ case class StreamingSymmetricHashJoinExec(
663663
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)
664664

665665
private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
666-
case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
666+
case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) =>
667667
// inputSchema can be empty as expr should only have BoundReferences and does not require
668668
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
669669
Predicate.create(expr, Seq.empty).eval _
@@ -672,7 +672,7 @@ case class StreamingSymmetricHashJoinExec(
672672
}
673673

674674
private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
675-
case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
675+
case Some(JoinStateValueWatermarkPredicate(expr, _, _)) =>
676676
Predicate.create(expr, inputAttributes).eval _
677677
case _ =>
678678
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
@@ -893,21 +893,25 @@ case class StreamingSymmetricHashJoinExec(
893893
*/
894894
def removeOldState(): Long = {
895895
stateWatermarkPredicate match {
896-
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
896+
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
897897
joinStateManager match {
898898
case s: SupportsEvictByCondition =>
899899
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)
900900

901901
case s: SupportsEvictByTimestamp =>
902-
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
902+
s.evictByTimestamp(
903+
watermarkMsToStateTimestamp(stateWatermark),
904+
prevStateWatermark.map(watermarkMsToStateTimestamp))
903905
}
904-
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
906+
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
905907
joinStateManager match {
906908
case s: SupportsEvictByCondition =>
907909
s.evictByValueCondition(stateValueWatermarkPredicateFunc)
908910

909911
case s: SupportsEvictByTimestamp =>
910-
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
912+
s.evictByTimestamp(
913+
watermarkMsToStateTimestamp(stateWatermark),
914+
prevStateWatermark.map(watermarkMsToStateTimestamp))
911915
}
912916
case _ => 0L
913917
}
@@ -925,21 +929,25 @@ case class StreamingSymmetricHashJoinExec(
925929
*/
926930
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
927931
stateWatermarkPredicate match {
928-
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
932+
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
929933
joinStateManager match {
930934
case s: SupportsEvictByCondition =>
931935
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)
932936

933937
case s: SupportsEvictByTimestamp =>
934-
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
938+
s.evictAndReturnByTimestamp(
939+
watermarkMsToStateTimestamp(stateWatermark),
940+
prevStateWatermark.map(watermarkMsToStateTimestamp))
935941
}
936-
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
942+
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
937943
joinStateManager match {
938944
case s: SupportsEvictByCondition =>
939945
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)
940946

941947
case s: SupportsEvictByTimestamp =>
942-
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
948+
s.evictAndReturnByTimestamp(
949+
watermarkMsToStateTimestamp(stateWatermark),
950+
prevStateWatermark.map(watermarkMsToStateTimestamp))
943951
}
944952
case _ => Iterator.empty
945953
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,18 @@ object StreamingSymmetricHashJoinHelper extends Logging {
4646
override def toString: String = s"$desc: $expr"
4747
}
4848
/** Predicate for watermark on state keys */
49-
case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long)
49+
case class JoinStateKeyWatermarkPredicate(
50+
expr: Expression,
51+
stateWatermark: Long,
52+
prevStateWatermark: Option[Long] = None)
5053
extends JoinStateWatermarkPredicate {
5154
def desc: String = "key predicate"
5255
}
5356
/** Predicate for watermark on state values */
54-
case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long)
57+
case class JoinStateValueWatermarkPredicate(
58+
expr: Expression,
59+
stateWatermark: Long,
60+
prevStateWatermark: Option[Long] = None)
5561
extends JoinStateWatermarkPredicate {
5662
def desc: String = "value predicate"
5763
}
@@ -185,6 +191,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
185191
rightKeys: Seq[Expression],
186192
condition: Option[Expression],
187193
eventTimeWatermarkForEviction: Option[Long],
194+
eventTimeWatermarkForLateEvents: Option[Long],
188195
useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = {
189196

190197
// Perform assertions against multiple event time columns in the same DataFrame. This method
@@ -215,20 +222,30 @@ object StreamingSymmetricHashJoinHelper extends Logging {
215222
expr.map { e =>
216223
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
217224
// is defined
218-
JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
225+
JoinStateKeyWatermarkPredicate(
226+
e,
227+
eventTimeWatermarkForEviction.get,
228+
eventTimeWatermarkForLateEvents)
219229
}
220230
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
221231
val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
222232
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
223233
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
224234
condition,
225235
eventTimeWatermarkForEviction)
236+
val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ =>
237+
StreamingJoinHelper.getStateValueWatermark(
238+
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
239+
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
240+
condition,
241+
eventTimeWatermarkForLateEvents)
242+
}
226243
val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
227244
val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
228245
expr.map { e =>
229246
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
230247
// is defined
231-
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
248+
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark)
232249
}
233250
} else {
234251
None

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper
3434
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
3535
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
3636
import org.apache.spark.sql.internal.SQLConf
37-
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType}
37+
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType}
38+
import org.apache.spark.unsafe.types.UTF8String
3839
import org.apache.spark.util.NextIterator
3940

4041
/**
@@ -184,15 +185,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
184185
trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
185186
import SymmetricHashJoinStateManager._
186187

187-
/** Evict the state by timestamp. Returns the number of values evicted. */
188-
def evictByTimestamp(endTimestamp: Long): Long
188+
/**
189+
* Evict the state by timestamp. Returns the number of values evicted.
190+
*
191+
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
192+
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
193+
* assumed to have been evicted already (e.g. from the previous batch). When provided,
194+
* the scan starts from startTimestamp + 1.
195+
*/
196+
def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long
189197

190198
/**
191199
* Evict the state by timestamp and return the evicted key-value pairs.
192200
*
193201
* It is caller's responsibility to consume the whole iterator.
202+
*
203+
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
204+
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
205+
* assumed to have been evicted already (e.g. from the previous batch). When provided,
206+
* the scan starts from startTimestamp + 1.
194207
*/
195-
def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
208+
def evictAndReturnByTimestamp(
209+
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair]
196210
}
197211

198212
/**
@@ -519,11 +533,11 @@ class SymmetricHashJoinStateManagerV4(
519533
}
520534
}
521535

522-
override def evictByTimestamp(endTimestamp: Long): Long = {
536+
override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = {
523537
require(hasEventTime,
524538
"evictByTimestamp requires event time; secondary index was not populated")
525539
var removed = 0L
526-
tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
540+
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted =>
527541
val key = evicted.key
528542
val timestamp = evicted.timestamp
529543
val numValues = evicted.numValues
@@ -537,12 +551,13 @@ class SymmetricHashJoinStateManagerV4(
537551
removed
538552
}
539553

540-
override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = {
554+
override def evictAndReturnByTimestamp(
555+
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = {
541556
require(hasEventTime,
542557
"evictAndReturnByTimestamp requires event time; secondary index was not populated")
543558
val reusableKeyToValuePair = KeyToValuePair()
544559

545-
tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
560+
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted =>
546561
val key = evicted.key
547562
val timestamp = evicted.timestamp
548563
val values = keyWithTsToValues.get(key, timestamp)
@@ -663,17 +678,30 @@ class SymmetricHashJoinStateManagerV4(
663678

664679
/**
665680
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
666-
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
681+
* When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient
682+
* range access; falls back to prefixScan otherwise to stay within the key's scope.
683+
*
684+
* When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are
685+
* filtered out so both code paths produce identical results.
667686
*/
668687
def getValuesInRange(
669688
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
670689
val reusableGetValuesResult = new GetValuesResult()
690+
// Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires
691+
// an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue.
692+
val useRangeScan = maxTs < Long.MaxValue
671693

672694
new NextIterator[GetValuesResult] {
673-
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
695+
private val iter = if (useRangeScan) {
696+
val startKey = createKeyRow(key, minTs).copy()
697+
// rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1
698+
val endKey = Some(createKeyRow(key, maxTs + 1))
699+
stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName)
700+
} else {
701+
stateStore.prefixScanWithMultiValues(key, colFamilyName)
702+
}
674703

675704
private var currentTs = -1L
676-
private var pastUpperBound = false
677705
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
678706

679707
private def flushAccumulated(): GetValuesResult = {
@@ -691,16 +719,16 @@ class SymmetricHashJoinStateManagerV4(
691719

692720
@tailrec
693721
override protected def getNext(): GetValuesResult = {
694-
if (pastUpperBound || !iter.hasNext) {
722+
if (!iter.hasNext) {
695723
flushAccumulated()
696724
} else {
697725
val unsafeRowPair = iter.next()
698726
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
699727

700-
if (ts > maxTs) {
701-
pastUpperBound = true
702-
getNext()
703-
} else if (ts < minTs) {
728+
// Filter out entries outside [minTs, maxTs]. This is essential when using
729+
// prefixScan (which returns all timestamps for the key) and serves as a
730+
// safety guard for rangeScan as well.
731+
if (ts < minTs || ts > maxTs) {
704732
getNext()
705733
} else if (currentTs == -1L || currentTs == ts) {
706734
currentTs = ts
@@ -773,6 +801,8 @@ class SymmetricHashJoinStateManagerV4(
773801
isInternal = true
774802
)
775803

804+
// Returns an UnsafeRow backed by a reused projection buffer. Callers that need to
805+
// hold the row beyond the immediate state store call must invoke copy() on the result.
776806
private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
777807
TimestampKeyStateEncoder.attachTimestamp(
778808
attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
@@ -788,9 +818,66 @@ class SymmetricHashJoinStateManagerV4(
788818

789819
case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)
790820

791-
// NOTE: This assumes we consume the whole iterator to trigger completion.
792-
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
793-
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
821+
private def defaultInternalRow(schema: StructType): InternalRow = {
822+
InternalRow.fromSeq(schema.map(f => defaultValueForType(f.dataType)))
823+
}
824+
825+
private def defaultValueForType(dt: DataType): Any = dt match {
826+
case BooleanType => false
827+
case ByteType => 0.toByte
828+
case ShortType => 0.toShort
829+
case IntegerType | DateType => 0
830+
case LongType | TimestampType | TimestampNTZType => 0L
831+
case FloatType => 0.0f
832+
case DoubleType => 0.0
833+
case StringType => UTF8String.EMPTY_UTF8
834+
case BinaryType => Array.emptyByteArray
835+
case st: StructType => defaultInternalRow(st)
836+
case _ => null
837+
}
838+
839+
/**
840+
* Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
841+
* TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields].
842+
* We need a full-schema row (not just the timestamp) because the encoder expects all
843+
* key columns to be present. Default values are used for the key fields since only the
844+
* timestamp matters for ordering in the prefix encoder.
845+
*/
846+
private def createScanBoundaryRow(timestamp: Long): UnsafeRow = {
847+
val defaultKey = UnsafeProjection.create(keySchema)
848+
.apply(defaultInternalRow(keySchema))
849+
createKeyRow(defaultKey, timestamp).copy()
850+
}
851+
852+
/**
853+
* Scan keys eligible for eviction within the timestamp range.
854+
*
855+
* This assumes we consume the whole iterator to trigger completion.
856+
*
857+
* @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are
858+
* eligible for eviction.
859+
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp
860+
* are assumed to have been evicted already. The scan starts from startTimestamp + 1.
861+
*/
862+
def scanEvictedKeys(
863+
endTimestamp: Long,
864+
startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = {
865+
// rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive.
866+
// startTimestamp is exclusive (already evicted), so we seek from st + 1.
867+
val startKeyRow = startTimestamp.flatMap { st =>
868+
if (st < Long.MaxValue) Some(createScanBoundaryRow(st + 1))
869+
else None
870+
}
871+
// endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound.
872+
// When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is
873+
// safe because rangeScanWithMultiValues with no end key uses the column-family prefix
874+
// as the upper bound, naturally scoping the scan within this column family.
875+
val endKeyRow = if (endTimestamp < Long.MaxValue) {
876+
Some(createScanBoundaryRow(endTimestamp + 1))
877+
} else {
878+
None
879+
}
880+
val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName)
794881
new NextIterator[EvictedKeysResult]() {
795882
var currentKeyRow: UnsafeRow = null
796883
var currentEventTime: Long = -1L

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,13 +528,19 @@ class IncrementalExecution(
528528
case j: StreamingSymmetricHashJoinExec =>
529529
val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get)
530530
val iwEviction = inputWatermarkForEviction(j.stateInfo.get)
531+
// Only use the late-events watermark as the scan lower bound when a previous
532+
// batch actually existed. In the very first batch the watermark propagation
533+
// yields Some(0) even though no state has been evicted yet, which would
534+
// incorrectly skip entries at timestamp 0.
535+
val prevBatchLateEventsWm =
536+
if (prevOffsetSeqMetadata.isDefined) iwLateEvents else None
531537
j.copy(
532538
eventTimeWatermarkForLateEvents = iwLateEvents,
533539
eventTimeWatermarkForEviction = iwEviction,
534540
stateWatermarkPredicates =
535541
StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
536542
j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
537-
iwEviction, !allowMultipleStatefulOperators)
543+
iwEviction, prevBatchLateEventsWm, !allowMultipleStatefulOperators)
538544
)
539545
}
540546
}

0 commit comments

Comments
 (0)