Skip to content

Commit ef84c49

Browse files
committed
change the scan's endKey to option
1 parent 07e241d commit ef84c49

6 files changed

Lines changed: 39 additions & 30 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ class SymmetricHashJoinStateManagerV4(
662662
} else {
663663
val startKeyRow = createKeyRow(key, minTs).copy()
664664
val endKeyRow = createKeyRow(key, maxTs + 1)
665-
stateStore.scanWithMultiValues(Some(startKeyRow), endKeyRow, colFamilyName)
665+
stateStore.scanWithMultiValues(Some(startKeyRow), Some(endKeyRow), colFamilyName)
666666
}
667667

668668
private var currentTs = -1L
@@ -784,7 +784,7 @@ class SymmetricHashJoinStateManagerV4(
784784
// NOTE: This assumes we consume the whole iterator to trigger completion.
785785
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
786786
val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1)
787-
val evictIterator = stateStore.scanWithMultiValues(None, endKeyRow, colFamilyName)
787+
val evictIterator = stateStore.scanWithMultiValues(None, Some(endKeyRow), colFamilyName)
788788
new NextIterator[EvictedKeysResult]() {
789789
var currentKeyRow: UnsafeRow = null
790790
var currentEventTime: Long = -1L

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,20 +1667,27 @@ class RocksDB(
16671667
*
16681668
* @param startKey None to seek to the beginning of the column family,
16691669
* or Some(key) to seek to the given start position (inclusive).
1670-
* @param endKey The exclusive upper bound for the scan (encoded key bytes).
1670+
* @param endKey None to scan to the end of the column family,
1671+
* or Some(key) as the exclusive upper bound for the scan (encoded key bytes).
16711672
* @param cfName The column family name.
16721673
* @return An iterator of ByteArrayPairs in the given range.
16731674
*/
16741675
def scan(
16751676
startKey: Option[Array[Byte]],
1676-
endKey: Array[Byte],
1677+
endKey: Option[Array[Byte]],
16771678
cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = {
16781679
updateMemoryUsageIfNeeded()
16791680

1680-
val updatedEndKey = if (useColumnFamilies) {
1681-
encodeStateRowWithPrefix(endKey, cfName)
1682-
} else {
1683-
endKey
1681+
val upperBoundBytes: Option[Array[Byte]] = endKey match {
1682+
case Some(key) =>
1683+
Some(if (useColumnFamilies) encodeStateRowWithPrefix(key, cfName) else key)
1684+
case None =>
1685+
if (useColumnFamilies) {
1686+
val cfPrefix = encodeStateRowWithPrefix(Array.emptyByteArray, cfName)
1687+
RocksDB.prefixUpperBound(cfPrefix)
1688+
} else {
1689+
None
1690+
}
16841691
}
16851692

16861693
val seekTarget = startKey match {
@@ -1691,9 +1698,9 @@ class RocksDB(
16911698
else null
16921699
}
16931700

1694-
val upperBoundSlice = new Slice(updatedEndKey)
1701+
val upperBoundSlice = upperBoundBytes.map(new Slice(_))
16951702
val scanReadOptions = new ReadOptions()
1696-
scanReadOptions.setIterateUpperBound(upperBoundSlice)
1703+
upperBoundSlice.foreach(scanReadOptions.setIterateUpperBound)
16971704

16981705
val iter = db.newIterator(scanReadOptions)
16991706
if (seekTarget != null) {
@@ -1705,7 +1712,7 @@ class RocksDB(
17051712
def closeResources(): Unit = {
17061713
iter.close()
17071714
scanReadOptions.close()
1708-
upperBoundSlice.close()
1715+
upperBoundSlice.foreach(_.close())
17091716
}
17101717

17111718
Option(TaskContext.get()).foreach { tc =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,14 +551,14 @@ private[sql] class RocksDBStateStoreProvider
551551

552552
override def scan(
553553
startKey: Option[UnsafeRow],
554-
endKey: UnsafeRow,
554+
endKey: Option[UnsafeRow],
555555
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
556556
validateAndTransitionState(UPDATE)
557557
verifyColFamilyOperations("scan", colFamilyName)
558558

559559
val kvEncoder = keyValueEncoderMap.get(colFamilyName)
560560
val encodedStartKey = startKey.map(kvEncoder._1.encodeKey)
561-
val encodedEndKey = kvEncoder._1.encodeKey(endKey)
561+
val encodedEndKey = endKey.map(kvEncoder._1.encodeKey)
562562

563563
val rowPair = new UnsafeRowPair()
564564
val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName)
@@ -573,7 +573,7 @@ private[sql] class RocksDBStateStoreProvider
573573

574574
override def scanWithMultiValues(
575575
startKey: Option[UnsafeRow],
576-
endKey: UnsafeRow,
576+
endKey: Option[UnsafeRow],
577577
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
578578
validateAndTransitionState(UPDATE)
579579
verifyColFamilyOperations("scanWithMultiValues", colFamilyName)
@@ -585,7 +585,7 @@ private[sql] class RocksDBStateStoreProvider
585585
" which supports multiple values for a single key")
586586

587587
val encodedStartKey = startKey.map(kvEncoder._1.encodeKey)
588-
val encodedEndKey = kvEncoder._1.encodeKey(endKey)
588+
val encodedEndKey = endKey.map(kvEncoder._1.encodeKey)
589589
val rocksDbIter = rocksDB.scan(encodedStartKey, encodedEndKey, colFamilyName)
590590

591591
val rowPair = new UnsafeRowPair()

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ trait ReadStateStore {
188188
*
189189
* @param startKey None to scan from the beginning of the column family,
190190
* or Some(key) to seek to the given start position (inclusive).
191-
* @param endKey The exclusive upper bound for the scan.
191+
* @param endKey None to scan to the end of the column family,
192+
* or Some(key) as the exclusive upper bound for the scan.
192193
* @param colFamilyName The column family name.
193194
*
194195
* Callers must ensure the column family's key encoder produces lexicographically ordered
@@ -197,7 +198,7 @@ trait ReadStateStore {
197198
*/
198199
def scan(
199200
startKey: Option[UnsafeRow],
200-
endKey: UnsafeRow,
201+
endKey: Option[UnsafeRow],
201202
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
202203
: StateStoreIterator[UnsafeRowPair] = {
203204
throw StateStoreErrors.unsupportedOperationException("scan", "")
@@ -208,7 +209,8 @@ trait ReadStateStore {
208209
*
209210
* @param startKey None to scan from the beginning of the column family,
210211
* or Some(key) to seek to the given start position (inclusive).
211-
* @param endKey The exclusive upper bound for the scan.
212+
* @param endKey None to scan to the end of the column family,
213+
* or Some(key) as the exclusive upper bound for the scan.
212214
* @param colFamilyName The column family name.
213215
*
214216
* Callers must ensure the column family's key encoder produces lexicographically ordered
@@ -220,7 +222,7 @@ trait ReadStateStore {
220222
*/
221223
def scanWithMultiValues(
222224
startKey: Option[UnsafeRow],
223-
endKey: UnsafeRow,
225+
endKey: Option[UnsafeRow],
224226
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
225227
: StateStoreIterator[UnsafeRowPair] = {
226228
throw StateStoreErrors.unsupportedOperationException("scanWithMultiValues", "")
@@ -456,14 +458,14 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
456458

457459
override def scan(
458460
startKey: Option[UnsafeRow],
459-
endKey: UnsafeRow,
461+
endKey: Option[UnsafeRow],
460462
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
461463
store.scan(startKey, endKey, colFamilyName)
462464
}
463465

464466
override def scanWithMultiValues(
465467
startKey: Option[UnsafeRow],
466-
endKey: UnsafeRow,
468+
endKey: Option[UnsafeRow],
467469
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
468470
store.scanWithMultiValues(startKey, endKey, colFamilyName)
469471
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta
174174

175175
override def scan(
176176
startKey: Option[UnsafeRow],
177-
endKey: UnsafeRow,
177+
endKey: Option[UnsafeRow],
178178
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
179179
innerStore.scan(startKey, endKey, colFamilyName)
180180
}
181181

182182
override def scanWithMultiValues(
183183
startKey: Option[UnsafeRow],
184-
endKey: UnsafeRow,
184+
endKey: Option[UnsafeRow],
185185
colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
186186
innerStore.scanWithMultiValues(startKey, endKey, colFamilyName)
187187
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBTimestampEncoderOperationsSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
591591
val prefixKey = keyToRow("key1", 1)
592592
val startKey = keyAndTimestampToRow("key1", 1, 0L)
593593
val endKey = keyAndTimestampToRow("key1", 1, 1001L)
594-
val iter = store.scanWithMultiValues(Some(startKey), endKey)
594+
val iter = store.scanWithMultiValues(Some(startKey), Some(endKey))
595595

596596
val results = iter.map { pair =>
597597
(pair.key.getLong(2), pair.value.getInt(0))
@@ -628,7 +628,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
628628
// Scan with Some(startKey) covering full range
629629
val startKey = keyAndTimestampToRow("key1", 1, Long.MinValue)
630630
val endKey = keyAndTimestampToRow("key1", 1, Long.MaxValue)
631-
val iter = store.scanWithMultiValues(Some(startKey), endKey)
631+
val iter = store.scanWithMultiValues(Some(startKey), Some(endKey))
632632
val results = iter.map(_.key.getLong(2)).toList
633633
iter.close()
634634

@@ -657,7 +657,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
657657
// Scan range that contains no entries
658658
val startKey = keyAndTimestampToRow("key1", 1, 1500L)
659659
val endKey = keyAndTimestampToRow("key1", 1, 1600L)
660-
val iter = store.scanWithMultiValues(Some(startKey), endKey)
660+
val iter = store.scanWithMultiValues(Some(startKey), Some(endKey))
661661
assert(!iter.hasNext)
662662
iter.close()
663663
} finally {
@@ -685,7 +685,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
685685

686686
// Scan from beginning (None) up to 301 (exclusive), covering [..300]
687687
val endKey = keyAndTimestampToRow("key1", 1, 301L)
688-
val iter = store.scanWithMultiValues(None, endKey)
688+
val iter = store.scanWithMultiValues(None, Some(endKey))
689689
val results = iter.map(_.key.getLong(2)).toList
690690
iter.close()
691691

@@ -715,7 +715,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
715715
// Scan with endKey at timestamp 201 with dummyKey - should include
716716
// everything up to timestamp 200 regardless of join key
717717
val endKey = keyAndTimestampToRow("key1", 1, 201L)
718-
val iter = store.scanWithMultiValues(None, endKey)
718+
val iter = store.scanWithMultiValues(None, Some(endKey))
719719
val results = iter.map { pair =>
720720
(pair.key.getString(0), pair.key.getLong(2))
721721
}.toList
@@ -747,7 +747,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
747747

748748
val startKey = keyAndTimestampToRow("key1", 1, 200L)
749749
val endKey = keyAndTimestampToRow("key1", 1, 401L)
750-
val iter = store.scan(Some(startKey), endKey)
750+
val iter = store.scan(Some(startKey), Some(endKey))
751751
val results = iter.map { pair =>
752752
(pair.key.getLong(2), pair.value.getInt(0))
753753
}.toList
@@ -778,7 +778,7 @@ class RocksDBTimestampEncoderOperationsSuite extends SharedSparkSession
778778
// Scan negative range only: [-300, 0)
779779
val startKey = keyAndTimestampToRow("key1", 1, -300L)
780780
val endKey = keyAndTimestampToRow("key1", 1, 0L)
781-
val iter = store.scan(Some(startKey), endKey)
781+
val iter = store.scan(Some(startKey), Some(endKey))
782782
val results = iter.map(_.key.getLong(2)).toList
783783
iter.close()
784784

0 commit comments

Comments
 (0)