Skip to content

Commit e2a204b

Browse files
committed
[SPARK-XXXXX][CORE] Add ConcurrentStageDAGScheduler for low-latency streaming
Ports the ConcurrentStageDAGScheduler from the Databricks runtime so that streaming queries can opt in to a "real-time" execution mode that runs all stages of a job concurrently rather than sequentially. When enabled via spark.scheduler.dagSchedulerType=ConcurrentStageDAGScheduler and the per-job streaming.concurrent.stages.enabled property, the scheduler: - Marks all ancestor stages of the final stage as concurrent on job submission and validates that the cluster has enough free slots (CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT), gated by spark.scheduler.realtimeModeSlotsCheck.disabled. - Submits child stages while parents are still running, delays task completion events for a child whose parent is still running, and replays the delayed events when the parent finishes. - Rejects speculative execution. DAGScheduler changes (no-op for the default scheduler): - New protected onFinalStageCreated hook, invoked from handleJobSubmitted / handleMapStageSubmitted right after final stage creation. - New protected submitConcurrentStage and postSchedulerEvent helpers. - New package-visible isRunningStage and getStage accessors. - submitStage and markStageAsFinished relaxed from private to protected so subclasses can override them. DAGSchedulerSuite refactor: - Renames the concrete suite to abstract DAGSchedulerSuiteBase and adds an empty class DAGSchedulerSuite extends DAGSchedulerSuiteBase to preserve the existing entry point. - Extracts a TestDAGScheduler trait carrying the scheduleShuffleMergeFinalize and handleTaskCompletion overrides; MyDAGScheduler mixes the trait in. - Adds a protected createInitialScheduler hook used by init(). - Loosens submit, completeShuffleMapStageSuccessfully, completeNextResultStageWithSuccess, and assertDataStructuresEmpty to protected so subclass suites can use them. Integration: - SparkContext picks the scheduler implementation based on spark.scheduler.dagSchedulerType. - TaskSchedulerImpl uses maxFailures=1 for concurrent-stage TaskSets so a failure restarts the streaming query instead of being silently retried. - TaskSetManager counts ExecutorLostFailure toward task failures and skips the "executor lost is not the task's fault" exemption in concurrent mode. Adds the supporting LogKeys (PARENT_STAGE, STREAMING_QUERY_ID) and the CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT error class. Deviations from the runtime source kept to the minimum necessary to compile in OSS: - Extends DAGScheduler directly (runtime extends CrossJobDepDAGScheduler, which gates micro-batch pipelining; not part of OSS). - Hook is named onFinalStageCreated rather than the runtime's populateCrossJobDepInfo, since CrossJobDepDAGScheduler is not part of OSS. - Micro-batch pipelining co-existence check (and its test) dropped, since MBP is not part of OSS. - getStreamingBatchIdFromProperties and StreamingBatchId live in the companion object instead of CrossJobDepDAGScheduler. - Slot check uses sc.schedulerBackend.defaultParallelism() in place of the runtime's TaskSchedulerStats helper. - DatabricksEdgeConfigs.serverlessEnabled gating removed; the spark.scheduler.realtimeModeSlotsCheck.disabled config is the sole knob. - isConcurrentStagesEnabled tolerates null Properties (OSS TaskSet allows null in tests). Co-authored-by: Isaac
1 parent 706b6a3 commit e2a204b

12 files changed

Lines changed: 776 additions & 32 deletions

File tree

