Skip to content

Commit 137c25b

Browse files
committed
WIP introduce scan to scope the iteration in State Store
1 parent 73a272d commit 137c25b

6 files changed

Lines changed: 457 additions & 23 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.TaskContext
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.internal.LogKeys.{END_INDEX, START_INDEX, STATE_STORE_ID}
2929
import org.apache.spark.sql.catalyst.InternalRow
30-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
30+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow}
3131
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
3232
import org.apache.spark.sql.execution.metric.SQLMetric
3333
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
@@ -647,17 +647,25 @@ class SymmetricHashJoinStateManagerV4(
647647

648648
/**
649649
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
650-
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
650+
*
651+
* When a bounded range is provided, leverages RocksDB's native seek and upper bound via
652+
* [[StateStore.scanWithMultiValues]] to avoid reading entries outside the range.
653+
* Falls back to [[StateStore.prefixScanWithMultiValues]] when the full range is requested.
651654
*/
652655
def getValuesInRange(
653656
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
654657
val reusableGetValuesResult = new GetValuesResult()
655658

656659
new NextIterator[GetValuesResult] {
657-
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
660+
private val iter = if (minTs == Long.MinValue && maxTs == Long.MaxValue) {
661+
stateStore.prefixScanWithMultiValues(key, colFamilyName)
662+
} else {
663+
val startKeyRow = createKeyRow(key, minTs)
664+
val endKeyRow = createKeyRow(key, maxTs + 1)
665+
stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName)
666+
}
658667

659668
private var currentTs = -1L
660-
private var pastUpperBound = false
661669
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()
662670

663671
private def flushAccumulated(): GetValuesResult = {
@@ -675,18 +683,13 @@ class SymmetricHashJoinStateManagerV4(
675683

676684
@tailrec
677685
override protected def getNext(): GetValuesResult = {
678-
if (pastUpperBound || !iter.hasNext) {
686+
if (!iter.hasNext) {
679687
flushAccumulated()
680688
} else {
681689
val unsafeRowPair = iter.next()
682690
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
683691

684-
if (ts > maxTs) {
685-
pastUpperBound = true
686-
getNext()
687-
} else if (ts < minTs) {
688-
getNext()
689-
} else if (currentTs == -1L || currentTs == ts) {
692+
if (currentTs == -1L || currentTs == ts) {
690693
currentTs = ts
691694
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
692695
getNext()
@@ -770,11 +773,17 @@ class SymmetricHashJoinStateManagerV4(
770773
stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
771774
}
772775

776+
private lazy val dummyKeyRow: UnsafeRow = {
777+
val projection = UnsafeProjection.create(keySchema)
778+
projection(new GenericInternalRow(keySchema.length)).copy()
779+
}
780+
773781
case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)
774782

775783
// NOTE: This assumes we consume the whole iterator to trigger completion.
776784
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
777-
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
785+
val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1)
786+
val evictIterator = stateStore.scanWithMultiValues(None, endKeyRow, colFamilyName)
778787
new NextIterator[EvictedKeysResult]() {
779788
var currentKeyRow: UnsafeRow = null
780789
var currentEventTime: Long = -1L
@@ -789,25 +798,19 @@ class SymmetricHashJoinStateManagerV4(
789798
val ts = TimestampKeyStateEncoder.extractTimestamp(kv.key)
790799

791800
if (keyRow == currentKeyRow && ts == currentEventTime) {
792-
// new value with same (key, ts)
793801
count += 1
794802
} else if (ts > endTimestamp) {
795-
// we found the timestamp beyond the range - we shouldn't continue further
803+
// Safety check for boundary edge case: a small number of entries at exactly
804+
// endTimestamp + 1 may leak through the upper bound because the encoded end key
805+
// includes a join key suffix.
796806
isBeyondUpperBound = true
797-
798-
// We don't need to construct the last (key, ts) into EvictedKeysResult - the code
799-
// after loop will handle that if there is leftover. That said, we do not reset the
800-
// current (key, ts) info here.
801807
} else if (currentKeyRow == null && currentEventTime == -1L) {
802-
// first value to process
803808
currentKeyRow = keyRow.copy()
804809
currentEventTime = ts
805810
count = 1
806811
} else {
807-
// construct the last (key, ts) into EvictedKeysResult
808812
ret = EvictedKeysResult(currentKeyRow, currentEventTime, count)
809813

810-
// register the next (key, ts) to process
811814
currentKeyRow = keyRow.copy()
812815
currentEventTime = ts
813816
count = 1
@@ -817,10 +820,8 @@ class SymmetricHashJoinStateManagerV4(
817820
if (ret != null) {
818821
ret
819822
} else if (count > 0) {
820-
// there is a final leftover (key, ts) to return
821823
ret = EvictedKeysResult(currentKeyRow, currentEventTime, count)
822824

823-
// we shouldn't continue further
824825
currentKeyRow = null
825826
currentEventTime = -1L
826827
count = 0

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,86 @@ class RocksDB(
16621662
}
16631663
}
16641664

1665+
/**
1666+
* Scan key-value pairs in the range [startKey, endKey).
1667+
*
1668+
* @param startKey None to seek to the beginning of the column family,
1669+
* or Some(key) to seek to the given start position (inclusive).
1670+
* @param endKey The exclusive upper bound for the scan (encoded key bytes).
1671+
* @param cfName The column family name.
1672+
* @return An iterator of ByteArrayPairs in the given range.
1673+
*/
1674+
def scan(
1675+
startKey: Option[Array[Byte]],
1676+
endKey: Array[Byte],
1677+
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = {
1678+
updateMemoryUsageIfNeeded()
1679+
1680+
val updatedEndKey = if (useColumnFamilies) {
1681+
encodeStateRowWithPrefix(endKey, cfName)
1682+
} else {
1683+
endKey
1684+
}
1685+
1686+
val seekTarget = startKey match {
1687+
case Some(key) =>
1688+
if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key
1689+
case None =>
1690+
if (useColumnFamilies) encodeStateRowWithPrefix(Array.emptyByteArray, cfName)
1691+
else null
1692+
}
1693+
1694+
val upperBoundSlice = new Slice(updatedEndKey)
1695+
val scanReadOptions = new ReadOptions()
1696+
scanReadOptions.setIterateUpperBound(upperBoundSlice)
1697+
1698+
val iter = db.newIterator(scanReadOptions)
1699+
if (seekTarget != null) {
1700+
iter.seek(seekTarget)
1701+
} else {
1702+
iter.seekToFirst()
1703+
}
1704+
1705+
def closeResources(): Unit = {
1706+
iter.close()
1707+
scanReadOptions.close()
1708+
upperBoundSlice.close()
1709+
}
1710+
1711+
Option(TaskContext.get()).foreach { tc =>
1712+
tc.addTaskCompletionListener[Unit] { _ => closeResources() }
1713+
}
1714+
1715+
new NextIterator[ByteArrayPair] {
1716+
override protected def getNext(): ByteArrayPair = {
1717+
if (iter.isValid) {
1718+
val key = if (useColumnFamilies) {
1719+
decodeStateRowWithPrefix(iter.key)._1
1720+
} else {
1721+
iter.key
1722+
}
1723+
1724+
val value = if (conf.rowChecksumEnabled) {
1725+
KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
1726+
readVerifier, iter.key, iter.value, delimiterSize)
1727+
} else {
1728+
iter.value
1729+
}
1730+
1731+
byteArrayPair.set(key, value)
1732+
iter.next()
1733+
byteArrayPair
1734+
} else {
1735+
finished = true
1736+
closeResources()
1737+
null
1738+
}
1739+
}
1740+
1741+
override protected def close(): Unit = closeResources()
1742+
}
1743+
}
1744+
16651745
def release(): Unit = {}
16661746

16671747
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,63 @@ private[sql] class RocksDBStateStoreProvider
549549
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
550550
}
551551

552+
override def scan(
553+
startKey: Option[UnsafeRow],
554+
endKey: UnsafeRow,
555+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
556+
validateAndTransitionState(UPDATE)
557+
verifyColFamilyOperations("scan", colFamilyName)
558+
559+
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
560+
val encodedStartKey = startKey.map(kvEncoder._1.encodeKey)
561+
val encodedEndKey = kvEncoder._1.encodeKey(endKey)
562+
563+
val rowPair = new UnsafeRowPair()
564+
val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName)
565+
val iter = rocksDbIter.map { kv =>
566+
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
567+
kvEncoder._2.decodeValue(kv.value))
568+
rowPair
569+
}
570+
571+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
572+
}
573+
574+
override def scanWithMultiValues(
575+
startKey: Option[UnsafeRow],
576+
endKey: UnsafeRow,
577+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
578+
validateAndTransitionState(UPDATE)
579+
verifyColFamilyOperations("scanWithMultiValues", colFamilyName)
580+
581+
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
582+
verify(
583+
kvEncoder._2.supportsMultipleValuesPerKey,
584+
"Multi-value iterator operation requires an encoder" +
585+
" which supports multiple values for a single key")
586+
587+
val encodedStartKey = startKey.map(kvEncoder._1.encodeKey)
588+
val encodedEndKey = kvEncoder._1.encodeKey(endKey)
589+
val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName)
590+
591+
val rowPair = new UnsafeRowPair()
592+
val iter = rocksDbIter.flatMap { kv =>
593+
val keyRow = kvEncoder._1.decodeKey(kv.key)
594+
val valueRows = kvEncoder._2.decodeValues(kv.value)
595+
valueRows.iterator.map { valueRow =>
596+
rowPair.withRows(keyRow, valueRow)
597+
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
598+
StateStoreProvider.validateStateRowFormat(
599+
rowPair.key, keySchema, rowPair.value, valueSchema, stateStoreId, storeConf)
600+
isValidated = true
601+
}
602+
rowPair
603+
}
604+
}
605+
606+
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
607+
}
608+
552609
var checkpointInfo: Option[StateStoreCheckpointInfo] = None
553610
private var storedMetrics: Option[RocksDBMetrics] = None
554611

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,49 @@ trait ReadStateStore {
183183
prefixKey: UnsafeRow,
184184
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
185185

186+
/**
187+
* Scan key-value pairs in the range [startKey, endKey).
188+
*
189+
* @param startKey None to scan from the beginning of the column family,
190+
* or Some(key) to seek to the given start position (inclusive).
191+
* @param endKey The exclusive upper bound for the scan.
192+
* @param colFamilyName The column family name.
193+
*
194+
* Callers must ensure the column family's key encoder produces lexicographically ordered
195+
* bytes for the scan range to be meaningful (e.g., timestamp-based encoders or
196+
* RangeKeyScanStateEncoder).
197+
*/
198+
def scan(
199+
startKey: Option[UnsafeRow],
200+
endKey: UnsafeRow,
201+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
202+
: StateStoreIterator[UnsafeRowPair] = {
203+
throw StateStoreErrors.unsupportedOperationException("scan", "")
204+
}
205+
206+
/**
207+
* Scan key-value pairs in the range [startKey, endKey), expanding multi-valued entries.
208+
*
209+
* @param startKey None to scan from the beginning of the column family,
210+
* or Some(key) to seek to the given start position (inclusive).
211+
* @param endKey The exclusive upper bound for the scan.
212+
* @param colFamilyName The column family name.
213+
*
214+
* Callers must ensure the column family's key encoder produces lexicographically ordered
215+
* bytes for the scan range to be meaningful (e.g., timestamp-based encoders or
216+
* RangeKeyScanStateEncoder).
217+
*
218+
* It is expected to throw exception if Spark calls this method without setting
219+
* multipleValuesPerKey as true for the column family.
220+
*/
221+
def scanWithMultiValues(
222+
startKey: Option[UnsafeRow],
223+
endKey: UnsafeRow,
224+
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
225+
: StateStoreIterator[UnsafeRowPair] = {
226+
throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "")
227+
}
228+
186229
/** Return an iterator containing all the key-value pairs in the StateStore. */
187230
def iterator(
188231
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
@@ -411,6 +454,20 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
411454
store.prefixScanWithMultiValues(prefixKey, colFamilyName)
412455
}
413456

457+
override def scan(
458+
startKey: Option[UnsafeRow],
459+
endKey: UnsafeRow,
460+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
461+
store.scan(startKey, endKey, colFamilyName)
462+
}
463+
464+
override def scanWithMultiValues(
465+
startKey: Option[UnsafeRow],
466+
endKey: UnsafeRow,
467+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
468+
store.scanWithMultiValues(startKey, endKey, colFamilyName)
469+
}
470+
414471
override def iteratorWithMultiValues(
415472
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
416473
store.iteratorWithMultiValues(colFamilyName)

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
172172
innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName)
173173
}
174174

175+
override def scan(
176+
startKey: Option[UnsafeRow],
177+
endKey: UnsafeRow,
178+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
179+
innerStore.scan(startKey, endKey, colFamilyName)
180+
}
181+
182+
override def scanWithMultiValues(
183+
startKey: Option[UnsafeRow],
184+
endKey: UnsafeRow,
185+
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
186+
innerStore.scanWithMultiValues(startKey, endKey, colFamilyName)
187+
}
188+
175189
override def iteratorWithMultiValues(
176190
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
177191
innerStore.iteratorWithMultiValues(colFamilyName)

0 commit comments

Comments
 (0)