Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
outputAttr,
stateInfo = None,
batchTimestampMs = None,
prevBatchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
planLater(child),
Expand Down Expand Up @@ -815,6 +816,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, t.leftAttributes, outputAttrs, outputMode, timeMode,
stateInfo = None,
batchTimestampMs = None,
prevBatchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
userFacingDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ case class TransformWithStateInPySparkExec(
timeMode: TimeMode,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
prevBatchTimestampMs: Option[Long] = None,
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
userFacingDataType: TransformWithStateInPySpark.UserFacingDataType.Value,
Expand Down Expand Up @@ -314,7 +315,8 @@ case class TransformWithStateInPySparkExec(
val data = groupAndProject(filteredIter, groupingAttributes, child.output, dedupAttributes)

val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs,
prevBatchTimestampMs, metrics)

val evalType = {
if (userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS) {
Expand Down Expand Up @@ -442,6 +444,7 @@ object TransformWithStateInPySparkExec {
Some(System.currentTimeMillis),
None,
None,
None,
userFacingDataType,
child,
isStreaming = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ case class TransformWithStateExec(
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
prevBatchTimestampMs: Option[Long] = None,
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan,
Expand Down Expand Up @@ -251,7 +252,7 @@ case class TransformWithStateExec(
case ProcessingTime =>
assert(batchTimestampMs.isDefined)
val batchTimestamp = batchTimestampMs.get
processorHandle.getExpiredTimers(batchTimestamp)
processorHandle.getExpiredTimers(batchTimestamp, prevBatchTimestampMs)
.flatMap { case (keyObj, expiryTimestampMs) =>
numExpiredTimers += 1
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
Expand All @@ -260,7 +261,26 @@ case class TransformWithStateExec(
case EventTime =>
assert(eventTimeWatermarkForEviction.isDefined)
val watermark = eventTimeWatermarkForEviction.get
processorHandle.getExpiredTimers(watermark)
// Use the late-events watermark as the scan lower bound only when we can prove that
// it equals the previous batch's eviction watermark for this operator. That holds
// when STATEFUL_OPERATOR_ALLOW_MULTIPLE is true.
//
// When STATEFUL_OPERATOR_ALLOW_MULTIPLE is false (legacy mode), lateEvents == eviction
// for the same batch, so using it would collapse the eviction range to (wm, wm] = empty
// and silently stop firing timers. Fall back to None (full scan) in that mode, as we
// don't expect the legacy mode to be used by many users.
//
// The prevBatchTimestampMs.isDefined check guards against the first batch, where
// watermark propagation yields Some(0) for late events even though no timers have been
// processed yet, which would incorrectly skip timers registered at timestamp 0.
val prevWatermark =
if (prevBatchTimestampMs.isDefined &&
conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) {
eventTimeWatermarkForLateEvents
} else {
None
}
processorHandle.getExpiredTimers(watermark, prevWatermark)
.flatMap { case (keyObj, expiryTimestampMs) =>
numExpiredTimers += 1
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
Expand Down Expand Up @@ -493,7 +513,7 @@ case class TransformWithStateExec(
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, timeMode,
isStreaming, batchTimestampMs, metrics)
isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
withStatefulProcessorErrorHandling("init") {
Expand All @@ -509,7 +529,7 @@ case class TransformWithStateExec(
initStateIterator: Iterator[InternalRow]):
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
keyEncoder, timeMode, isStreaming, batchTimestampMs, prevBatchTimestampMs, metrics)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
withStatefulProcessorErrorHandling("init") {
Expand Down Expand Up @@ -581,6 +601,7 @@ object TransformWithStateExec {
Some(System.currentTimeMillis),
None,
None,
None,
child,
isStreaming = false,
hasInitialState,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class StatefulProcessorHandleImpl(
timeMode: TimeMode,
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
extends StatefulProcessorHandleImplBase(timeMode, keyEncoder) with Logging {
import StatefulProcessorHandleState._
Expand Down Expand Up @@ -171,13 +172,19 @@ class StatefulProcessorHandleImpl(

/**
* Function to retrieve all expired registered timers for all grouping keys
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
* this function will return all timers that have timestamp
* less than or equal to the passed threshold.
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
* Timers at or below this timestamp are assumed to have been
* already processed in the previous batch and will be skipped.
* @return - iterator of registered timers for all grouping keys
*/
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
def getExpiredTimers(
expiryTimestampMs: Long,
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
verifyTimerOperations("get_expired_timers")
timerState.getExpiredTimers(expiryTimestampMs)
timerState.getExpiredTimers(expiryTimestampMs, prevExpiryTimestampMs)
}

/**
Expand Down Expand Up @@ -237,7 +244,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
ttlStates.add(valueStateWithTTL)
TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
valueStateWithTTL
Expand Down Expand Up @@ -286,7 +294,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
ttlStates.add(listStateWithTTL)
listStateWithTTL
Expand Down Expand Up @@ -324,7 +333,8 @@ class StatefulProcessorHandleImpl(
validateTTLConfig(ttlConfig, stateName)
assert(batchTimestampMs.isDefined)
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
valEncoder, ttlConfig, batchTimestampMs.get, metrics)
valEncoder, ttlConfig, batchTimestampMs.get,
prevBatchTimestampMs, metrics)
TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
ttlStates.add(mapStateWithTTL)
mapStateWithTTL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ class TimerStateImpl(
schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, Seq(0)),
useMultipleValuesPerKey = false, isInternal = true)

private val secIndexProjection = UnsafeProjection.create(keySchemaForSecIndex)

// Placeholder grouping-key struct used in range-scan boundary rows; see
// [[RangeScanBoundaryUtils]] for rationale. Correctness relies on real stored
// entries never having a null grouping-key struct, which is preserved by
// registerTimer going through the user's expression encoder. Preserve this
// invariant if you change how entries are written.
private val defaultGroupingKey: InternalRow = RangeScanBoundaryUtils.defaultInternalRow(
keySchemaForSecIndex(1).dataType.asInstanceOf[StructType])

/**
* Encodes a timestamp into an UnsafeRow key for the secondary index.
* The timestamp is incremented by 1 so that the encoded key serves as an exclusive
* lower / upper bound in range scans. Returns None if tsMs is Long.MaxValue
* (overflow guard).
*
* The returned UnsafeRow is always a fresh copy, safe to hold alongside other
* rows produced by the same projection.
*/
private def encodeTimestampAsKey(tsMs: Long): Option[UnsafeRow] = {
if (tsMs < Long.MaxValue) {
val row = new GenericInternalRow(keySchemaForSecIndex.length)
row.setLong(0, tsMs + 1)
Comment on lines +136 to +137
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always valid to only fill the first field of keySchemaForSecIndex?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no - null is "greater" than non-null in UnsafeRow format. I'm going to apply the same approach with #55267 via using defaults on Literal.

row.update(1, defaultGroupingKey)
Some(secIndexProjection.apply(row).copy())
} else {
None
}
}

private def getGroupingKey(cfName: String): Any = {
val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
if (keyOption.isEmpty) {
Expand Down Expand Up @@ -189,15 +219,22 @@ class TimerStateImpl(

/**
* Function to get all the expired registered timers for all grouping keys.
* Perform a range scan on timestamp and will stop iterating once the key row timestamp equals or
* Perform a range scan on timestamp and will stop iterating once the key row timestamp
* exceeds the limit (as timestamp key is increasingly sorted).
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds, this function
* will return all timers that have timestamp less than passed threshold.
* @param expiryTimestampMs Threshold for expired timestamp in milliseconds (inclusive),
* this function will return all timers that have timestamp
* less than or equal to the passed threshold.
* @param prevExpiryTimestampMs If provided, the lower bound (exclusive) of the scan range.
* Timers at or below this timestamp are assumed to have been
* already processed in the previous batch and will be skipped.
* @return - iterator of all the registered timers for all grouping keys
*/
def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
// this iter is increasingly sorted on timestamp
val iter = store.iterator(tsToKeyCFName)
def getExpiredTimers(
expiryTimestampMs: Long,
prevExpiryTimestampMs: Option[Long] = None): Iterator[(Any, Long)] = {
val startKey = prevExpiryTimestampMs.flatMap(encodeTimestampAsKey)
val endKey = encodeTimestampAsKey(expiryTimestampMs)
val iter = store.rangeScan(startKey, endKey, tsToKeyCFName)

new NextIterator[(Any, Long)] {
override protected def getNext(): (Any, Long) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - Spark SQL encoder for value
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored
*/
Expand All @@ -45,9 +49,11 @@ class ListStateImplWithTTL[S](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric])
extends OneToManyTTLState(
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs, metrics) with ListState[S] {
stateName, store, keyExprEnc.schema, ttlConfig, batchTimestampMs,
prevBatchTimestampMs, metrics) with ListState[S] {

private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
stateName, hasTtl = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @param batchTimestampMs - current batch processing timestamp.
* @param prevBatchTimestampMs - batch timestamp from the previous micro-batch (exclusive).
* Entries with expiration at or below this timestamp are assumed
* to have been already cleaned up and will be skipped during
* TTL eviction scans.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
Expand All @@ -49,10 +53,11 @@ class MapStateImplWithTTL[K, V](
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric])
prevBatchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric])
extends OneToOneTTLState(
stateName, store, getCompositeKeySchema(keyExprEnc.schema, userKeyEnc.schema), ttlConfig,
batchTimestampMs, metrics) with MapState[K, V] with Logging {
batchTimestampMs, prevBatchTimestampMs, metrics) with MapState[K, V] with Logging {

private val stateTypesEncoder = new CompositeKeyStateEncoder(
keyExprEnc, userKeyEnc, valEncoder, stateName, hasTtl = true)
Expand Down
Loading