@@ -27,7 +27,7 @@ import org.apache.spark.TaskContext
2727import org .apache .spark .internal .Logging
2828import org .apache .spark .internal .LogKeys .{END_INDEX , START_INDEX , STATE_STORE_ID }
2929import org .apache .spark .sql .catalyst .InternalRow
30- import org .apache .spark .sql .catalyst .expressions .{Attribute , AttributeReference , Expression , GenericInternalRow , JoinedRow , Literal , SafeProjection , SpecificInternalRow , UnsafeProjection , UnsafeRow }
30+ import org .apache .spark .sql .catalyst .expressions .{Attribute , AttributeReference , Expression , JoinedRow , Literal , SafeProjection , SpecificInternalRow , UnsafeProjection , UnsafeRow }
3131import org .apache .spark .sql .catalyst .types .DataTypeUtils .toAttributes
3232import org .apache .spark .sql .execution .metric .SQLMetric
3333import org .apache .spark .sql .execution .streaming .operators .stateful .{StatefulOperatorStateInfo , StatefulOpStateStoreCheckpointInfo , WatermarkSupport }
@@ -647,25 +647,17 @@ class SymmetricHashJoinStateManagerV4(
647647
648648 /**
649649 * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
650- *
651- * When a bounded range is provided, leverages RocksDB's native seek and upper bound via
652- * [[StateStore.scanWithMultiValues ]] to avoid reading entries outside the range.
653- * Falls back to [[StateStore.prefixScanWithMultiValues ]] when the full range is requested.
650+ * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
654651 */
655652 def getValuesInRange (
656653 key : UnsafeRow , minTs : Long , maxTs : Long ): Iterator [GetValuesResult ] = {
657654 val reusableGetValuesResult = new GetValuesResult ()
658655
659656 new NextIterator [GetValuesResult ] {
660- private val iter = if (minTs == Long .MinValue && maxTs == Long .MaxValue ) {
661- stateStore.prefixScanWithMultiValues(key, colFamilyName)
662- } else {
663- val startKeyRow = createKeyRow(key, minTs).copy()
664- val endKeyRow = createKeyRow(key, maxTs + 1 )
665- stateStore.scanWithMultiValues(Some (startKeyRow), Some (endKeyRow), colFamilyName)
666- }
657+ private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
667658
668659 private var currentTs = - 1L
660+ private var pastUpperBound = false
669661 private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer [ValueAndMatchPair ]()
670662
671663 private def flushAccumulated (): GetValuesResult = {
@@ -683,13 +675,18 @@ class SymmetricHashJoinStateManagerV4(
683675
684676 @ tailrec
685677 override protected def getNext (): GetValuesResult = {
686- if (! iter.hasNext) {
678+ if (pastUpperBound || ! iter.hasNext) {
687679 flushAccumulated()
688680 } else {
689681 val unsafeRowPair = iter.next()
690682 val ts = TimestampKeyStateEncoder .extractTimestamp(unsafeRowPair.key)
691683
692- if (currentTs == - 1L || currentTs == ts) {
684+ if (ts > maxTs) {
685+ pastUpperBound = true
686+ getNext()
687+ } else if (ts < minTs) {
688+ getNext()
689+ } else if (currentTs == - 1L || currentTs == ts) {
693690 currentTs = ts
694691 valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
695692 getNext()
@@ -773,18 +770,11 @@ class SymmetricHashJoinStateManagerV4(
773770 stateStore.remove(createKeyRow(key, timestamp), colFamilyName)
774771 }
775772
776- private lazy val dummyKeyRow : UnsafeRow = {
777- val defaultValues = keySchema.fields.map(f => Literal .default(f.dataType).eval())
778- val projection = UnsafeProjection .create(keySchema)
779- projection(new GenericInternalRow (defaultValues)).copy()
780- }
781-
782773 case class EvictedKeysResult (key : UnsafeRow , timestamp : Long , numValues : Int )
783774
784775 // NOTE: This assumes we consume the whole iterator to trigger completion.
785776 def scanEvictedKeys (endTimestamp : Long ): Iterator [EvictedKeysResult ] = {
786- val endKeyRow = createKeyRow(dummyKeyRow, endTimestamp + 1 )
787- val evictIterator = stateStore.scanWithMultiValues(None , Some (endKeyRow), colFamilyName)
777+ val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
788778 new NextIterator [EvictedKeysResult ]() {
789779 var currentKeyRow : UnsafeRow = null
790780 var currentEventTime : Long = - 1L
@@ -799,19 +789,25 @@ class SymmetricHashJoinStateManagerV4(
799789 val ts = TimestampKeyStateEncoder .extractTimestamp(kv.key)
800790
801791 if (keyRow == currentKeyRow && ts == currentEventTime) {
792+ // new value with same (key, ts)
802793 count += 1
803794 } else if (ts > endTimestamp) {
804- // Safety check for boundary edge case: a small number of entries at exactly
805- // endTimestamp + 1 may leak through the upper bound because the encoded end key
806- // includes a join key suffix.
795+ // we found the timestamp beyond the range - we shouldn't continue further
807796 isBeyondUpperBound = true
797+
798+ // We don't need to construct the last (key, ts) into EvictedKeysResult - the code
799+ // after loop will handle that if there is leftover. That said, we do not reset the
800+ // current (key, ts) info here.
808801 } else if (currentKeyRow == null && currentEventTime == - 1L ) {
802+ // first value to process
809803 currentKeyRow = keyRow.copy()
810804 currentEventTime = ts
811805 count = 1
812806 } else {
807+ // construct the last (key, ts) into EvictedKeysResult
813808 ret = EvictedKeysResult (currentKeyRow, currentEventTime, count)
814809
810+ // register the next (key, ts) to process
815811 currentKeyRow = keyRow.copy()
816812 currentEventTime = ts
817813 count = 1
@@ -821,8 +817,10 @@ class SymmetricHashJoinStateManagerV4(
821817 if (ret != null ) {
822818 ret
823819 } else if (count > 0 ) {
820+ // there is a final leftover (key, ts) to return
824821 ret = EvictedKeysResult (currentKeyRow, currentEventTime, count)
825822
823+ // we shouldn't continue further
826824 currentKeyRow = null
827825 currentEventTime = - 1L
828826 count = 0
0 commit comments