Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6870,6 +6870,11 @@
"joinSide=<joinSide>, storeVersion=<storeVersion>, partitionId=<partitionId>.",
"Enable <configKey> as a workaround to skip null values."
]
},
"RANGE_SCAN_TIMESTAMP_OUT_OF_RANGE" : {
"message" : [
"Range scan returned a row with timestamp <timestamp> outside the expected range [<minTimestamp>, <maxTimestamp>]."
]
}
},
"sqlState" : "XXKST"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ case class StreamingSymmetricHashJoinExec(
private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes)

private[this] val stateKeyWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(expr, _)) =>
case Some(JoinStateKeyWatermarkPredicate(expr, _, _)) =>
// inputSchema can be empty as expr should only have BoundReferences and does not require
// the schema to generated predicate. See [[StreamingSymmetricHashJoinHelper]].
Predicate.create(expr, Seq.empty).eval _
Expand All @@ -684,7 +684,7 @@ case class StreamingSymmetricHashJoinExec(
}

private[this] val stateValueWatermarkPredicateFunc = stateWatermarkPredicate match {
case Some(JoinStateValueWatermarkPredicate(expr, _)) =>
case Some(JoinStateValueWatermarkPredicate(expr, _, _)) =>
Predicate.create(expr, inputAttributes).eval _
case _ =>
Predicate.create(Literal(false), Seq.empty).eval _ // false = do not remove if no predicate
Expand Down Expand Up @@ -905,21 +905,25 @@ case class StreamingSymmetricHashJoinExec(
*/
def removeOldState(): Long = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case _ => 0L
}
Expand All @@ -937,21 +941,25 @@ case class StreamingSymmetricHashJoinExec(
*/
def removeAndReturnOldState(): Iterator[KeyToValuePair] = {
stateWatermarkPredicate match {
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateKeyWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictAndReturnByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark, prevStateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
s.evictAndReturnByTimestamp(
watermarkMsToStateTimestamp(stateWatermark),
prevStateWatermark.map(watermarkMsToStateTimestamp))
}
case _ => Iterator.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,37 @@ object StreamingSymmetricHashJoinHelper extends Logging {
def desc: String
override def toString: String = s"$desc: $expr"
}
/** Predicate for watermark on state keys */
case class JoinStateKeyWatermarkPredicate(expr: Expression, stateWatermark: Long)
/**
* Predicate for watermark on state keys.
*
* @param stateWatermark Current batch's eviction watermark. Entries with timestamp
* at or below this value are eligible for eviction in this batch.
* @param prevStateWatermark Previous batch's eviction watermark, i.e. the watermark
* used for filtering late events in the current batch.
* Entries with timestamp at or below this value were already
* evicted in prior batches, so the effective range of entries
* to evict in this batch is `(prevStateWatermark, stateWatermark]`.
* State manager implementations can leverage this lower bound
* to optimize eviction (e.g. narrowing the scan range to skip
* already-evicted entries). `None` means we do not have a known
* lower bound (e.g. the first batch after restart), in which
* case eviction must consider all entries up to `stateWatermark`.
*/
case class JoinStateKeyWatermarkPredicate(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a high level comment to explain why the prevStateWatermark is passed here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! Done.

expr: Expression,
stateWatermark: Long,
prevStateWatermark: Option[Long] = None)
extends JoinStateWatermarkPredicate {
def desc: String = "key predicate"
}
/** Predicate for watermark on state values */
case class JoinStateValueWatermarkPredicate(expr: Expression, stateWatermark: Long)
/**
* Predicate for watermark on state values. See [[JoinStateKeyWatermarkPredicate]] for
* the semantics of `stateWatermark` and `prevStateWatermark`.
*/
case class JoinStateValueWatermarkPredicate(
expr: Expression,
stateWatermark: Long,
prevStateWatermark: Option[Long] = None)
extends JoinStateWatermarkPredicate {
def desc: String = "value predicate"
}
Expand Down Expand Up @@ -185,6 +209,7 @@ object StreamingSymmetricHashJoinHelper extends Logging {
rightKeys: Seq[Expression],
condition: Option[Expression],
eventTimeWatermarkForEviction: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
useFirstEventTimeColumn: Boolean): JoinStateWatermarkPredicates = {

// Perform assertions against multiple event time columns in the same DataFrame. This method
Expand Down Expand Up @@ -215,20 +240,30 @@ object StreamingSymmetricHashJoinHelper extends Logging {
expr.map { e =>
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
// is defined
JoinStateKeyWatermarkPredicate(e, eventTimeWatermarkForEviction.get)
JoinStateKeyWatermarkPredicate(
e,
eventTimeWatermarkForEviction.get,
eventTimeWatermarkForLateEvents)
}
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark(
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
condition,
eventTimeWatermarkForEviction)
val prevStateValueWatermark = eventTimeWatermarkForLateEvents.flatMap { _ =>
StreamingJoinHelper.getStateValueWatermark(
attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes),
attributesWithEventWatermark = AttributeSet(otherSideInputAttributes),
condition,
eventTimeWatermarkForLateEvents)
}
val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey))
val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark)
expr.map { e =>
// watermarkExpression only provides the expression when eventTimeWatermarkForEviction
// is defined
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get)
JoinStateValueWatermarkPredicate(e, stateValueWatermark.get, prevStateValueWatermark)
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, RangeScanBoundaryUtils, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType}
import org.apache.spark.util.NextIterator
Expand Down Expand Up @@ -184,15 +184,28 @@ trait SupportsEvictByCondition { self: SymmetricHashJoinStateManager =>
trait SupportsEvictByTimestamp { self: SymmetricHashJoinStateManager =>
import SymmetricHashJoinStateManager._

/** Evict the state by timestamp. Returns the number of values evicted. */
def evictByTimestamp(endTimestamp: Long): Long
/**
* Evict the state by timestamp. Returns the number of values evicted.
*
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
* assumed to have been evicted already (e.g. from the previous batch). When provided,
* the scan starts from startTimestamp + 1.
*/
def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long

/**
* Evict the state by timestamp and return the evicted key-value pairs.
*
* It is caller's responsibility to consume the whole iterator.
*
* @param endTimestamp Inclusive upper bound: evicts entries with timestamp <= endTimestamp.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp are
* assumed to have been evicted already (e.g. from the previous batch). When provided,
* the scan starts from startTimestamp + 1.
*/
def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair]
def evictAndReturnByTimestamp(
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair]
}

