@@ -34,7 +34,8 @@ import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper
3434import org .apache .spark .sql .execution .streaming .operators .stateful .join .StreamingSymmetricHashJoinHelper ._
3535import org .apache .spark .sql .execution .streaming .state .{DropLastNFieldsStatePartitionKeyExtractor , KeyStateEncoderSpec , NoopStatePartitionKeyExtractor , NoPrefixKeyStateEncoderSpec , StatePartitionKeyExtractor , StateSchemaBroadcast , StateStore , StateStoreCheckpointInfo , StateStoreColFamilySchema , StateStoreConf , StateStoreErrors , StateStoreId , StateStoreMetrics , StateStoreProvider , StateStoreProviderId , SupportsFineGrainedReplay , TimestampAsPostfixKeyStateEncoderSpec , TimestampAsPrefixKeyStateEncoderSpec , TimestampKeyStateEncoder }
3636import org .apache .spark .sql .internal .SQLConf
37- import org .apache .spark .sql .types .{BooleanType , DataType , LongType , NullType , StructField , StructType }
37+ import org .apache .spark .sql .types .{BinaryType , BooleanType , ByteType , DataType , DateType , DoubleType , FloatType , IntegerType , LongType , NullType , ShortType , StringType , StructField , StructType , TimestampNTZType , TimestampType }
38+ import org .apache .spark .unsafe .types .UTF8String
3839import org .apache .spark .util .NextIterator
3940
4041/**
@@ -184,15 +185,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
184185trait SupportsEvictByTimestamp { self : SymmetricHashJoinStateManager =>
185186 import SymmetricHashJoinStateManager ._
186187
187- /** Evict the state by timestamp. Returns the number of values evicted. */
188- def evictByTimestamp (endTimestamp : Long ): Long
188+ /**
189+ * Evict the state by timestamp. Returns the number of values evicted.
190+ *
191+ * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
192+ * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
193+ * assumed to have been evicted already (e.g. from the previous batch). When provided,
194+ * the scan starts from startTimestamp + 1.
195+ */
196+ def evictByTimestamp (endTimestamp : Long , startTimestamp : Option [Long ] = None ): Long
189197
190198 /**
191199 * Evict the state by timestamp and return the evicted key-value pairs.
192200 *
193201 * It is caller's responsibility to consume the whole iterator.
202+ *
203+ * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
204+ * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
205+ * assumed to have been evicted already (e.g. from the previous batch). When provided,
206+ * the scan starts from startTimestamp + 1.
194207 */
195- def evictAndReturnByTimestamp (endTimestamp : Long ): Iterator [KeyToValuePair ]
208+ def evictAndReturnByTimestamp (
209+ endTimestamp : Long , startTimestamp : Option [Long ] = None ): Iterator [KeyToValuePair ]
196210}
197211
198212/**
@@ -519,11 +533,11 @@ class SymmetricHashJoinStateManagerV4(
519533 }
520534 }
521535
522- override def evictByTimestamp (endTimestamp : Long ): Long = {
536+ override def evictByTimestamp (endTimestamp : Long , startTimestamp : Option [ Long ] = None ): Long = {
523537 require(hasEventTime,
524538 " evictByTimestamp requires event time; secondary index was not populated" )
525539 var removed = 0L
526- tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
540+ tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp ).foreach { evicted =>
527541 val key = evicted.key
528542 val timestamp = evicted.timestamp
529543 val numValues = evicted.numValues
@@ -537,12 +551,13 @@ class SymmetricHashJoinStateManagerV4(
537551 removed
538552 }
539553
540- override def evictAndReturnByTimestamp (endTimestamp : Long ): Iterator [KeyToValuePair ] = {
554+ override def evictAndReturnByTimestamp (
555+ endTimestamp : Long , startTimestamp : Option [Long ] = None ): Iterator [KeyToValuePair ] = {
541556 require(hasEventTime,
542557 " evictAndReturnByTimestamp requires event time; secondary index was not populated" )
543558 val reusableKeyToValuePair = KeyToValuePair ()
544559
545- tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
560+ tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp ).flatMap { evicted =>
546561 val key = evicted.key
547562 val timestamp = evicted.timestamp
548563 val values = keyWithTsToValues.get(key, timestamp)
@@ -663,17 +678,30 @@ class SymmetricHashJoinStateManagerV4(
663678
664679 /**
665680 * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
666- * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
681+ * When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient
682+ * range access; falls back to prefixScan otherwise to stay within the key's scope.
683+ *
684+ * When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are
685+ * filtered out so both code paths produce identical results.
667686 */
668687 def getValuesInRange (
669688 key : UnsafeRow , minTs : Long , maxTs : Long ): Iterator [GetValuesResult ] = {
670689 val reusableGetValuesResult = new GetValuesResult ()
690+ // Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires
691+ // an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue.
692+ val useRangeScan = maxTs < Long .MaxValue
671693
672694 new NextIterator [GetValuesResult ] {
673- private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
695+ private val iter = if (useRangeScan) {
696+ val startKey = createKeyRow(key, minTs).copy()
697+ // rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1
698+ val endKey = Some (createKeyRow(key, maxTs + 1 ))
699+ stateStore.rangeScanWithMultiValues(Some (startKey), endKey, colFamilyName)
700+ } else {
701+ stateStore.prefixScanWithMultiValues(key, colFamilyName)
702+ }
674703
675704 private var currentTs = - 1L
676- private var pastUpperBound = false
677705 private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer [ValueAndMatchPair ]()
678706
679707 private def flushAccumulated (): GetValuesResult = {
@@ -691,16 +719,16 @@ class SymmetricHashJoinStateManagerV4(
691719
692720 @ tailrec
693721 override protected def getNext (): GetValuesResult = {
694- if (pastUpperBound || ! iter.hasNext) {
722+ if (! iter.hasNext) {
695723 flushAccumulated()
696724 } else {
697725 val unsafeRowPair = iter.next()
698726 val ts = TimestampKeyStateEncoder .extractTimestamp(unsafeRowPair.key)
699727
700- if (ts > maxTs) {
701- pastUpperBound = true
702- getNext()
703- } else if (ts < minTs) {
728+ // Filter out entries outside [minTs, maxTs]. This is essential when using
729+ // prefixScan (which returns all timestamps for the key) and serves as a
730+ // safety guard for rangeScan as well.
731+ if (ts < minTs || ts > maxTs ) {
704732 getNext()
705733 } else if (currentTs == - 1L || currentTs == ts) {
706734 currentTs = ts
@@ -773,6 +801,8 @@ class SymmetricHashJoinStateManagerV4(
773801 isInternal = true
774802 )
775803
804+ // Returns an UnsafeRow backed by a reused projection buffer. Callers that need to
805+ // hold the row beyond the immediate state store call must invoke copy() on the result.
776806 private def createKeyRow (key : UnsafeRow , timestamp : Long ): UnsafeRow = {
777807 TimestampKeyStateEncoder .attachTimestamp(
778808 attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
@@ -788,9 +818,66 @@ class SymmetricHashJoinStateManagerV4(
788818
789819 case class EvictedKeysResult (key : UnsafeRow , timestamp : Long , numValues : Int )
790820
791- // NOTE: This assumes we consume the whole iterator to trigger completion.
792- def scanEvictedKeys (endTimestamp : Long ): Iterator [EvictedKeysResult ] = {
793- val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
821+ private def defaultInternalRow (schema : StructType ): InternalRow = {
822+ InternalRow .fromSeq(schema.map(f => defaultValueForType(f.dataType)))
823+ }
824+
825+ private def defaultValueForType (dt : DataType ): Any = dt match {
826+ case BooleanType => false
827+ case ByteType => 0 .toByte
828+ case ShortType => 0 .toShort
829+ case IntegerType | DateType => 0
830+ case LongType | TimestampType | TimestampNTZType => 0L
831+ case FloatType => 0.0f
832+ case DoubleType => 0.0
833+ case StringType => UTF8String .EMPTY_UTF8
834+ case BinaryType => Array .emptyByteArray
835+ case st : StructType => defaultInternalRow(st)
836+ case _ => null
837+ }
838+
839+ /**
840+ * Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
841+ * TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields].
842+ * We need a full-schema row (not just the timestamp) because the encoder expects all
843+ * key columns to be present. Default values are used for the key fields since only the
844+ * timestamp matters for ordering in the prefix encoder.
845+ */
846+ private def createScanBoundaryRow (timestamp : Long ): UnsafeRow = {
847+ val defaultKey = UnsafeProjection .create(keySchema)
848+ .apply(defaultInternalRow(keySchema))
849+ createKeyRow(defaultKey, timestamp).copy()
850+ }
851+
852+ /**
853+ * Scan keys eligible for eviction within the timestamp range.
854+ *
855+ * This assumes we consume the whole iterator to trigger completion.
856+ *
857+ * @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are
858+ * eligible for eviction.
859+ * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp
860+ * are assumed to have been evicted already. The scan starts from startTimestamp + 1.
861+ */
862+ def scanEvictedKeys (
863+ endTimestamp : Long ,
864+ startTimestamp : Option [Long ] = None ): Iterator [EvictedKeysResult ] = {
865+ // rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive.
866+ // startTimestamp is exclusive (already evicted), so we seek from st + 1.
867+ val startKeyRow = startTimestamp.flatMap { st =>
868+ if (st < Long .MaxValue ) Some (createScanBoundaryRow(st + 1 ))
869+ else None
870+ }
871+ // endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound.
872+ // When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is
873+ // safe because rangeScanWithMultiValues with no end key uses the column-family prefix
874+ // as the upper bound, naturally scoping the scan within this column family.
875+ val endKeyRow = if (endTimestamp < Long .MaxValue ) {
876+ Some (createScanBoundaryRow(endTimestamp + 1 ))
877+ } else {
878+ None
879+ }
880+ val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName)
794881 new NextIterator [EvictedKeysResult ]() {
795882 var currentKeyRow : UnsafeRow = null
796883 var currentEventTime : Long = - 1L
0 commit comments