Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,93 @@ class RocksDB(
}
}

/**
* Scan key-value pairs in the range [startKey, endKey).
*
* @param startKey None to seek to the beginning of the column family,
* or Some(key) to seek to the given start position (inclusive).
* @param endKey None to scan to the end of the column family,
* or Some(key) as the exclusive upper bound for the scan (encoded key bytes).
* @param cfName The column family name.
* @return An iterator of ByteArrayPairs in the given range.
*/
def rangeScan(
startKey: Option[Array[Byte]],
endKey: Option[Array[Byte]],
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = {
updateMemoryUsageIfNeeded()

val upperBoundBytes: Option[Array[Byte]] = endKey match {
case Some(key) =>
Some(if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key)
case None =>
if (useColumnFamilies) {
val cfPrefix = encodeStateRowWithPrefix(Array.emptyByteArray, cfName)
RocksDB.prefixUpperBound(cfPrefix)
} else {
None
}
}

val seekTarget = startKey match {
case Some(key) =>
if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key
case None =>
if (useColumnFamilies) encodeStateRowWithPrefix(Array.emptyByteArray, cfName)
else null
}

val upperBoundSlice = upperBoundBytes.map(new Slice(_))
val scanReadOptions = new ReadOptions()
upperBoundSlice.foreach(scanReadOptions.setIterateUpperBound)

val iter = db.newIterator(scanReadOptions)
if (seekTarget != null) {
iter.seek(seekTarget)
} else {
iter.seekToFirst()
}

def closeResources(): Unit = {
iter.close()
scanReadOptions.close()
upperBoundSlice.foreach(_.close())
}

Option(TaskContext.get()).foreach { tc =>
tc.addTaskCompletionListener[Unit] { _ => closeResources() }
}

new NextIterator[ByteArrayPair] {
override protected def getNext(): ByteArrayPair = {
if (iter.isValid) {
val key = if (useColumnFamilies) {
decodeStateRowWithPrefix(iter.key)._1
} else {
iter.key
}

val value = if (conf.rowChecksumEnabled) {
KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
readVerifier, iter.key, iter.value, delimiterSize)
} else {
iter.value
}

byteArrayPair.set(key, value)
iter.next()
byteArrayPair
} else {
finished = true
closeResources()
null
}
}

override protected def close(): Unit = closeResources()
}
}