/**
Expand Down Expand Up @@ -519,11 +532,11 @@ class SymmetricHashJoinStateManagerV4(
}
}

override def evictByTimestamp(endTimestamp: Long): Long = {
override def evictByTimestamp(endTimestamp: Long, startTimestamp: Option[Long] = None): Long = {
require(hasEventTime,
"evictByTimestamp requires event time; secondary index was not populated")
var removed = 0L
tsWithKey.scanEvictedKeys(endTimestamp).foreach { evicted =>
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).foreach { evicted =>
val key = evicted.key
val timestamp = evicted.timestamp
val numValues = evicted.numValues
Expand All @@ -537,12 +550,13 @@ class SymmetricHashJoinStateManagerV4(
removed
}

override def evictAndReturnByTimestamp(endTimestamp: Long): Iterator[KeyToValuePair] = {
override def evictAndReturnByTimestamp(
endTimestamp: Long, startTimestamp: Option[Long] = None): Iterator[KeyToValuePair] = {
require(hasEventTime,
"evictAndReturnByTimestamp requires event time; secondary index was not populated")
val reusableKeyToValuePair = KeyToValuePair()

tsWithKey.scanEvictedKeys(endTimestamp).flatMap { evicted =>
tsWithKey.scanEvictedKeys(endTimestamp, startTimestamp).flatMap { evicted =>
val key = evicted.key
val timestamp = evicted.timestamp
val values = keyWithTsToValues.get(key, timestamp)
Expand Down Expand Up @@ -663,14 +677,33 @@ class SymmetricHashJoinStateManagerV4(

/**
* Returns entries where minTs <= timestamp <= maxTs (both inclusive), grouped by timestamp.
* Skips entries before minTs and stops iterating past maxTs (timestamps are sorted).
* When maxTs is bounded (< Long.MaxValue), uses rangeScanWithMultiValues for efficient
* range access; falls back to prefixScan otherwise to stay within the key's scope.
*
* When prefixScan is used (maxTs == Long.MaxValue), entries outside [minTs, maxTs] are
* filtered out so both code paths produce identical results.
*/
def getValuesInRange(
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
val reusableGetValuesResult = new GetValuesResult()
// Only use rangeScan when maxTs < Long.MaxValue, since rangeScan requires
// an exclusive end key (maxTs + 1) which would overflow at Long.MaxValue.
val useRangeScan = maxTs < Long.MaxValue

new NextIterator[GetValuesResult] {
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)
private val iter = if (useRangeScan) {
// startKey must be copied because the second createKeyRow call below reuses
// the same projection buffer and would otherwise overwrite its contents.
// endKey does not need a copy: rangeScanWithMultiValues encodes both bounds
// to independent byte arrays eagerly at call time, and the scope of endKey
// ends with the call of rangeScanWithMultiValues.
val startKey = createKeyRow(key, minTs).copy()
// rangeScanWithMultiValues endKey is exclusive, so use maxTs + 1
val endKey = Some(createKeyRow(key, maxTs + 1))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to copy it like startKey?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to copy endKey since startKey and endKey should co-exist at the same time, but once we call the rangeScanWithMultiValues, both startKey and endKey aren't used.

I'm leaving code comment instead.

stateStore.rangeScanWithMultiValues(Some(startKey), endKey, colFamilyName)
} else {
stateStore.prefixScanWithMultiValues(key, colFamilyName)
}

private var currentTs = -1L
private var pastUpperBound = false
Expand All @@ -697,6 +730,11 @@ class SymmetricHashJoinStateManagerV4(
val unsafeRowPair = iter.next()
val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)

if (useRangeScan && (ts < minTs || ts > maxTs)) {
throw StateStoreErrors.streamStreamJoinRangeScanTimestampOutOfRange(
ts, minTs, maxTs)
}

if (ts > maxTs) {
pastUpperBound = true
getNext()
Expand Down Expand Up @@ -773,6 +811,8 @@ class SymmetricHashJoinStateManagerV4(
isInternal = true
)

// Returns an UnsafeRow backed by a reused projection buffer. Callers that need to
// hold the row beyond the immediate state store call must invoke copy() on the result.
private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
TimestampKeyStateEncoder.attachTimestamp(
attachTimestampProjection, keySchemaWithTimestamp, key, timestamp)
Expand All @@ -788,9 +828,60 @@ class SymmetricHashJoinStateManagerV4(

case class EvictedKeysResult(key: UnsafeRow, timestamp: Long, numValues: Int)

// NOTE: This assumes we consume the whole iterator to trigger completion.
def scanEvictedKeys(endTimestamp: Long): Iterator[EvictedKeysResult] = {
val evictIterator = stateStore.iteratorWithMultiValues(colFamilyName)
// Reusable default key row for scan boundary construction; see
// [[RangeScanBoundaryUtils]] for rationale. Safe to reuse because createKeyRow
// only reads this row (via BoundReference evaluations) and writes to the
// projection's own internal buffer. Correctness relies on real stored entries
// never having internally-null key fields, which is preserved by join-key
// expressions being evaluated via the user's expression encoder. Preserve this
// invariant if you change how entries are written.
private lazy val defaultKey: UnsafeRow = RangeScanBoundaryUtils.defaultUnsafeRow(keySchema)

/**
* Build a scan boundary row for rangeScan. The TsWithKeyTypeStore uses
* TimestampAsPrefixKeyStateEncoder, which encodes the row as [timestamp][key_fields].
* We need a full-schema row (not just the timestamp) because the encoder expects all
* key columns to be present. Default values are used for the key fields since only the
* timestamp matters for ordering in the prefix encoder.
*/
private def createScanBoundaryRow(timestamp: Long): UnsafeRow = {
createKeyRow(defaultKey, timestamp).copy()
}

/**
* Scan keys eligible for eviction within the timestamp range.
*
* This assumes we consume the whole iterator to trigger completion.
*
* @param endTimestamp Inclusive upper bound: entries with timestamp <= endTimestamp are
* eligible for eviction.
* @param startTimestamp Exclusive lower bound: entries with timestamp <= startTimestamp
* are assumed to have been evicted already. The scan starts from startTimestamp + 1.
*/
def scanEvictedKeys(
endTimestamp: Long,
startTimestamp: Option[Long] = None): Iterator[EvictedKeysResult] = {
// If startTimestamp == Long.MaxValue, everything has already been evicted;
// nothing can match, so return immediately.
if (startTimestamp.contains(Long.MaxValue)) {
return Iterator.empty
}

// rangeScanWithMultiValues: startKey is inclusive, endKey is exclusive.
// startTimestamp is exclusive (already evicted), so we seek from st + 1.
val startKeyRow = startTimestamp.map { st =>
createScanBoundaryRow(st + 1)
}
// endTimestamp is inclusive, so we use endTimestamp + 1 as the exclusive upper bound.
// When endTimestamp == Long.MaxValue we cannot add 1, so endKeyRow is None. This is
// safe because rangeScanWithMultiValues with no end key uses the column-family prefix
// as the upper bound, naturally scoping the scan within this column family.
val endKeyRow = if (endTimestamp < Long.MaxValue) {
Some(createScanBoundaryRow(endTimestamp + 1))
} else {
None
}
val evictIterator = stateStore.rangeScanWithMultiValues(startKeyRow, endKeyRow, colFamilyName)
new NextIterator[EvictedKeysResult]() {
var currentKeyRow: UnsafeRow = null
var currentEventTime: Long = -1L
Expand Down
Loading