common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ public enum LogKeys implements LogKey {
577577
OUTPUT_BUFFER,
578578
OVERHEAD_MEMORY_SIZE,
579579
PAGE_SIZE,
580+
PARENT_STAGE,
580581
PARENT_STAGES,
581582
PARSE_MODE,
582583
PARTITIONED_FILE_READER,
@@ -792,6 +793,7 @@ public enum LogKeys implements LogKey {
792793
STREAMING_DATA_SOURCE_NAME,
793794
STREAMING_OFFSETS_END,
794795
STREAMING_OFFSETS_START,
796+
STREAMING_QUERY_ID,
795797
STREAMING_QUERY_PROGRESS,
796798
STREAMING_SOURCE,
797799
STREAMING_TABLE,

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,12 @@
890890
],
891891
"sqlState" : "0A000"
892892
},
893+
"CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT" : {
894+
"message" : [
895+
"The minimum number of free slots required in the cluster is <numTasks>, however, the cluster has only has <numSlots> slots free. Query will stall or fail. Increase cluster size to proceed."
896+
],
897+
"sqlState" : "53000"
898+
},
893899
"CONCURRENT_STREAM_LOG_UPDATE" : {
894900
"message" : [
895901
"Concurrent update to the log. Multiple streaming jobs detected for <batchId>.",

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,11 @@ class SparkContext(config: SparkConf) extends Logging {
600600
val (sched, ts) = SparkContext.createTaskScheduler(this, master)
601601
_schedulerBackend = sched
602602
_taskScheduler = ts
603-
_dagScheduler = new DAGScheduler(this)
603+
_dagScheduler = conf.get(DAG_SCHEDULER_TYPE) match {
604+
case "ConcurrentStageDAGScheduler" =>
605+
new ConcurrentStageDAGScheduler(this)
606+
case _ => new DAGScheduler(this)
607+
}
604608
_heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet)
605609

606610
if (_conf.get(EXECUTOR_ALLOW_SYNC_LOG_LEVEL)) {

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,6 +2396,26 @@ package object config {
23962396
.booleanConf
23972397
.createWithDefault(true)
23982398

2399+
private[spark] val STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED =
2400+
ConfigBuilder("spark.scheduler.realtimeModeSlotsCheck.disabled")
2401+
.internal()
2402+
.doc("For query running in real time mode, disable the check if the number of slots" +
2403+
" required by all concurrent stages is available before submit the query" )
2404+
.withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE)
2405+
.version("4.2.0")
2406+
.booleanConf
2407+
.createWithDefault(false)
2408+
2409+
private[spark] val DAG_SCHEDULER_TYPE =
2410+
ConfigBuilder("spark.scheduler.dagSchedulerType")
2411+
.internal()
2412+
.doc("The DAGScheduler implementation to use. Set to 'ConcurrentStageDAGScheduler' to " +
2413+
"enable real-time mode, which runs stages concurrently for low-latency streaming queries.")
2414+
.withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE)
2415+
.version("4.2.0")
2416+
.stringConf
2417+
.createWithDefault("DAGScheduler")
2418+
23992419
private[spark] val STREAMING_ID_AWARE_SCHEDULER_LOGGING_QUERY_ID_LENGTH =
24002420
ConfigBuilder("spark.scheduler.streaming.idAwareLogging.queryIdLength")
24012421
.doc("Maximum number of characters of the streaming query ID to include " +
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.scheduler
19+
20+
import java.util.Properties
21+
22+
import scala.collection.mutable
23+
24+
import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, SparkException, SparkRuntimeException, Success}
25+
import org.apache.spark.internal.LogKeys
26+
import org.apache.spark.internal.config.{SPECULATION_ENABLED, STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
27+
import org.apache.spark.resource.ResourceProfile
28+
import org.apache.spark.storage.BlockManagerMaster
29+
import org.apache.spark.util.Clock
30+
import org.apache.spark.util.SystemClock
31+
32+
/**
33+
* A [[DAGScheduler]] that runs all the stages in a job without waiting for its parents
34+
* complete. This combined with streaming shuffle between the stages, allows for low latency
35+
* execution of streaming queries in real-time mode.
36+
*/
37+
class ConcurrentStageDAGScheduler(
38+
sc: SparkContext,
39+
taskScheduler: TaskScheduler,
40+
listenerBus: LiveListenerBus,
41+
mapOutputTracker: MapOutputTrackerMaster,
42+
blockManagerMaster: BlockManagerMaster,
43+
env: SparkEnv,
44+
clock: Clock = new SystemClock())
45+
extends DAGScheduler(
46+
sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) {
47+
48+
import ConcurrentStageDAGScheduler._
49+
50+
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
51+
this(
52+
sc,
53+
taskScheduler,
54+
sc.listenerBus,
55+
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
56+
sc.env.blockManager.master,
57+
sc.env
58+
)
59+
}
60+
61+
def this(sc: SparkContext) = this(sc, sc.taskScheduler)
62+
63+
// This contains all the concurrent states that are yet to be scheduled across all the jobs.
64+
private[spark] val concurrentStages = new mutable.HashSet[Stage]
65+
66+
private[scheduler] case class DependentStageInfo(
67+
parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
68+
delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = mutable.ListBuffer.empty)
69+
70+
// This map holds parents of concurrently scheduled stages. When tasks for such a stage complete,
71+
// and if any of the parents are still running, we delay processing of such events until parent
72+
// stages are complete. We save these events in this map until then.
73+
private[spark] val dependentStageMap = new mutable.HashMap[Stage, DependentStageInfo]
74+
75+
private def totalNumCoreForStage(stage: Stage): Int = {
76+
val numTask = stage match {
77+
case r: ResultStage => r.partitions.length
78+
case m: ShuffleMapStage => m.numPartitions
79+
}
80+
val resourceProfile = sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
81+
val taskCpus = ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
82+
taskCpus * numTask
83+
}
84+
85+
/**
86+
* Hook invoked after the final stage is created. Registers stages reachable from
87+
* the final stage as concurrent so they can be submitted in parallel.
88+
*/
89+
override def onFinalStageCreated(finalStage: Stage, properties: Properties): Unit = {
90+
91+
val queryBatchId = getStreamingBatchIdFromProperties(properties)
92+
93+
if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
94+
if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
95+
// Speculation is not supported with concurrent stages.
96+
throw new SparkException(
97+
"Speculative execution is not supported with concurrent stages " +
98+
s"(streaming query: $queryBatchId). Please disable ${SPECULATION_ENABLED.key} config."
99+
)
100+
}
101+
102+
logInfo(log"Concurrent stages is enabled for [query ${MDC(LogKeys.STREAMING_QUERY_ID,
103+
queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
104+
105+
// Mark current stage and all its ancestors as concurrent
106+
var totalCoresNeeded = 0
107+
def visit(stage: Stage): Unit = {
108+
if (!concurrentStages.contains(stage)) {
109+
logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent for [query ${MDC(
110+
LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
111+
LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
112+
concurrentStages += stage
113+
totalCoresNeeded += totalNumCoreForStage(stage)
114+
stage.parents.foreach(visit)
115+
}
116+
}
117+
visit(finalStage)
118+
119+
if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
120+
try {
121+
val totalSlots = sc.schedulerBackend.defaultParallelism()
122+
val coresInUse = runningStages.toArray.map(totalNumCoreForStage(_)).sum
123+
if (totalSlots - coresInUse < totalCoresNeeded) {
124+
throw new SparkRuntimeException(
125+
errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT",
126+
messageParameters = Map(
127+
"numSlots" -> (totalSlots - coresInUse).toString,
128+
"numTasks" -> totalCoresNeeded.toString))
129+
}
130+
} catch {
131+
case e: UnsupportedOperationException =>
132+
logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for RTM.")
133+
}
134+
}
135+
} else {
136+
super.onFinalStageCreated(finalStage, properties)
137+
}
138+
}
139+
140+
override def submitStage(stage: Stage): Unit = {
141+
super.submitStage(stage)
142+
143+
if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) {
144+
// The current stage is not registered in waitingStages, which means it has
145+
// no parents. This case we should remove it from concurrentStages since it is already
146+
// running.
147+
assert(runningStages.contains(stage), "stage should be running if not in waitingStages")
148+
logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from concurrentStages")
149+
concurrentStages -= stage
150+
}
151+
152+
// Find the stages that should be submitted concurrently with this stage.
153+
waitingStages.intersect(concurrentStages).foreach { stage =>
154+
logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}")
155+
concurrentStages -= stage // Don't submit this stage concurrently for subsequent attempts.
156+
stage.parents.foreach { parent =>
157+
if (isRunningStage(parent)) {
158+
logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE, stage)} with parent ${
159+
MDC(LogKeys.PARENT_STAGE, parent)}")
160+
dependentStageMap.getOrElseUpdate(stage, DependentStageInfo()).parents += parent
161+
}
162+
}
163+
// Remove stage and its parents from concurrentStages
164+
def removeFromConcurrentStages(stage: Stage): Unit = {
165+
if (concurrentStages.contains(stage)) {
166+
logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from concurrentStages")
167+
concurrentStages -= stage
168+
}
169+
stage.parents.foreach { parent =>
170+
assert(!waitingStages.contains(parent), "Parent stage should not still be waiting")
171+
removeFromConcurrentStages(parent)
172+
}
173+
}
174+
removeFromConcurrentStages(stage)
175+
submitConcurrentStage(stage)
176+
}
177+
}
178+
179+
// This is overridden to check if the task completion event should be delayed a parent stage
180+
// till has running tasks. See comment for `dependentStageMap` for more details.
181+
override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
182+
val stageId = event.task.stageId
183+
val taskId = event.taskInfo.taskId
184+
185+
getStage(stageId) match {
186+
case Some(stage) if event.reason == Success && dependentStageMap.contains(stage) =>
187+
val dependentStageInfo = dependentStageMap(stage)
188+
logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID, taskId)} in stage ${
189+
MDC(LogKeys.STAGE, stage)}. Active parent(s): ${MDC(LogKeys.PARENT_STAGES,
190+
dependentStageInfo.parents.mkString(", "))}")
191+
dependentStageInfo.delayedTaskCompletionEvents += event
192+
193+
case _ => // Otherwise handle the event as usual.
194+
super.handleTaskCompletion(event)
195+
}
196+
}
197+
198+
// This is overridden to handle any delayed task completion events for dependent stages.
199+
override def markStageAsFinished(
200+
stage: Stage,
201+
errorMessage: Option[String] = None,
202+
willRetry: Boolean = false): Unit = {
203+
204+
super.markStageAsFinished(stage, errorMessage, willRetry)
205+
206+
// If this is a parent of a stage in dependentStageMap, remove it from parents.
207+
val dependentStages = dependentStageMap
208+
.filter(_._2.parents.contains(stage))
209+
.keys
210+
211+
dependentStages.foreach { dependent =>
212+
if (errorMessage.isEmpty) {
213+
assert(
214+
isRunningStage(dependent),
215+
s"Parent stages $stage's dependent stage $dependent should be running")
216+
}
217+
logInfo(log"Removing parent stage ${MDC(LogKeys.PARENT_STAGE, stage)} from dependent map " +
218+
log"for stage ${MDC(LogKeys.STAGE, dependent)}")
219+
dependentStageMap(dependent).parents -= stage
220+
checkDependentStageTasks(dependent)
221+
}
222+
}
223+
224+
// Checks if the dependent stage's parents are all done. If all the parents are done,
225+
// enqueues any saved task completion event (if any).
226+
private def checkDependentStageTasks(stage: Stage): Unit = {
227+
val dependentStageInfo = dependentStageMap.getOrElse(
228+
stage, throw new RuntimeException(s"Stage $stage is not in dependentStageMap")
229+
)
230+
231+
if (dependentStageInfo.parents.isEmpty) {
232+
val delayedEvents = dependentStageInfo.delayedTaskCompletionEvents
233+
logInfo(log"All the parents are done for ${MDC(LogKeys.STAGE, stage)}. Removing it from " +
234+
log"the map. It has ${MDC(LogKeys.NUM_EVENTS, delayedEvents.size.toLong)} " +
235+
log"task completion events")
236+
dependentStageMap -= stage
237+
delayedEvents.foreach { event =>
238+
logInfo(log"Posting delayed task ${MDC(LogKeys.TASK_ID, event.taskInfo.taskId)} " +
239+
log"completion event for stage ${MDC(LogKeys.STAGE, stage)}")
240+
postSchedulerEvent(event)
241+
}
242+
}
243+
}
244+
}
245+
246+
object ConcurrentStageDAGScheduler {
247+
248+
val CONCURRENT_STAGES_ENABLED_PROPERTY: String = "streaming.concurrent.stages.enabled"
249+
250+
def isConcurrentStagesEnabled(properties: Properties): Boolean = {
251+
properties != null &&
252+
properties.getProperty(CONCURRENT_STAGES_ENABLED_PROPERTY) == "true"
253+
}
254+
255+
/**
256+
* Extracts the [[StreamingBatchId]] from the given properties if all three of the streaming
257+
* query id, run id and batch id are present.
258+
*/
259+
def getStreamingBatchIdFromProperties(properties: Properties): Option[StreamingBatchId] = {
260+
if (properties == null) {
261+
return None
262+
}
263+
264+
val queryId = Option(properties.getProperty("sql.streaming.queryId"))
265+
val runId = Option(properties.getProperty("sql.streaming.runId"))
266+
val batchId = Option(properties.getProperty("streaming.sql.batchId"))
267+
if (queryId.nonEmpty && runId.nonEmpty && batchId.nonEmpty) {
268+
Some(StreamingBatchId(queryId.get, runId.get, batchId.get.toLong))
269+
} else {
270+
None
271+
}
272+
}
273+
}
274+
275+
/**
276+
* Case class to identify a batch in a streaming query.
277+
*
278+
* @param queryId - Streaming query id
279+
* @param runId - Streaming query run id
280+
* @param batchId - Batch id for a micro batch in a streaming query
281+
*/
282+
case class StreamingBatchId(queryId: String, runId: String, batchId: Long)

0 commit comments

Comments
 (0)