def release(): Unit = {}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.apache.spark.unsafe.Platform
sealed trait RocksDBKeyStateEncoder {
def supportPrefixKeyScan: Boolean
def supportsDeleteRange: Boolean
def supportsRangeScan: Boolean
def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
def encodeKey(row: UnsafeRow): Array[Byte]
def decodeKey(keyBytes: Array[Byte]): UnsafeRow
Expand Down Expand Up @@ -1500,6 +1501,8 @@ class PrefixKeyScanStateEncoder(
override def supportPrefixKeyScan: Boolean = true

override def supportsDeleteRange: Boolean = false

override def supportsRangeScan: Boolean = false
}

/**
Expand Down Expand Up @@ -1699,6 +1702,8 @@ class RangeKeyScanStateEncoder(
override def supportPrefixKeyScan: Boolean = true

override def supportsDeleteRange: Boolean = true

override def supportsRangeScan: Boolean = true
}

/**
Expand Down Expand Up @@ -1731,6 +1736,8 @@ class NoPrefixKeyStateEncoder(

override def supportsDeleteRange: Boolean = false

override def supportsRangeScan: Boolean = false

override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
throw new IllegalStateException("This encoder doesn't support prefix key!")
}
Expand Down Expand Up @@ -1884,6 +1891,8 @@ class TimestampAsPrefixKeyStateEncoder(

// TODO: [SPARK-55491] Revisit this to support delete range if needed.
override def supportsDeleteRange: Boolean = false

override def supportsRangeScan: Boolean = true
}

/**
Expand Down Expand Up @@ -1932,6 +1941,8 @@ class TimestampAsPostfixKeyStateEncoder(
}

override def supportsDeleteRange: Boolean = false

override def supportsRangeScan: Boolean = true
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,68 @@ private[sql] class RocksDBStateStoreProvider
new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
}

override def rangeScan(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
validateAndTransitionState(UPDATE)
verifyColFamilyOperations("rangeScan", colFamilyName)

val kvEncoder = keyValueEncoderMap.get(colFamilyName)
require(kvEncoder._1.supportsRangeScan,
"Range scan requires an encoder that supports range scanning!")

val encodedStartKey = startKey.map(kvEncoder._1.encodeKey)
val encodedEndKey = endKey.map(kvEncoder._1.encodeKey)

val rowPair = new UnsafeRowPair()
val rocksDbIter = rocksDB.rangeScan(encodedStartKey, encodedEndKey, colFamilyName)
val iter = rocksDbIter.map { kv =>
rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
kvEncoder._2.decodeValue(kv.value))
rowPair
}

new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
}

override def rangeScanWithMultiValues(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
validateAndTransitionState(UPDATE)
verifyColFamilyOperations("rangeScanWithMultiValues", colFamilyName)

val kvEncoder = keyValueEncoderMap.get(colFamilyName)
require(kvEncoder._1.supportsRangeScan,
"Range scan requires an encoder that supports range scanning!")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Range scan with multiple values requires an

verify(
kvEncoder._2.supportsMultipleValuesPerKey,
"Multi-value iterator operation requires an encoder" +
" which supports multiple values for a single key")

val encodedStartKey = startKey.map(kvEncoder._1.encodeKey)
val encodedEndKey = endKey.map(kvEncoder._1.encodeKey)
val rocksDbIter = rocksDB.rangeScan(encodedStartKey, encodedEndKey, colFamilyName)

val rowPair = new UnsafeRowPair()
val iter = rocksDbIter.flatMap { kv =>
val keyRow = kvEncoder._1.decodeKey(kv.key)
val valueRows = kvEncoder._2.decodeValues(kv.value)
valueRows.iterator.map { valueRow =>
rowPair.withRows(keyRow, valueRow)
if (!isValidated && rowPair.value != null && !useColumnFamilies) {
StateStoreProvider.validateStateRowFormat(
rowPair.key, keySchema, rowPair.value, valueSchema, stateStoreId, storeConf)
isValidated = true
}
Comment on lines +609 to +613
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only rangeScanWithMultiValues needs to validate the state row format, but rangeScan doesn't?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice finding! Looks like we missed prefixScan to do that (while prefixScanWithMultiValues does) and missed the same for rangeScan?

Let's deal with it as FOLLOWUP or another JIRA ticket since we want to address both prefixScan and rangeScan, not only rangeScan.

rowPair
}
}

new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
}

var checkpointInfo: Option[StateStoreCheckpointInfo] = None
private var storedMetrics: Option[RocksDBMetrics] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,51 @@ trait ReadStateStore {
prefixKey: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]

/**
* Scan key-value pairs in the range [startKey, endKey).
*
* @param startKey None to scan from the beginning of the column family,
* or Some(key) to seek to the given start position (inclusive).
* @param endKey None to scan to the end of the column family,
* or Some(key) as the exclusive upper bound for the scan.
* @param colFamilyName The column family name.
*
* Callers must ensure the column family's key encoder produces lexicographically ordered
* bytes for the scan range to be meaningful (e.g., timestamp-based encoders or
* RangeKeyScanStateEncoder).
*/
def rangeScan(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
: StateStoreIterator[UnsafeRowPair] = {
throw StateStoreErrors.unsupportedOperationException("rangeScan", "")
}

/**
* Scan key-value pairs in the range [startKey, endKey), expanding multi-valued entries.
*
* @param startKey None to scan from the beginning of the column family,
* or Some(key) to seek to the given start position (inclusive).
* @param endKey None to scan to the end of the column family,
* or Some(key) as the exclusive upper bound for the scan.
* @param colFamilyName The column family name.
*
* Callers must ensure the column family's key encoder produces lexicographically ordered
* bytes for the scan range to be meaningful (e.g., timestamp-based encoders or
* RangeKeyScanStateEncoder).
*
* It is expected to throw exception if Spark calls this method without setting
* multipleValuesPerKey as true for the column family.
*/
def rangeScanWithMultiValues(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
: StateStoreIterator[UnsafeRowPair] = {
throw StateStoreErrors.unsupportedOperationException("rangeScanWithMultiValues", "")
}

/** Return an iterator containing all the key-value pairs in the StateStore. */
def iterator(
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair]
Expand Down Expand Up @@ -411,6 +456,20 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
store.prefixScanWithMultiValues(prefixKey, colFamilyName)
}

override def rangeScan(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
store.rangeScan(startKey, endKey, colFamilyName)
}

override def rangeScanWithMultiValues(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
store.rangeScanWithMultiValues(startKey, endKey, colFamilyName)
}

override def iteratorWithMultiValues(
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
store.iteratorWithMultiValues(colFamilyName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
innerStore.prefixScanWithMultiValues(prefixKey, colFamilyName)
}

override def rangeScan(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
innerStore.rangeScan(startKey, endKey, colFamilyName)
}

override def rangeScanWithMultiValues(
startKey: Option[UnsafeRow],
endKey: Option[UnsafeRow],
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
innerStore.rangeScanWithMultiValues(startKey, endKey, colFamilyName)
}

override def iteratorWithMultiValues(
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
innerStore.iteratorWithMultiValues(colFamilyName)
Expand Down
Loading