Skip to content

Commit 158dff8

Browse files
committed
Roll back the usage on stream-stream join, add test for range scan
1 parent ef84c49 commit 158dff8

2 files changed

Lines changed: 356 additions & 25 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.TaskContext
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID}
2929
import org.apache.spark.sql.catalyst.InternalRow
30-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
30+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
3131
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
3232
import org.apache.spark.sql.execution.metric.SQLMetric
3333
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
@@ -647,25 +647,17 @@ class SymmetricHashJoinStateManagerV4(
647647

648648
/**
649649
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
650-
*
651-
* When a bounded range is provided, leverages RocksDB's native seek and upper bound via
652-
* [[StateStore.scanWithMultiValues]] to avoid reading entries outside the range.
653-
* Falls back to [[StateStore.prefixScanWithMultiValues]] when the full range is requested.
650+
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
654651
*/
655652
def getValuesInRange(
656653
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
657654
val reusableGetValuesResult = new GetValuesResult()
658655

659656
new NextIterator[GetValuesResult] {
660-
private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) {
661-
stateStore.prefixScanWithMultiValues(key, colFamilyName)
662-
} else {
663-
val startKeyRow = createKeyRow(key, minTs).copy()
664-
val endKeyRow = createKeyRow(key, maxTs + 1)
665-
stateStore.scanWithMultiValues(Some(startKeyRow), Some(endKeyRow), colFamilyName)
666-
}
657+
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
667658

668659
private var currentTs = -1L
660+
private var pastUpperBound = false
669661
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
670662

671663
private def flushAccumulated(): GetValuesResult = {
@@ -683,13 +675,18 @@ class SymmetricHashJoinStateManagerV4(
683675

684676
@tailrec
685677
override protected def getNext(): GetValuesResult = {
686-
if (!iter.hasNext) {
678+
if (pastUpperBound || !iter.hasNext) {
687679
flushAccumulated()
688680
} else {
689681
val unsafeRowPair = iter.next()
690682
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
691683

692-
if (currentTs == -1L || currentTs == ts) {
684+
if (ts > maxTs) {
685+
pastUpperBound = true
686+
getNext()
687+
} else if (ts < minTs) {
688+
getNext()
689+
} else if (currentTs == -1L || currentTs == ts) {
693690
currentTs = ts
694691
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
695692
getNext()
@@ -773,18 +770,11 @@ class SymmetricHashJoinStateManagerV4(
773770
stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
774771
}
775772

776-
private lazy val dummyKeyRow: UnsafeRow = {
777-
val defaultValues = keySchema.fields.map(f => Literal.default(f.dataType).eval())
778-
val projection = UnsafeProjection.create(keySchema)
779-
projection(new GenericInternalRow(defaultValues)).copy()
780-
}
781-
782773
case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)
783774

784775
// NOTE: This assumes we consume the whole iterator to trigger completion.
785776
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
786-
val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1)
787-
val evictIterator = stateStore.scanWithMultiValues(None, Some(endKeyRow), colFamilyName)
777+
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
788778
new NextIterator[EvictedKeysResult]() {
789779
var currentKeyRow: UnsafeRow = null
790780
var currentEventTime: Long = -1L
@@ -799,19 +789,25 @@ class SymmetricHashJoinStateManagerV4(
799789
val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key)
800790

801791
if (keyRow == currentKeyRow && ts == currentEventTime) {
792+
// new value with same (key, ts)
802793
count += 1
803794
} else if (ts > endTimestamp) {
804-
// Safety check for boundary edge case: a small number of entries at exactly
805-
// endTimestamp + 1 may leak through the upper bound because the encoded end key
806-
// includes a join key suffix.
795+
// we found the timestamp beyond the range - we shouldn't continue further
807796
isBeyondUpperBound = true
797+
798+
// We don't need to construct the last (key, ts) into EvictedKeysResult - the code
799+
// after loop will handle that if there is leftover. That said, we do not reset the
800+
// current (key, ts) info here.
808801
} else if (currentKeyRow == null && currentEventTime == -1L) {
802+
// first value to process
809803
currentKeyRow = keyRow.copy()
810804
currentEventTime = ts
811805
count = 1
812806
} else {
807+
// construct the last (key, ts) into EvictedKeysResult
813808
ret = EvictedKeysResult(currentKeyRow, currentEventTime, count)
814809

810+
// register the next (key, ts) to process
815811
currentKeyRow = keyRow.copy()
816812
currentEventTime = ts
817813
count = 1
@@ -821,8 +817,10 @@ class SymmetricHashJoinStateManagerV4(
821817
if (ret != null) {
822818
ret
823819
} else if (count > 0) {
820+
// there is a final leftover (key, ts) to return
824821
ret = EvictedKeysResult(currentKeyRow, currentEventTime, count)
825822

823+
// we shouldn't continue further
826824
currentKeyRow = null
827825
currentEventTime = -1L
828826
count = 0

0 commit comments

Comments
 (0)