Skip to content

Commit f151e03

Browse files
committed
[SPARK-56400][SS] Apply rangeScan API in transformWithState Timer/TTL
### What changes were proposed in this pull request? This PR proposes to apply rangeScan API in transformWithState Timer/TTL, which will give an improvement of scanning on expired timers and entries with configured TTL. The main idea is to perform scanning of expired timers and TTL entries from [the end timestamp of previous scan + 1, new end timestamp], which was [None, new end timestamp]. Previously it had to go through tombstones prior batches made in prior evictions (till compaction happens), and with this change we will be able to skip those tombstones. ### Why are the changes needed? This change will give a hit to RocksDB about the exact range to scan, reducing the chance of reading tombstone a lot. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UTs, and existing UTs. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude 4.6 Opus Closes #55265 from HeartSaVioR/SPARK-56400-on-top-of-SPARK-56369. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent acfae33 commit f151e03

File tree

15 files changed

+562
-27
lines changed

15 files changed

+562
-27
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
787787
outputAttr,
788788
stateInfo = None,
789789
batchTimestampMs = None,
790+
prevBatchTimestampMs = None,
790791
eventTimeWatermarkForLateEvents = None,
791792
eventTimeWatermarkForEviction = None,
792793
planLater(child),
@@ -815,6 +816,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
815816
func, t.leftAttributes, outputAttrs, outputMode, timeMode,
816817
stateInfo = None,
817818
batchTimestampMs = None,
819+
prevBatchTimestampMs = None,
818820
eventTimeWatermarkForLateEvents = None,
819821
eventTimeWatermarkForEviction = None,
820822
userFacingDataType,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ case class TransformWithStateInPySparkExec(
7474
timeMode: TimeMode,
7575
stateInfo: Option[StatefulOperatorStateInfo],
7676
batchTimestampMs: Option[Long],
77+
prevBatchTimestampMs: Option[Long] = None,
7778
eventTimeWatermarkForLateEvents: Option[Long],
7879
eventTimeWatermarkForEviction: Option[Long],
7980
userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
@@ -314,7 +315,8 @@ case class TransformWithStateInPySparkExec(
314315
val data = groupAndProject(filteredIter, groupingAttributes, child.output, dedupAttributes)
315316

316317
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
317-
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
318+
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs,
319+
prevBatchTimestampMs, metrics)
318320

319321
val evalType = {
320322
if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
@@ -442,6 +444,7 @@ object TransformWithStateInPySparkExec {
442444
Some(System.currentTimeMillis),
443445
None,
444446
None,
447+
None,
445448
userFacingDataType,
446449
child,
447450
isStreaming = false,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ case class TransformWithStateExec(
6767
outputObjAttr: Attribute,
6868
stateInfo: Option[StatefulOperatorStateInfo],
6969
batchTimestampMs: Option[Long],
70+
prevBatchTimestampMs: Option[Long] = None,
7071
eventTimeWatermarkForLateEvents: Option[Long],
7172
eventTimeWatermarkForEviction: Option[Long],
7273
child: SparkPlan,
@@ -251,7 +252,7 @@ case class TransformWithStateExec(
251252
case ProcessingTime =>
252253
assert(batchTimestampMs.isDefined)
253254
val batchTimestamp = batchTimestampMs.get
254-
processorHandle.getExpiredTimers(batchTimestamp)
255+
processorHandle.getExpiredTimers(batchTimestamp, prevBatchTimestampMs)
255256
.flatMap { case (keyObj, expiryTimestampMs) =>
256257
numExpiredTimers += 1
257258
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
@@ -260,7 +261,26 @@ case class TransformWithStateExec(
260261
case EventTime =>
261262
assert(eventTimeWatermarkForEviction.isDefined)
262263
val watermark = eventTimeWatermarkForEviction.get
263-
processorHandle.getExpiredTimers(watermark)
264+
// Use the late-events watermark as the scan lower bound only when we can prove that
265+
// it equals the previous batch's eviction watermark for this operator. That holds
266+
// when STATEFUL_OPERATOR_ALLOW_MULTIPLE is true.
267+
//
268+
// When STATEFUL_OPERATOR_ALLOW_MULTIPLE is false (legacy mode), lateEvents == eviction
269+
// for the same batch, so using it would collapse the eviction range to (wm, wm] = empty
270+
// and silently stop firing timers. Fall back to None (full scan) in that mode, as we
271+
// don't expect the legacy mode to be used by many users.
272+
//
273+
// The prevBatchTimestampMs.isDefined check guards against the first batch, where
274+
// watermark propagation yields Some(0) for late events even though no timers have been
275+
// processed yet, which would incorrectly skip timers registered at timestamp 0.
276+
val prevWatermark =
277+
if (prevBatchTimestampMs.isDefined &&
278+
conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) {
279+
eventTimeWatermarkForLateEvents
280+
} else {
281+
None
282+
}
283+
processorHandle.getExpiredTimers(watermark, prevWatermark)
264284
.flatMap { case (keyObj, expiryTimestampMs) =>
265285
numExpiredTimers += 1
266286
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
@@ -493,7 +513,7 @@ case class TransformWithStateExec(
493513
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
494514
val processorHandle = new StatefulProcessorHandleImpl(
495515
store, getStateInfo.queryRunId, keyEncoder, timeMode,
496-
isStreaming, batchTimestampMs, metrics)
516+
isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
497517
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
498518
statefulProcessor.setHandle(processorHandle)
499519
withStatefulProcessorErrorHandling("init") {
@@ -509,7 +529,7 @@ case class TransformWithStateExec(
509529
initStateIterator: Iterator[InternalRow]):
510530
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
511531
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
512-
keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
532+
keyEncoder, timeMode, isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
513533
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
514534
statefulProcessor.setHandle(processorHandle)
515535
withStatefulProcessorErrorHandling("init") {
@@ -581,6 +601,7 @@ object TransformWithStateExec {
581601
Some(System.currentTimeMillis),
582602
None,
583603
None,
604+
None,
584605
child,
585606
isStreaming = false,
586607
hasInitialState,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statefulprocessor/StatefulProcessorHandleImpl.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class StatefulProcessorHandleImpl(
114114
timeMode: TimeMode,
115115
isStreaming: Boolean = true,
116116
batchTimestampMs: Option[Long] = None,
117+
prevBatchTimestampMs: Option[Long] = None,
117118
metrics: Map[String, SQLMetric] = Map.empty)
118119
extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging {
119120
import StatefulProcessorHandleState._
@@ -171,13 +172,19 @@ class StatefulProcessorHandleImpl(
171172

172173
/**
173174
* Function to retrieve all expired registered timers for all grouping keys
174-
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
175-
* will return all timers that have timestamp less than passed threshold
175+
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
176+
* this function will return all timers that have timestamp
177+
* less than or equal to the passed threshold.
178+
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
179+
* Timers at or below this timestamp are assumed to have been
180+
* already processed in the previous batch and will be skipped.
176181
* @return - iterator of registered timers for all grouping keys
177182
*/
178-
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
183+
def getExpiredTimers(
184+
expiryTimestampMs: Long,
185+
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
179186
verifyTimerOperations("get_expired_timers")
180-
timerState.getExpiredTimers(expiryTimestampMs)
187+
timerState.getExpiredTimers(expiryTimestampMs, prevExpiryTimestampMs)
181188
}
182189

183190
/**
@@ -237,7 +244,8 @@ class StatefulProcessorHandleImpl(
237244
validateTTLConfig(ttlConfig, stateName)
238245
assert(batchTimestampMs.isDefined)
239246
val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
240-
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
247+
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
248+
prevBatchTimestampMs, metrics)
241249
ttlStates.add(valueStateWithTTL)
242250
TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
243251
valueStateWithTTL
@@ -286,7 +294,8 @@ class StatefulProcessorHandleImpl(
286294
validateTTLConfig(ttlConfig, stateName)
287295
assert(batchTimestampMs.isDefined)
288296
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
289-
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
297+
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
298+
prevBatchTimestampMs, metrics)
290299
TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
291300
ttlStates.add(listStateWithTTL)
292301
listStateWithTTL
@@ -324,7 +333,8 @@ class StatefulProcessorHandleImpl(
324333
validateTTLConfig(ttlConfig, stateName)
325334
assert(batchTimestampMs.isDefined)
326335
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
327-
valEncoder, ttlConfig, batchTimestampMs.get, metrics)
336+
valEncoder, ttlConfig, batchTimestampMs.get,
337+
prevBatchTimestampMs, metrics)
328338
TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
329339
ttlStates.add(mapStateWithTTL)
330340
mapStateWithTTL

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,36 @@ class TimerStateImpl(
112112
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)),
113113
useMultipleValuesPerKey = false, isInternal = true)
114114

115+
private val secIndexProjection = UnsafeProjection.create(keySchemaForSecIndex)
116+
117+
// Placeholder grouping-key struct used in range-scan boundary rows; see
118+
// [[RangeScanBoundaryUtils]] for rationale. Correctness relies on real stored
119+
// entries never having a null grouping-key struct, which is preserved by
120+
// registerTimer going through the user's expression encoder. Preserve this
121+
// invariant if you change how entries are written.
122+
private val defaultGroupingKey: InternalRow = RangeScanBoundaryUtils.defaultInternalRow(
123+
keySchemaForSecIndex(1).dataType.asInstanceOf[StructType])
124+
125+
/**
126+
* Encodes a timestamp into an UnsafeRow key for the secondary index.
127+
* The timestamp is incremented by 1 so that the encoded key serves as an exclusive
128+
* lower / upper bound in range scans. Returns None if tsMs is Long.MaxValue
129+
* (overflow guard).
130+
*
131+
* The returned UnsafeRow is always a fresh copy, safe to hold alongside other
132+
* rows produced by the same projection.
133+
*/
134+
private def encodeTimestampAsKey(tsMs: Long): Option[UnsafeRow] = {
135+
if (tsMs < Long.MaxValue) {
136+
val row = new GenericInternalRow(keySchemaForSecIndex.length)
137+
row.setLong(0, tsMs + 1)
138+
row.update(1, defaultGroupingKey)
139+
Some(secIndexProjection.apply(row).copy())
140+
} else {
141+
None
142+
}
143+
}
144+
115145
private def getGroupingKey(cfName: String): Any = {
116146
val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
117147
if (keyOption.isEmpty) {
@@ -189,15 +219,22 @@ class TimerStateImpl(
189219

190220
/**
191221
* Function to get all the expired registered timers for all grouping keys.
192-
* Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or
222+
* Perform a range scan on timestamp and will stop iterating once the key row timestamp
193223
* exceeds the limit (as timestamp key is increasingly sorted).
194-
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
195-
* will return all timers that have timestamp less than passed threshold.
224+
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
225+
* this function will return all timers that have timestamp
226+
* less than or equal to the passed threshold.
227+
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
228+
* Timers at or below this timestamp are assumed to have been
229+
* already processed in the previous batch and will be skipped.
196230
* @return - iterator of all the registered timers for all grouping keys
197231
*/
198-
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
199-
// this iter is increasingly sorted on timestamp
200-
val iter = store.iterator(tsToKeyCFName)
232+
def getExpiredTimers(
233+
expiryTimestampMs: Long,
234+
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
235+
val startKey = prevExpiryTimestampMs.flatMap(encodeTimestampAsKey)
236+
val endKey = encodeTimestampAsKey(expiryTimestampMs)
237+
val iter = store.rangeScan(startKey, endKey, tsToKeyCFName)
201238

202239
new NextIterator[(Any, Long)] {
203240
override protected def getNext(): (Any, Long) = {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/ListStateImplWithTTL.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ import org.apache.spark.util.NextIterator
3535
* @param valEncoder - Spark SQL encoder for value
3636
* @param ttlConfig - TTL configuration for values stored in this state
3737
* @param batchTimestampMs - current batch processing timestamp.
38+
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
39+
* Entries with expiration at or below this timestamp are assumed
40+
* to have been already cleaned up and will be skipped during
41+
* TTL eviction scans.
3842
* @param metrics - metrics to be updated as part of stateful processing
3943
* @tparam S - data type of object that will be stored
4044
*/
@@ -45,9 +49,11 @@ class ListStateImplWithTTL[S](
4549
valEncoder: ExpressionEncoder[Any],
4650
ttlConfig: TTLConfig,
4751
batchTimestampMs: Long,
52+
prevBatchTimestampMs: Option[Long] = None,
4853
metrics: Map[String, SQLMetric])
4954
extends OneToManyTTLState(
50-
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] {
55+
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs,
56+
prevBatchTimestampMs, metrics) with ListState[S] {
5157

5258
private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
5359
stateName, hasTtl = true)

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
3636
* @param valEncoder - SQL encoder for state variable
3737
* @param ttlConfig - the ttl configuration (time to live duration etc.)
3838
* @param batchTimestampMs - current batch processing timestamp.
39+
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
40+
* Entries with expiration at or below this timestamp are assumed
41+
* to have been already cleaned up and will be skipped during
42+
* TTL eviction scans.
3943
* @param metrics - metrics to be updated as part of stateful processing
4044
* @tparam K - type of key for map state variable
4145
* @tparam V - type of value for map state variable
@@ -49,10 +53,11 @@ class MapStateImplWithTTL[K, V](
4953
valEncoder: ExpressionEncoder[Any],
5054
ttlConfig: TTLConfig,
5155
batchTimestampMs: Long,
52-
metrics: Map[String, SQLMetric])
56+
prevBatchTimestampMs: Option[Long] = None,
57+
metrics: Map[String, SQLMetric])
5358
extends OneToOneTTLState(
5459
stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig,
55-
batchTimestampMs, metrics) with MapState[K, V] with Logging {
60+
batchTimestampMs, prevBatchTimestampMs, metrics) with MapState[K, V] with Logging {
5661

5762
private val stateTypesEncoder = new CompositeKeyStateEncoder(
5863
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true)

0 commit comments

Comments
 (0)