diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8c032d31cff61..f62c22b196483 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6870,6 +6870,11 @@ "joinSide=, storeVersion=, partitionId=.", "Enable as a workaround to skip null values." ] + }, + "RANGE_SCAN_TIMESTAMP_OUT_OF_RANGE" : { + "message" : [ + "Range scan returned a row with timestamp outside the expected range [, ]." + ] } }, "sqlState" : "XXKST" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 4346f1096a15f..71a7d4cf56e13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -675,7 +675,7 @@ case class StreamingSymmetricHashJoinExec( private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(expr, _)) => + case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) => // inputSchema can be empty as expr should only have BoundReferences and does not require // the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]]. Predicate.create(expr, Seq.empty).eval _ @@ -684,7 +684,7 @@ case class StreamingSymmetricHashJoinExec( } private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match { - case Some(JoinStateValueWatermarkPredicate(expr, _)) => + case Some(JoinStateValueWatermarkPredicate(expr, _, _)) => Predicate.create(expr, inputAttributes).eval _ case _ => Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate @@ -905,21 +905,25 @@ case class StreamingSymmetricHashJoinExec( */ def removeOldState(): Long = { stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictByKeyCondition(stateKeyWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } - case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictByValueCondition(stateValueWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } case _ => 0L } @@ -937,21 +941,25 @@ case class StreamingSymmetricHashJoinExec( */ def removeAndReturnOldState(): Iterator[KeyToValuePair] = { stateWatermarkPredicate match { - case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictAndReturnByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } - case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) => + case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) => joinStateManager match { case s: SupportsEvictByCondition => s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc) case s: SupportsEvictByTimestamp => - s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark)) + s.evictAndReturnByTimestamp( + watermarkMsToStateTimestamp(stateWatermark), + prevStateWatermark.map(watermarkMsToStateTimestamp)) } case _ => Iterator.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala index cea6398f4e501..e96c9dbfd7e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinHelper.scala @@ -45,13 +45,37 @@ object StreamingSymmetricHashJoinHelper extends Logging { def desc: String override def toString: String = s"$desc: $expr" } - /** Predicate for watermark on state keys */ - case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long) + /** + * Predicate for watermark on state keys. + * + * @param stateWatermark Current batch's eviction watermark. Entries with timestamp + * at or below this value are eligible for eviction in this batch. + * @param prevStateWatermark Previous batch's eviction watermark, i.e. the watermark + * used for filtering late events in the current batch. + * Entries with timestamp at or below this value were already + * evicted in prior batches, so the effective range of entries + * to evict in this batch is `(prevStateWatermark, stateWatermark]`. + * State manager implementations can leverage this lower bound + * to optimize eviction (e.g. narrowing the scan range to skip + * already-evicted entries). `None` means we do not have a known + * lower bound (e.g. the first batch after restart), in which + * case eviction must consider all entries up to `stateWatermark`. + */ + case class JoinStateKeyWatermarkPredicate( + expr: Expression, + stateWatermark: Long, + prevStateWatermark: Option[Long] = None) extends JoinStateWatermarkPredicate { def desc: String = "key predicate" } - /** Predicate for watermark on state values */ - case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long) + /** + * Predicate for watermark on state values. See [[JoinStateKeyWatermarkPredicate]] for + * the semantics of `stateWatermark` and `prevStateWatermark`. + */ + case class JoinStateValueWatermarkPredicate( + expr: Expression, + stateWatermark: Long, + prevStateWatermark: Option[Long] = None) extends JoinStateWatermarkPredicate { def desc: String = "value predicate" } @@ -185,6 +209,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { rightKeys: Seq[Expression], condition: Option[Expression], eventTimeWatermarkForEviction: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = { // Perform assertions against multiple event time columns in the same DataFrame. This method @@ -215,7 +240,10 @@ object StreamingSymmetricHashJoinHelper extends Logging { expr.map { e => // watermarkExpression only provides the expression when eventTimeWatermarkForEviction // is defined - JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get) + JoinStateKeyWatermarkPredicate( + e, + eventTimeWatermarkForEviction.get, + eventTimeWatermarkForLateEvents) } } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( @@ -223,12 +251,19 @@ object StreamingSymmetricHashJoinHelper extends Logging { attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, eventTimeWatermarkForEviction) + val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ => + StreamingJoinHelper.getStateValueWatermark( + attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), + attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), + condition, + eventTimeWatermarkForLateEvents) + } val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) expr.map { e => // watermarkExpression only provides the expression when eventTimeWatermarkForEviction // is defined - JoinStateValueWatermarkPredicate(e, stateValueWatermark.get) + JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark) } } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala index 1a9f50365e2c7..02c9ef11df89c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/SymmetricHashJoinStateManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport} import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._ -import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder} +import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, RangeScanBoundaryUtils, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType} import org.apache.spark.util.NextIterator @@ -184,15 +184,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager => trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager => import SymmetricHashJoinStateManager._ - /** Evict the state by timestamp. Returns the number of values evicted. */ - def evictByTimestamp(endTimestamp: Long): Long + /** + * Evict the state by timestamp. Returns the number of values evicted. + * + * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are + * assumed to have been evicted already (e.g. from the previous batch). When provided, + * the scan starts from startTimestamp + 1. + */ + def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long /** * Evict the state by timestamp and return the evicted key-value pairs. * * It is caller's responsibility to consume the whole iterator. + * + * @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are + * assumed to have been evicted already (e.g. from the previous batch). When provided, + * the scan starts from startTimestamp + 1. */ - def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] + def evictAndReturnByTimestamp( + endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] } /** @@ -519,11 +532,11 @@ class SymmetricHashJoinStateManagerV4( } } - override def evictByTimestamp(endTimestamp: Long): Long = { + override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = { require(hasEventTime, "evictByTimestamp requires event time; secondary index was not populated") var removed = 0L - tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted => + tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted => val key = evicted.key val timestamp = evicted.timestamp val numValues = evicted.numValues @@ -537,12 +550,13 @@ class SymmetricHashJoinStateManagerV4( removed } - override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = { + override def evictAndReturnByTimestamp( + endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = { require(hasEventTime, "evictAndReturnByTimestamp requires event time; secondary index was not populated") val reusableKeyToValuePair = KeyToValuePair() - tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted => + tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted => val key = evicted.key val timestamp = evicted.timestamp val values = keyWithTsToValues.get(key, timestamp) @@ -663,14 +677,33 @@ class SymmetricHashJoinStateManagerV4( /** * Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp. - * Skips entries before minTs and stops iterating past maxTs (timestamps are sorted). + * When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient + * range access; falls back to prefixScan otherwise to stay within the key's scope. + * + * When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are + * filtered out so both code paths produce identical results. */ def getValuesInRange( key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { val reusableGetValuesResult = new GetValuesResult() + // Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires + // an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue. + val useRangeScan = maxTs < Long.MaxValue new NextIterator[GetValuesResult] { - private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) + private val iter = if (useRangeScan) { + // startKey must be copied because the second createKeyRow call below reuses + // the same projection buffer and would otherwise overwrite its contents. + // endKey does not need a copy: rangeScanWithMultiValues encodes both bounds + // to independent byte arrays eagerly at call time, and the scope of endKey + // ends with the call of rangeScanWithMultiValues. + val startKey = createKeyRow(key, minTs).copy() + // rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1 + val endKey = Some(createKeyRow(key, maxTs + 1)) + stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName) + } else { + stateStore.prefixScanWithMultiValues(key, colFamilyName) + } private var currentTs = -1L private var pastUpperBound = false @@ -697,6 +730,11 @@ class SymmetricHashJoinStateManagerV4( val unsafeRowPair = iter.next() val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) + if (useRangeScan && (ts < minTs || ts > maxTs)) { + throw StateStoreErrors.streamStreamJoinRangeScanTimestampOutOfRange( + ts, minTs, maxTs) + } + if (ts > maxTs) { pastUpperBound = true getNext() @@ -773,6 +811,8 @@ class SymmetricHashJoinStateManagerV4( isInternal = true ) + // Returns an UnsafeRow backed by a reused projection buffer. Callers that need to + // hold the row beyond the immediate state store call must invoke copy() on the result. private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = { TimestampKeyStateEncoder.attachTimestamp( attachTimestampProjection, keySchemaWithTimestamp, key, timestamp) @@ -788,9 +828,60 @@ class SymmetricHashJoinStateManagerV4( case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int) - // NOTE: This assumes we consume the whole iterator to trigger completion. - def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = { - val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName) + // Reusable default key row for scan boundary construction; see + // [[RangeScanBoundaryUtils]] for rationale. Safe to reuse because createKeyRow + // only reads this row (via BoundReference evaluations) and writes to the + // projection's own internal buffer. Correctness relies on real stored entries + // never having internally-null key fields, which is preserved by join-key + // expressions being evaluated via the user's expression encoder. Preserve this + // invariant if you change how entries are written. + private lazy val defaultKey: UnsafeRow = RangeScanBoundaryUtils.defaultUnsafeRow(keySchema) + + /** + * Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses + * TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields]. + * We need a full-schema row (not just the timestamp) because the encoder expects all + * key columns to be present. Default values are used for the key fields since only the + * timestamp matters for ordering in the prefix encoder. + */ + private def createScanBoundaryRow(timestamp: Long): UnsafeRow = { + createKeyRow(defaultKey, timestamp).copy() + } + + /** + * Scan keys eligible for eviction within the timestamp range. + * + * This assumes we consume the whole iterator to trigger completion. + * + * @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are + * eligible for eviction. + * @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp + * are assumed to have been evicted already. The scan starts from startTimestamp + 1. + */ + def scanEvictedKeys( + endTimestamp: Long, + startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = { + // If startTimestamp == Long.MaxValue, everything has already been evicted; + // nothing can match, so return immediately. + if (startTimestamp.contains(Long.MaxValue)) { + return Iterator.empty + } + + // rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive. + // startTimestamp is exclusive (already evicted), so we seek from st + 1. + val startKeyRow = startTimestamp.map { st => + createScanBoundaryRow(st + 1) + } + // endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound. + // When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is + // safe because rangeScanWithMultiValues with no end key uses the column-family prefix + // as the upper bound, naturally scoping the scan within this column family. + val endKeyRow = if (endTimestamp < Long.MaxValue) { + Some(createScanBoundaryRow(endTimestamp + 1)) + } else { + None + } + val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName) new NextIterator[EvictedKeysResult]() { var currentKeyRow: UnsafeRow = null var currentEventTime: Long = -1L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala index ac89079e6d515..1587fd4786a35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala @@ -531,13 +531,31 @@ class IncrementalExecution( case j: StreamingSymmetricHashJoinExec => val iwLateEvents = inputWatermarkForLateEvents(j.stateInfo.get) val iwEviction = inputWatermarkForEviction(j.stateInfo.get) + // Use the late-events watermark as the scan lower bound only when we can prove that + // it equals the previous batch's eviction watermark for this operator. That holds + // when STATEFUL_OPERATOR_ALLOW_MULTIPLE is true. + // + // When STATEFUL_OPERATOR_ALLOW_MULTIPLE is false (legacy mode), lateEvents == eviction + // for the same batch, so using it would collapse the eviction range to (wm, wm] = empty + // and silently stop state eviction. Fall back to None (full scan) in that mode, as we + // don't expect the legacy mode to be used by many users. + // + // The prevOffsetSeqMetadata.isDefined check guards against the first batch, where + // watermark propagation yields Some(0) for late events even though no state has been + // evicted yet, which would incorrectly skip entries at timestamp 0. + val prevBatchLateEventsWm = + if (prevOffsetSeqMetadata.isDefined && allowMultipleStatefulOperators) { + iwLateEvents + } else { + None + } j.copy( eventTimeWatermarkForLateEvents = iwLateEvents, eventTimeWatermarkForEviction = iwEviction, stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - iwEviction, !allowMultipleStatefulOperators) + iwEviction, prevBatchLateEventsWm, !allowMultipleStatefulOperators) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index a00c6a8bc73c3..a301ecbf22dbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -317,6 +317,14 @@ object StateStoreErrors { new StreamStreamJoinInconsistentStateNullValue( valueIndex, numValues, joinSide, storeVersion, partitionId, configKey) } + + def streamStreamJoinRangeScanTimestampOutOfRange( + timestamp: Long, + minTimestamp: Long, + maxTimestamp: Long): StreamStreamJoinInconsistentStateRangeScanTimestampOutOfRange = { + new StreamStreamJoinInconsistentStateRangeScanTimestampOutOfRange( + timestamp, minTimestamp, maxTimestamp) + } } trait ConvertableToCannotLoadStoreError { @@ -702,3 +710,14 @@ class StreamStreamJoinInconsistentStateNullValue( "storeVersion" -> storeVersion.toString, "partitionId" -> partitionId.toString, "configKey" -> configKey)) + +class StreamStreamJoinInconsistentStateRangeScanTimestampOutOfRange( + timestamp: Long, + minTimestamp: Long, + maxTimestamp: Long) + extends SparkRuntimeException( + errorClass = "STREAM_STREAM_JOIN_INCONSISTENT_STATE.RANGE_SCAN_TIMESTAMP_OUT_OF_RANGE", + messageParameters = Map( + "timestamp" -> timestamp.toString, + "minTimestamp" -> minTimestamp.toString, + "maxTimestamp" -> maxTimestamp.toString)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 4717d2fb40544..0460a41f4cc5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -1100,6 +1100,90 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite } } + test("StreamingJoinStateManager V4 - getValuesInRange boundary edge cases") { + withJoinStateManager( + inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => + implicit val mgr = manager + + Seq(10, 20, 30, 40, 50).foreach(append(40, _)) + + // Exact boundary matches (both inclusive) + assert(getJoinedRowTimestamps(40, Some((10L, 10L))) === Seq(10)) + assert(getJoinedRowTimestamps(40, Some((50L, 50L))) === Seq(50)) + + // Range with Long.MinValue / Long.MaxValue + assert(getJoinedRowTimestamps(40, Some((Long.MinValue, 30L))) === Seq(10, 20, 30)) + assert(getJoinedRowTimestamps(40, Some((30L, Long.MaxValue))) === Seq(30, 40, 50)) + assert(getJoinedRowTimestamps(40, Some((Long.MinValue, Long.MaxValue))) === + Seq(10, 20, 30, 40, 50)) + + // Empty range (minTs > maxTs) + assert(getJoinedRowTimestamps(40, Some((50L, 10L))) === Seq.empty) + + // Range entirely outside stored timestamps + assert(getJoinedRowTimestamps(40, Some((100L, 200L))) === Seq.empty) + assert(getJoinedRowTimestamps(40, Some((1L, 5L))) === Seq.empty) + + // Full range via None (all entries) + assert(getJoinedRowTimestamps(40, None) === Seq(10, 20, 30, 40, 50)) + } + } + + test("StreamingJoinStateManager V4 - evictByTimestamp boundary edge cases") { + withJoinStateManager( + inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => + implicit val mgr = manager + val evictByTs = manager.asInstanceOf[SupportsEvictByTimestamp] + + // --- Range eviction with startTimestamp (exclusive) and endTimestamp (inclusive) --- + Seq(10, 20, 30, 40, 50).foreach(append(40, _)) + // startTimestamp=20 is exclusive, endTimestamp=40 is inclusive: evicts timestamps 30, 40 + assert(evictByTs.evictByTimestamp(40, Some(20)) === 2) + assert(get(40) === Seq(10, 20, 50)) + + // --- evictAndReturnByTimestamp returns evicted values --- + Seq(30, 40).foreach(append(40, _)) // restore evicted entries + val evictedValues = evictByTs.evictAndReturnByTimestamp(30, Some(10)) + .map(p => toValueInt(p.value)).toSeq.sorted + // startTimestamp=10 is exclusive, endTimestamp=30 is inclusive: timestamps 20 and 30 + assert(evictedValues === Seq(20, 30)) + assert(get(40) === Seq(10, 40, 50)) + + // --- start equals end: empty range (exclusive start = inclusive end) --- + // startTimestamp=40 (exclusive) and endTimestamp=40 (inclusive): range is empty + assert(evictByTs.evictByTimestamp(40, Some(40)) === 0) + assert(get(40) === Seq(10, 40, 50)) + + // --- start just below entry: evicts exactly that entry --- + // startTimestamp=39 (exclusive) means entries >= 40 are scanned; endTimestamp=40 inclusive + assert(evictByTs.evictByTimestamp(40, Some(39)) === 1) + assert(get(40) === Seq(10, 50)) + + // --- overflow boundary: endTimestamp = Long.MaxValue --- + // Restore entries for a clean slate + Seq(20, 30, 40).foreach(append(40, _)) + // endTimestamp=Long.MaxValue with no startTimestamp: evicts all entries + assert(evictByTs.evictByTimestamp(Long.MaxValue) === 5) + assert(get(40) === Seq.empty) + + // --- overflow boundary: startTimestamp = Some(Long.MinValue) --- + Seq(10, 20, 30).foreach(append(40, _)) + // startTimestamp=Long.MinValue (exclusive), endTimestamp=20 (inclusive): + // Long.MinValue is excluded per the contract (already evicted), so the scan + // starts from Long.MinValue + 1. Since no real entry has timestamp Long.MinValue, + // this effectively scans all entries up to endTimestamp. + assert(evictByTs.evictByTimestamp(20, Some(Long.MinValue)) === 2) + assert(get(40) === Seq(30)) + + // --- overflow boundary: startTimestamp = Some(Long.MaxValue) --- + Seq(10, 20).foreach(append(40, _)) + // startTimestamp=Long.MaxValue (exclusive) means everything <= Long.MaxValue was already + // evicted. Nothing can remain, so the scan returns an empty iterator immediately. + assert(evictByTs.evictByTimestamp(50, Some(Long.MaxValue)) === 0) + assert(get(40) === Seq(10, 20, 30)) + } + } + // V1 excluded: V1 converter does not persist matched flags (SPARK-26154) versionsInTest.filter(_ >= 2).foreach { ver => test(s"StreamingJoinStateManager V$ver - skipUpdatingMatchedFlag skips matched flag update") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala index 06f25039147ad..e58af3b2bf651 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala @@ -23,7 +23,8 @@ import org.scalatest.Tag import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinExec -import org.apache.spark.sql.execution.streaming.runtime.MemoryStream +import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.{JoinStateKeyWatermarkPredicate, JoinStateValueWatermarkPredicate} +import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -329,6 +330,136 @@ class StreamingInnerJoinV4Suite ) } } + + test("prevStateWatermark must be None in the first batch") { + // Regression test for the IncrementalExecution guard: in the first batch + // prevOffsetSeqMetadata is None, so eventTimeWatermarkForLateEvents must NOT + // be passed to getStateWatermarkPredicates. Without the guard the watermark + // propagation framework yields Some(0) even in batch 0, which would cause + // scanEvictedKeys to skip state entries at timestamp 0. + val input1 = MemoryStream[(Int, Int)] + val input2 = MemoryStream[(Int, Int)] + + val df1 = input1.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "leftTime", + ($"key" * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + val df2 = input2.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "rightTime", + ($"key" * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = df1.join(df2, + df1("key") === df2("key") && + expr("leftTime >= rightTime - interval 5 seconds " + + "AND leftTime <= rightTime + interval 5 seconds"), + "inner") + .select(df1("key"), $"leftTime".cast("long"), $"leftValue", $"rightValue") + + def extractPrevWatermarks(q: StreamExecution): (Option[Long], Option[Long]) = { + val joinExec = q.lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val leftPrev = joinExec.stateWatermarkPredicates.left.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + val rightPrev = joinExec.stateWatermarkPredicates.right.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + (leftPrev, rightPrev) + } + + testStream(joined)( + // First batch: prevStateWatermark must be None on both sides. + MultiAddData(input1, (1, 5))(input2, (1, 5)), + CheckNewAnswer((1, 5, 2, 3)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isEmpty, + s"Left prevStateWatermark should be None in the first batch, got $leftPrev") + assert(rightPrev.isEmpty, + s"Right prevStateWatermark should be None in the first batch, got $rightPrev") + }, + + // Second batch: after watermark advances, prevStateWatermark should be set. + MultiAddData(input1, (2, 30))(input2, (2, 30)), + CheckNewAnswer((2, 30, 4, 6)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isDefined, + "Left prevStateWatermark should be defined after the first batch") + assert(rightPrev.isDefined, + "Right prevStateWatermark should be defined after the first batch") + }, + StopStream + ) + } + + test("SPARK-56402: prevStateWatermark must be None under legacy single-watermark propagator " + + "(STATEFUL_OPERATOR_ALLOW_MULTIPLE = false)") { + // Guards against the propagator-type bug: in legacy single-watermark mode + // (STATEFUL_OPERATOR_ALLOW_MULTIPLE = false), lateEvents == eviction for the + // same batch. If we naively thread `eventTimeWatermarkForLateEvents` as + // `prevStateWatermark`, the eviction scan range collapses to (wm, wm] = empty + // from batch 2 onward, silently skipping every eligible eviction. + // IncrementalExecution must fall back to None in legacy mode. + withSQLConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE.key -> "false") { + val input1 = MemoryStream[(Int, Int)] + val input2 = MemoryStream[(Int, Int)] + + val df1 = input1.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "leftTime", + ($"key" * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + val df2 = input2.toDF().toDF("key", "time") + .select($"key", timestamp_seconds($"time") as "rightTime", + ($"key" * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = df1.join(df2, + df1("key") === df2("key") && + expr("leftTime >= rightTime - interval 5 seconds " + + "AND leftTime <= rightTime + interval 5 seconds"), + "inner") + .select(df1("key"), $"leftTime".cast("long"), $"leftValue", $"rightValue") + + def extractPrevWatermarks(q: StreamExecution): (Option[Long], Option[Long]) = { + val joinExec = q.lastExecution.executedPlan.collect { + case j: StreamingSymmetricHashJoinExec => j + }.head + val leftPrev = joinExec.stateWatermarkPredicates.left.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + val rightPrev = joinExec.stateWatermarkPredicates.right.flatMap { + case p: JoinStateKeyWatermarkPredicate => p.prevStateWatermark + case p: JoinStateValueWatermarkPredicate => p.prevStateWatermark + } + (leftPrev, rightPrev) + } + + testStream(joined)( + MultiAddData(input1, (1, 5))(input2, (1, 5)), + CheckNewAnswer((1, 5, 2, 3)), + + // Batch 2+: even though prevOffsetSeqMetadata is now defined, legacy mode + // must keep prevStateWatermark = None to avoid collapsing the eviction scan + // range. + MultiAddData(input1, (2, 30))(input2, (2, 30)), + CheckNewAnswer((2, 30, 4, 6)), + Execute { q => + val (leftPrev, rightPrev) = extractPrevWatermarks(q) + assert(leftPrev.isEmpty, + s"Left prevStateWatermark must be None under legacy propagator, got $leftPrev") + assert(rightPrev.isEmpty, + s"Right prevStateWatermark must be None under legacy propagator, got $rightPrev") + }, + StopStream + ) + } + } } @SlowSQLTest