From e2a204b7f68ef99848bb237130068250b9ba4566 Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Thu, 21 May 2026 05:25:56 +0000 Subject: [PATCH 1/5] [SPARK-XXXXX][CORE] Add ConcurrentStageDAGScheduler for low-latency streaming Ports the ConcurrentStageDAGScheduler from the Databricks runtime so that streaming queries can opt in to a "real-time" execution mode that runs all stages of a job concurrently rather than sequentially. When enabled via spark.scheduler.dagSchedulerType=ConcurrentStageDAGScheduler and the per-job streaming.concurrent.stages.enabled property, the scheduler: - Marks all ancestor stages of the final stage as concurrent on job submission and validates that the cluster has enough free slots (CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT), gated by spark.scheduler.realtimeModeSlotsCheck.disabled. - Submits child stages while parents are still running, delays task completion events for a child whose parent is still running, and replays the delayed events when the parent finishes. - Rejects speculative execution. DAGScheduler changes (no-op for the default scheduler): - New protected onFinalStageCreated hook, invoked from handleJobSubmitted / handleMapStageSubmitted right after final stage creation. - New protected submitConcurrentStage and postSchedulerEvent helpers. - New package-visible isRunningStage and getStage accessors. - submitStage and markStageAsFinished relaxed from private to protected so subclasses can override them. DAGSchedulerSuite refactor: - Renames the concrete suite to abstract DAGSchedulerSuiteBase and adds an empty class DAGSchedulerSuite extends DAGSchedulerSuiteBase to preserve the existing entry point. - Extracts a TestDAGScheduler trait carrying the scheduleShuffleMergeFinalize and handleTaskCompletion overrides; MyDAGScheduler mixes the trait in. - Adds a protected createInitialScheduler hook used by init(). - Loosens submit, completeShuffleMapStageSuccessfully, completeNextResultStageWithSuccess, and assertDataStructuresEmpty to protected so subclass suites can use them. Integration: - SparkContext picks the scheduler implementation based on spark.scheduler.dagSchedulerType. - TaskSchedulerImpl uses maxFailures=1 for concurrent-stage TaskSets so a failure restarts the streaming query instead of being silently retried. - TaskSetManager counts ExecutorLostFailure toward task failures and skips the "executor lost is not the task's fault" exemption in concurrent mode. Adds the supporting LogKeys (PARENT_STAGE, STREAMING_QUERY_ID) and the CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT error class. Deviations from the runtime source kept to the minimum necessary to compile in OSS: - Extends DAGScheduler directly (runtime extends CrossJobDepDAGScheduler, which gates micro-batch pipelining; not part of OSS). - Hook is named onFinalStageCreated rather than the runtime's populateCrossJobDepInfo, since CrossJobDepDAGScheduler is not part of OSS. - Micro-batch pipelining co-existence check (and its test) dropped, since MBP is not part of OSS. - getStreamingBatchIdFromProperties and StreamingBatchId live in the companion object instead of CrossJobDepDAGScheduler. - Slot check uses sc.schedulerBackend.defaultParallelism() in place of the runtime's TaskSchedulerStats helper. - DatabricksEdgeConfigs.serverlessEnabled gating removed; the spark.scheduler.realtimeModeSlotsCheck.disabled config is the sole knob. - isConcurrentStagesEnabled tolerates null Properties (OSS TaskSet allows null in tests). Co-authored-by: Isaac --- .../org/apache/spark/internal/LogKeys.java | 2 + .../resources/error/error-conditions.json | 6 + .../scala/org/apache/spark/SparkContext.scala | 6 +- .../spark/internal/config/package.scala | 20 ++ .../ConcurrentStageDAGScheduler.scala | 282 ++++++++++++++++++ .../apache/spark/scheduler/DAGScheduler.scala | 37 ++- .../spark/scheduler/TaskSchedulerImpl.scala | 9 +- .../spark/scheduler/TaskSetManager.scala | 22 +- .../ConcurrentStageDAGSchedulerSuite.scala | 280 +++++++++++++++++ .../spark/scheduler/DAGSchedulerSuite.scala | 64 ++-- .../scheduler/TaskSchedulerImplSuite.scala | 28 ++ .../spark/scheduler/TaskSetManagerSuite.scala | 52 ++++ 12 files changed, 776 insertions(+), 32 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index e92ef6f462a3f..fc3777d1a93fd 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -577,6 +577,7 @@ public enum LogKeys implements LogKey { OUTPUT_BUFFER, OVERHEAD_MEMORY_SIZE, PAGE_SIZE, + PARENT_STAGE, PARENT_STAGES, PARSE_MODE, PARTITIONED_FILE_READER, @@ -792,6 +793,7 @@ public enum LogKeys implements LogKey { STREAMING_DATA_SOURCE_NAME, STREAMING_OFFSETS_END, STREAMING_OFFSETS_START, + STREAMING_QUERY_ID, STREAMING_QUERY_PROGRESS, STREAMING_SOURCE, STREAMING_TABLE, diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f1e162a6260f7..1f25a85266622 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -890,6 +890,12 @@ ], "sqlState" : "0A000" }, + "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT" : { + "message" : [ + "The minimum number of free slots required in the cluster is , however, the cluster has only has slots free. Query will stall or fail. Increase cluster size to proceed." + ], + "sqlState" : "53000" + }, "CONCURRENT_STREAM_LOG_UPDATE" : { "message" : [ "Concurrent update to the log. Multiple streaming jobs detected for .", diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0262144490ce8..6cd9d3895e9ac 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -600,7 +600,11 @@ class SparkContext(config: SparkConf) extends Logging { val (sched, ts) = SparkContext.createTaskScheduler(this, master) _schedulerBackend = sched _taskScheduler = ts - _dagScheduler = new DAGScheduler(this) + _dagScheduler = conf.get(DAG_SCHEDULER_TYPE) match { + case "ConcurrentStageDAGScheduler" => + new ConcurrentStageDAGScheduler(this) + case _ => new DAGScheduler(this) + } _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) if (_conf.get(EXECUTOR_ALLOW_SYNC_LOG_LEVEL)) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 86e5422a85515..0ea0b1a46f691 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2396,6 +2396,26 @@ package object config { .booleanConf .createWithDefault(true) + private[spark] val STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED = + ConfigBuilder("spark.scheduler.realtimeModeSlotsCheck.disabled") + .internal() + .doc("For query running in real time mode, disable the check if the number of slots" + + " required by all concurrent stages is available before submit the query" ) + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .version("4.2.0") + .booleanConf + .createWithDefault(false) + + private[spark] val DAG_SCHEDULER_TYPE = + ConfigBuilder("spark.scheduler.dagSchedulerType") + .internal() + .doc("The DAGScheduler implementation to use. Set to 'ConcurrentStageDAGScheduler' to " + + "enable real-time mode, which runs stages concurrently for low-latency streaming queries.") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .version("4.2.0") + .stringConf + .createWithDefault("DAGScheduler") + private[spark] val STREAMING_ID_AWARE_SCHEDULER_LOGGING_QUERY_ID_LENGTH = ConfigBuilder("spark.scheduler.streaming.idAwareLogging.queryIdLength") .doc("Maximum number of characters of the streaming query ID to include " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala new file mode 100644 index 0000000000000..a2591821b39b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +import scala.collection.mutable + +import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, SparkException, SparkRuntimeException, Success} +import org.apache.spark.internal.LogKeys +import org.apache.spark.internal.config.{SPECULATION_ENABLED, STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED} +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.storage.BlockManagerMaster +import org.apache.spark.util.Clock +import org.apache.spark.util.SystemClock + +/** + * A [[DAGScheduler]] that runs all the stages in a job without waiting for its parents + * complete. This combined with streaming shuffle between the stages, allows for low latency + * execution of streaming queries in real-time mode. + */ +class ConcurrentStageDAGScheduler( + sc: SparkContext, + taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock()) + extends DAGScheduler( + sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) { + + import ConcurrentStageDAGScheduler._ + + def this(sc: SparkContext, taskScheduler: TaskScheduler) = { + this( + sc, + taskScheduler, + sc.listenerBus, + sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + sc.env.blockManager.master, + sc.env + ) + } + + def this(sc: SparkContext) = this(sc, sc.taskScheduler) + + // This contains all the concurrent states that are yet to be scheduled across all the jobs. + private[spark] val concurrentStages = new mutable.HashSet[Stage] + + private[scheduler] case class DependentStageInfo( + parents: mutable.HashSet[Stage] = mutable.HashSet.empty, + delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = mutable.ListBuffer.empty) + + // This map holds parents of concurrently scheduled stages. When tasks for such a stage complete, + // and if any of the parents are still running, we delay processing of such events until parent + // stages are complete. We save these events in this map until then. + private[spark] val dependentStageMap = new mutable.HashMap[Stage, DependentStageInfo] + + private def totalNumCoreForStage(stage: Stage): Int = { + val numTask = stage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.numPartitions + } + val resourceProfile = sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId) + val taskCpus = ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf) + taskCpus * numTask + } + + /** + * Hook invoked after the final stage is created. Registers stages reachable from + * the final stage as concurrent so they can be submitted in parallel. + */ + override def onFinalStageCreated(finalStage: Stage, properties: Properties): Unit = { + + val queryBatchId = getStreamingBatchIdFromProperties(properties) + + if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) { + if (properties.getProperty(SPECULATION_ENABLED.key) == "true") { + // Speculation is not supported with concurrent stages. + throw new SparkException( + "Speculative execution is not supported with concurrent stages " + + s"(streaming query: $queryBatchId). Please disable ${SPECULATION_ENABLED.key} config." + ) + } + + logInfo(log"Concurrent stages is enabled for [query ${MDC(LogKeys.STREAMING_QUERY_ID, + queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, queryBatchId.get.batchId)}]") + + // Mark current stage and all its ancestors as concurrent + var totalCoresNeeded = 0 + def visit(stage: Stage): Unit = { + if (!concurrentStages.contains(stage)) { + logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent for [query ${MDC( + LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC( + LogKeys.BATCH_ID, queryBatchId.get.batchId)}]") + concurrentStages += stage + totalCoresNeeded += totalNumCoreForStage(stage) + stage.parents.foreach(visit) + } + } + visit(finalStage) + + if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) { + try { + val totalSlots = sc.schedulerBackend.defaultParallelism() + val coresInUse = runningStages.toArray.map(totalNumCoreForStage(_)).sum + if (totalSlots - coresInUse < totalCoresNeeded) { + throw new SparkRuntimeException( + errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT", + messageParameters = Map( + "numSlots" -> (totalSlots - coresInUse).toString, + "numTasks" -> totalCoresNeeded.toString)) + } + } catch { + case e: UnsupportedOperationException => + logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for RTM.") + } + } + } else { + super.onFinalStageCreated(finalStage, properties) + } + } + + override def submitStage(stage: Stage): Unit = { + super.submitStage(stage) + + if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) { + // The current stage is not registered in waitingStages, which means it has + // no parents. This case we should remove it from concurrentStages since it is already + // running. + assert(runningStages.contains(stage), "stage should be running if not in waitingStages") + logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from concurrentStages") + concurrentStages -= stage + } + + // Find the stages that should be submitted concurrently with this stage. + waitingStages.intersect(concurrentStages).foreach { stage => + logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}") + concurrentStages -= stage // Don't submit this stage concurrently for subsequent attempts. + stage.parents.foreach { parent => + if (isRunningStage(parent)) { + logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE, stage)} with parent ${ + MDC(LogKeys.PARENT_STAGE, parent)}") + dependentStageMap.getOrElseUpdate(stage, DependentStageInfo()).parents += parent + } + } + // Remove stage and its parents from concurrentStages + def removeFromConcurrentStages(stage: Stage): Unit = { + if (concurrentStages.contains(stage)) { + logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from concurrentStages") + concurrentStages -= stage + } + stage.parents.foreach { parent => + assert(!waitingStages.contains(parent), "Parent stage should not still be waiting") + removeFromConcurrentStages(parent) + } + } + removeFromConcurrentStages(stage) + submitConcurrentStage(stage) + } + } + + // This is overridden to check if the task completion event should be delayed a parent stage + // till has running tasks. See comment for `dependentStageMap` for more details. + override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { + val stageId = event.task.stageId + val taskId = event.taskInfo.taskId + + getStage(stageId) match { + case Some(stage) if event.reason == Success && dependentStageMap.contains(stage) => + val dependentStageInfo = dependentStageMap(stage) + logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID, taskId)} in stage ${ + MDC(LogKeys.STAGE, stage)}. Active parent(s): ${MDC(LogKeys.PARENT_STAGES, + dependentStageInfo.parents.mkString(", "))}") + dependentStageInfo.delayedTaskCompletionEvents += event + + case _ => // Otherwise handle the event as usual. + super.handleTaskCompletion(event) + } + } + + // This is overridden to handle any delayed task completion events for dependent stages. + override def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { + + super.markStageAsFinished(stage, errorMessage, willRetry) + + // If this is a parent of a stage in dependentStageMap, remove it from parents. + val dependentStages = dependentStageMap + .filter(_._2.parents.contains(stage)) + .keys + + dependentStages.foreach { dependent => + if (errorMessage.isEmpty) { + assert( + isRunningStage(dependent), + s"Parent stages $stage's dependent stage $dependent should be running") + } + logInfo(log"Removing parent stage ${MDC(LogKeys.PARENT_STAGE, stage)} from dependent map " + + log"for stage ${MDC(LogKeys.STAGE, dependent)}") + dependentStageMap(dependent).parents -= stage + checkDependentStageTasks(dependent) + } + } + + // Checks if the dependent stage's parents are all done. If all the parents are done, + // enqueues any saved task completion event (if any). + private def checkDependentStageTasks(stage: Stage): Unit = { + val dependentStageInfo = dependentStageMap.getOrElse( + stage, throw new RuntimeException(s"Stage $stage is not in dependentStageMap") + ) + + if (dependentStageInfo.parents.isEmpty) { + val delayedEvents = dependentStageInfo.delayedTaskCompletionEvents + logInfo(log"All the parents are done for ${MDC(LogKeys.STAGE, stage)}. Removing it from " + + log"the map. It has ${MDC(LogKeys.NUM_EVENTS, delayedEvents.size.toLong)} " + + log"task completion events") + dependentStageMap -= stage + delayedEvents.foreach { event => + logInfo(log"Posting delayed task ${MDC(LogKeys.TASK_ID, event.taskInfo.taskId)} " + + log"completion event for stage ${MDC(LogKeys.STAGE, stage)}") + postSchedulerEvent(event) + } + } + } +} + +object ConcurrentStageDAGScheduler { + + val CONCURRENT_STAGES_ENABLED_PROPERTY: String = "streaming.concurrent.stages.enabled" + + def isConcurrentStagesEnabled(properties: Properties): Boolean = { + properties != null && + properties.getProperty(CONCURRENT_STAGES_ENABLED_PROPERTY) == "true" + } + + /** + * Extracts the [[StreamingBatchId]] from the given properties if all three of the streaming + * query id, run id and batch id are present. + */ + def getStreamingBatchIdFromProperties(properties: Properties): Option[StreamingBatchId] = { + if (properties == null) { + return None + } + + val queryId = Option(properties.getProperty("sql.streaming.queryId")) + val runId = Option(properties.getProperty("sql.streaming.runId")) + val batchId = Option(properties.getProperty("streaming.sql.batchId")) + if (queryId.nonEmpty && runId.nonEmpty && batchId.nonEmpty) { + Some(StreamingBatchId(queryId.get, runId.get, batchId.get.toLong)) + } else { + None + } + } +} + +/** + * Case class to identify a batch in a streaming query. + * + * @param queryId - Streaming query id + * @param runId - Streaming query run id + * @param batchId - Batch id for a micro batch in a streaming query + */ +case class StreamingBatchId(queryId: String, runId: String, batchId: Long) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22720b98aafde..f9e0c58dc4919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1272,6 +1272,14 @@ private[spark] class DAGScheduler( } } + private[scheduler] def isRunningStage(stage: Stage): Boolean = { + runningStages.contains(stage) + } + + private[scheduler] def getStage(stageId: Int): Option[Stage] = { + stageIdToStage.get(stageId) + } + /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). @@ -1385,6 +1393,8 @@ private[spark] class DAGScheduler( listenerBus.post(SparkListenerTaskGettingResult(taskInfo)) } + protected def onFinalStageCreated(finalStage: Stage, properties: Properties): Unit = {} + private def getQueryExecutionIdFromProperties(properties: Properties): Option[Long] = { try { Option(properties) @@ -1420,6 +1430,7 @@ private[spark] class DAGScheduler( // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) + onFinalStageCreated(finalStage, properties) } catch { case e: BarrierJobSlotsNumberCheckFailed => // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. @@ -1497,6 +1508,7 @@ private[spark] class DAGScheduler( // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. finalStage = getOrCreateShuffleMapStage(dependency, jobId) + onFinalStageCreated(finalStage, properties) } catch { case e: Exception => logWarning(log"Creating new stage failed due to exception - job: ${MDC(JOB_ID, jobId)}", e) @@ -1537,7 +1549,7 @@ private[spark] class DAGScheduler( } /** Submits stage, but first recursively submits any missing parents. */ - private def submitStage(stage: Stage): Unit = { + protected def submitStage(stage: Stage): Unit = { val jobId = activeJobForStage(stage) if (jobId.isDefined) { logDebug(s"submitStage($stage (name=${stage.name};" + @@ -1569,6 +1581,27 @@ private[spark] class DAGScheduler( } } + /** + * An experimental API to submit child stages even while the parents are running. This is only + * used in [[ConcurrentStageDAGScheduler]]. It defined here since it depends two private APIs in + * this class (namely submitMissingTasks() and activeJobForStage()). + */ + protected def submitConcurrentStage(stage: Stage): Unit = { + assert(waitingStages.contains(stage)) + activeJobForStage(stage) match { + case Some(job) => + waitingStages -= stage + submitMissingTasks(stage, job) + case None => // Not expected. + new IllegalStateException(s"No active job for stage $stage") + } + } + + protected def postSchedulerEvent(event: DAGSchedulerEvent): Unit = { + // Currently only used in [[ConcurrentStageDAGScheduler]]. + eventProcessLoop.post(event) + } + /** * `PythonRunner` needs to know what the pyspark memory and cores settings are for the profile * being run. Pass them in the local properties of the task if it's set for the stage profile. @@ -3253,7 +3286,7 @@ private[spark] class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished( + protected def markStageAsFinished( stage: Stage, errorMessage: Option[String] = None, willRetry: Boolean = false): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 618c8eb459026..6078a91f0f591 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -245,8 +245,15 @@ private[spark] class TaskSchedulerImpl( logInfo(log"Adding task set " + taskSet.logId + log" with ${MDC(LogKeys.NUM_TASKS, tasks.length)} tasks resource profile " + log"${MDC(LogKeys.RESOURCE_PROFILE_ID, taskSet.resourceProfileId)}") + val maxFailures = if (ConcurrentStageDAGScheduler + .isConcurrentStagesEnabled(taskSet.properties)) { + logInfo(s"Task retries are disabled for task set ${taskSet.id} with ${tasks.length} tasks") + 1 // Concurrent stage execution does not support task retries. + } else { + maxTaskFailures + } this.synchronized { - val manager = createTaskSetManager(taskSet, maxTaskFailures) + val manager = createTaskSetManager(taskSet, maxFailures) val stage = taskSet.stageId val stageTaskSets = taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c077a7a3bbb8..5847480979f59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1051,10 +1051,17 @@ private[spark] class TaskSetManager( logWarning(failureReason) None - case e: ExecutorLostFailure if !e.exitCausedByApp => - logInfo(log"${MDC(TASK_NAME, taskName(tid))} failed because while it was being computed," + - log" its executor exited for a reason unrelated to the task. " + - log"Not counting this failure towards the maximum number of failures for the task.") + case e: ExecutorLostFailure => + if (!e.exitCausedByApp + // if the query is running in real time mode, any failure should be considered + // a task failure + && !ConcurrentStageDAGScheduler.isConcurrentStagesEnabled(taskSet.properties)) { + logInfo(log"${MDC(TASK_NAME, taskName(tid))} failed because while it was being " + + log"computed, its executor exited for a reason unrelated to the task. " + + log"Not counting this failure towards the maximum number of failures for the task.") + } else { + logWarning(failureReason) + } None case _: TaskFailedReason => // TaskResultLost and others @@ -1069,7 +1076,12 @@ private[spark] class TaskSetManager( emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null, accumUpdates, metricPeaks) - if (!isZombie && reason.countTowardsTaskFailures) { + val countTowardsTaskFailures = reason.countTowardsTaskFailures || + // if the query is running in real time mode, any failures should contribute the task failures + // so that the query can restart. + ConcurrentStageDAGScheduler.isConcurrentStagesEnabled(taskSet.properties) + + if (!isZombie && countTowardsTaskFailures) { assert (null != failureReason) taskSetExcludelistHelperOpt.foreach(_.updateExcludedForFailedTask( info.host, info.executorId, index, failureReasonString)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala new file mode 100644 index 0000000000000..1efd73d86679a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +import org.apache.spark.HashPartitioner +import org.apache.spark.ShuffleDependency +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.internal.config.SPECULATION_ENABLED +import org.apache.spark.internal.config.STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED + +class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { + + // The unit-test SparkContext runs in local[2] mode, but the concurrent pipelines exercised + // here often need more slots than that. Disable the slot check so the tests aren't gated by + // executor capacity. + override def conf: SparkConf = + super.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, true) + + class TestConcurrentStageDAGScheduler(sc: SparkContext) + extends ConcurrentStageDAGScheduler( + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env) + with TestDAGScheduler + + override def createInitialScheduler(sc: SparkContext): DAGScheduler = { + new TestConcurrentStageDAGScheduler(sc) + } + + // Catch the job failure exception with a listener. + private class TestJobListener extends JobListener { + private var failureException: Option[Exception] = None + + override def jobFailed(exception: Exception): Unit = { + failureException = Some(exception) + } + + override def taskSucceeded(index: Int, result: Any): Unit = { } + + def expectFailure(): Exception = { + assert(failureException.nonEmpty, "Job was expected to fail with an exception, but didn't") + failureException.get + } + } + + + /** Default job properties with query settings and concurrent stages enabled. */ + private val testProperties: Properties = { + val properties = new Properties() + properties.setProperty("sql.streaming.queryId", "test_query_id") + properties.setProperty("sql.streaming.runId", "test_run_id") + properties.setProperty("streaming.sql.batchId", "5") + properties.setProperty(ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") + new Properties(properties) { + // Make it read-only. + override def setProperty(key: String, value: String): AnyRef = { + throw new UnsupportedOperationException("Default properties are read-only.") + } + } + } + + test("Simple job with two concurrent stages") { + // Run a simple job with two stages. Both stages should be running concurrently. + + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage] + + submit(resultStage, Array(0), properties = testProperties) + + assert(scheduler.waitingStages.isEmpty) // Both are submitted. + assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are running. + + // Verify concurrent scheduler specific state. + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + + assert(concurrentScheduler.concurrentStages.isEmpty) // All are already scheduled + + val depStageMap = concurrentScheduler.dependentStageMap + assert(depStageMap.keys.map(_.id) == Set(1)) // Result stage is the key. + assert(depStageMap.values.flatMap(_.parents.map(_.id)) == Seq(0)) // Map stage is the parent. + assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty) // No completed tasks. + + // First complete the result stage. Its tasks will complete, but the actual stage would still + // be running since its parent (map stage) hasn't completed yet. + + completeNextResultStageWithSuccess(1, 0) + assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are still running. + // dependentStageMap should have the completed task from result stage enqueued. + assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).size === 1) + + // Now complete the map stage. This should complete the result stage as well. + completeShuffleMapStageSuccessfully(0, 0, 1) + + assert(scheduler.runningStages.map(_.id) === Set()) // Both stages are complete. + assert(depStageMap.isEmpty) // No more dependent stages. + + assertDataStructuresEmpty() + } + + test("Default scheduler using a simple job with concurrent stages disabled") { + // This is opposite of the previous test. Concurrent stages are disabled, so the stages should + // be submitted one after the other. + + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage] + + submit(resultStage, Array(0), properties = new Properties()) + + assert(scheduler.runningStages.map(_.id) == Set(0)) // Only the map stage is running. + assert(scheduler.waitingStages.map(_.id) == Set(1)) // Result stage is waiting. + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + assert(concurrentScheduler.concurrentStages.isEmpty) // No concurrent stages. + assert(concurrentScheduler.dependentStageMap.isEmpty) // No dependent stages. + + // Complete the map stage. This should submit the result stage. + completeShuffleMapStageSuccessfully(0, 0, 1) + + assert(scheduler.runningStages.map(_.id) == Set(1)) // Only the result stage is running. + assert(scheduler.waitingStages.map(_.id) == Set()) // No waiting stages + + completeNextResultStageWithSuccess(1, 0) + assertDataStructuresEmpty() + } + + test("Complex pipeline with many stages") { + // Run a complex pipeline with multiple stages with multiple branches. Such a pipeline not + // common, but useful to ensure scheduler works as expected. + + // Shape: + // /<-------------------- stage_D + // stage_A <--- stage_B <--- stage_C <---\ ^ + // \ \<---------/ \ | + // \ <----------------/ \ | + // stage_E <--------------------------- stage_F + + // All of these should be running concurrently. + + val rddA = new MyRDD(sc, 2, Nil).setName("rddA") + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + + val rddB = new MyRDD(sc, 1, List(shuffleDepA)).setName("rddB") + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1)) + + val rddC = new MyRDD(sc, 1, List(shuffleDepA, shuffleDepB)).setName("rddC") + val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(5)) + + val rddD = new MyRDD(sc, 4, List(shuffleDepB)).setName("rddD") + val shuffleDepD = new ShuffleDependency(rddD, new HashPartitioner(5)) + + val rddE = new MyRDD(sc, 3, Nil).setName("rddE") + val shuffleDepE = new ShuffleDependency(rddE, new HashPartitioner(5)) + + val rddF = new MyRDD(sc, 5, List(shuffleDepC, shuffleDepD, shuffleDepE)).setName("rddF") + + submit(rddF, Array(0, 1, 2, 3, 4), properties = testProperties) + + assert(scheduler.waitingStages.isEmpty) // All the stages are submitted. + assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5)) // All 6 are running. + + // Assign stage ids corresponding to the RDDs A, B, etc + def stageFor(rddName: String): Stage = scheduler.runningStages.find(_.rdd.name == rddName).get + + val sA = stageFor("rddA") + val sB = stageFor("rddB") + val sC = stageFor("rddC") + val sD = stageFor("rddD") + val sE = stageFor("rddE") + val sF = stageFor("rddF") + + // log stage id mapping for debugging: + for (name <- List("A", "B", "C", "D", "E", "F")) { + logInfo(s"Stage id for stage $name is ${stageFor("rdd" + name).id}") + } + + // Verify concurrent scheduler specific state. + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + + assert(concurrentScheduler.concurrentStages.isEmpty) // All are already scheduled + + val depStageMap = concurrentScheduler.dependentStageMap + assert(depStageMap.keys === Set(sB, sC, sD, sF)) // All non-root stages are keys. + assert(depStageMap.values.flatMap(_.parents).toSet === Set( + sA, sB, sC, sD, sE)) // All except the results stage. + assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty) // No completed tasks. + + // Complete stages in order-of-order and verify the state. + + // First complete C. Entry for C would be updated with the completed task. + assert(depStageMap(sC).delayedTaskCompletionEvents.size === 0) + completeShuffleMapStageSuccessfully(sC.id, 0, 5) // Complete stage C. + assert(depStageMap(sC).delayedTaskCompletionEvents.size === 1) + + // All the 6 stages are still 'running' since C's completion events are delayed. + assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5)) + + // Now complete stage D. This is similar to completing C. The tasks are enqueued. + assert(depStageMap(sD).delayedTaskCompletionEvents.size === 0) + completeShuffleMapStageSuccessfully(sD.id, 0, 5) // Complete stage D + assert(depStageMap(sD).delayedTaskCompletionEvents.size === 4) // 4 tasks in stage C. + + // Complete stage E. This is a root node and should complete normally. + assert(depStageMap(sF).parents.contains(sE)) // E is one of the stages that F waits for. + completeShuffleMapStageSuccessfully(sE.id, 0, 3) + assert(depStageMap(sF).parents.contains(sE) === false) // E is removed from F's parents. + + // One less running stage. + assert(scheduler.runningStages === Set(sA, sB, sC, sD, sF)) // E is complete. + + // Complete stage A. This is a root node and should complete normally. + assert(depStageMap(sC).parents.contains(sA)) // A is one of the stages that C waits for. + assert(depStageMap(sB).parents.contains(sA)) // Same for B + completeShuffleMapStageSuccessfully(sA.id, 0, 2) + assert(scheduler.runningStages === Set(sB, sC, sD, sF)) // A is complete. + assert(depStageMap(sC).parents.contains(sA) === false) // A is removed from C's parents. + // In the case B, it is bit more. Its only parent A is complete. So there is no need to + // track it in depStageMap. So it is removed from `dependentStageMap` entirely. + assert(depStageMap.contains(sB) === false) // B is removed from the depStageMap. + + // Complete result stage F. This will be enqueued, will complete later. + completeNextResultStageWithSuccess(sF.id, 0) + assert(scheduler.runningStages === Set(sB, sC, sD, sF)) // F is still running. + assert(depStageMap(sF).delayedTaskCompletionEvents.size === 5) + + // Complete stage B, it will complete normally. + completeShuffleMapStageSuccessfully(sB.id, 0, 1) + // This will trigger completion of C as well since both its parents A & B are done. + // This will also trigger completion of D as its parent B is done. + // Which finally completes F as well. + assert(scheduler.runningStages === Set()) // All stages are complete. + assert(depStageMap.isEmpty) // No more dependent stages. + + assertDataStructuresEmpty() + } + + test("Should fail if speculative execution is enabled") { + // Try to a run job with two stages with speculative execution. It should fail the job with + // exception. + + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + val properties = new Properties(testProperties) + properties.setProperty(SPECULATION_ENABLED.key, "true") + + val jobListener = new TestJobListener() + submit(resultStage, Array(0), properties = properties, listener = jobListener) + + assert(jobListener.expectFailure().getMessage.contains( + "Speculative execution is not supported with concurrent stages")) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index e12348e1be2d7..d09bf617df396 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -180,7 +180,8 @@ class DummyScheduledFuture( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with TimeLimits { +abstract class DAGSchedulerSuiteBase extends SparkFunSuite with TempLocalSparkContext + with TimeLimits { import DAGSchedulerSuite._ @@ -368,18 +369,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } - class MyDAGScheduler( - sc: SparkContext, - taskScheduler: TaskScheduler, - listenerBus: LiveListenerBus, - mapOutputTracker: MapOutputTrackerMaster, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv, - clock: Clock = new SystemClock(), - shuffleMergeFinalize: Boolean = true, - shuffleMergeRegister: Boolean = true - ) extends DAGScheduler( - sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) { + trait TestDAGScheduler extends DAGScheduler { + + protected def getShuffleMergeFinalize: Boolean = true + protected def getShuffleMergeRegister: Boolean = true + /** * Schedules shuffle merge finalize. */ @@ -387,7 +381,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti shuffleMapStage: ShuffleMapStage, delay: Long, registerMergeResults: Boolean = true): Unit = { - if (shuffleMergeRegister && registerMergeResults) { + if (getShuffleMergeRegister && registerMergeResults) { for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) { val mergeStatuses = Seq((part, makeMergeStatus("", shuffleMapStage.shuffleDep.shuffleMergeId))) @@ -403,7 +397,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti shuffleMapStage.shuffleDep.setFinalizeTask( new DummyScheduledFuture(delay, registerMergeResults)) - if (shuffleMergeFinalize) { + if (getShuffleMergeFinalize) { handleShuffleMergeFinalized(shuffleMapStage, shuffleMapStage.shuffleDep.shuffleMergeId) } } @@ -417,6 +411,24 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } + class MyDAGScheduler( + sc: SparkContext, + taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock(), + shuffleMergeFinalize: Boolean = true, + shuffleMergeRegister: Boolean = true + ) extends DAGScheduler( + sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) + with TestDAGScheduler { + + override def getShuffleMergeFinalize: Boolean = shuffleMergeFinalize + override def getShuffleMergeRegister: Boolean = shuffleMergeRegister + } + override def beforeEach(): Unit = { super.beforeEach() firstInit = true @@ -446,15 +458,19 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti new MyMapOutputTrackerMaster(sc.getConf, broadcastManager)) blockManagerMaster = spy[MyBlockManagerMaster](new MyBlockManagerMaster(sc.getConf)) doNothing().when(blockManagerMaster).updateRDDBlockVisibility(any(), any()) - scheduler = spy[MyDAGScheduler](new MyDAGScheduler( + scheduler = spy(createInitialScheduler(sc)) + + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + } + + protected def createInitialScheduler(sc: SparkContext): DAGScheduler = { + new MyDAGScheduler( sc, taskScheduler, sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env)) - - dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + sc.env) } override def afterEach(): Unit = { @@ -527,7 +543,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } /** Submits a job to the scheduler and returns the job id. */ - private def submit( + protected def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, @@ -1168,7 +1184,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti * @param hostNames - Host on which each task in the task set is executed. In case no hostNames * are provided, the tasks will progressively complete on hostA, hostB, etc. */ - private def completeShuffleMapStageSuccessfully( + protected def completeShuffleMapStageSuccessfully( stageId: Int, attemptIdx: Int, numShufflePartitions: Int, @@ -1219,7 +1235,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti * @param stageId - The current stageId * @param attemptIdx - The current attempt count */ - private def completeNextResultStageWithSuccess( + protected def completeNextResultStageWithSuccess( stageId: Int, attemptIdx: Int, partitionToResult: Int => Int = _ => 42): Unit = { @@ -6107,7 +6123,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } - private def assertDataStructuresEmpty(): Unit = { + protected def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -6150,6 +6166,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } +class DAGSchedulerSuite extends DAGSchedulerSuiteBase + class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite { override def conf: SparkConf = super.conf.set(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, false) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index c9c9e529405ec..cb1f574e201cb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -2752,4 +2752,32 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext assert(!tsm.isInstanceOf[StructuredStreamingIdAwareSchedulerLogging]) } + test("maxTaskFailures is overridden to 1 when concurrent stages are enabled") { + // Concurrent stage execution does not support task retries — a failure must restart the + // streaming query rather than be silently retried against a still-running shuffle. The + // scheduler enforces this by capping the TaskSetManager's maxFailures at 1, regardless of + // the cluster-wide spark.task.maxFailures. + val clusterMaxFailures = 4 + val taskScheduler = setupScheduler(config.TASK_MAX_FAILURES.key -> clusterMaxFailures.toString) + val props = new Properties() + props.setProperty( + ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") + val base = FakeTask.createTaskSet(numTasks = 1) + val taskSet = new TaskSet(base.tasks, base.stageId, base.stageAttemptId, base.priority, + props, base.resourceProfileId, base.shuffleId) + + taskScheduler.submitTasks(taskSet) + val tsm = taskScheduler.taskSetManagerForAttempt( + taskSet.stageId, taskSet.stageAttemptId).get + + assert(tsm.maxTaskFailures === 1) + + // Regression guard: a regular TaskSet should still get the cluster's maxTaskFailures. + val regular = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 0) + taskScheduler.submitTasks(regular) + val regularTsm = taskScheduler.taskSetManagerForAttempt( + regular.stageId, regular.stageAttemptId).get + assert(regularTsm.maxTaskFailures === clusterMaxFailures) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 4fd710712221b..d07e52f3fc382 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2921,6 +2921,58 @@ class TaskSetManagerSuite s"\nCaptured logs:\n${logs.mkString("\n")}") } + /** + * Wraps an existing TaskSet with a copy that has the given properties. Used by the + * concurrent-stages tests below since `FakeTask.createTaskSet` always sets properties to null. + */ + private def withProperties(taskSet: TaskSet, properties: Properties): TaskSet = { + new TaskSet(taskSet.tasks, taskSet.stageId, taskSet.stageAttemptId, taskSet.priority, + properties, taskSet.resourceProfileId, taskSet.shuffleId) + } + + private def concurrentStagesProperties: Properties = { + val props = new Properties() + props.setProperty( + ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") + props + } + + test("ExecutorLostFailure counts as task failure when concurrent stages are enabled") { + // Otherwise (the default), an executor exit that wasn't "caused by the app" is exempt + // from the maxTaskFailures count. For real-time mode that exemption is wrong: an executor + // loss is a query failure, and the query should be restarted. + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = withProperties(FakeTask.createTaskSet(1), concurrentStagesProperties) + // Concurrent stage execution does not support task retries, so maxFailures = 1. + val manager = new TaskSetManager(sched, taskSet, maxTaskFailures = 1) + + val offer = manager.resourceOffer("exec1", "host1", TaskLocality.ANY)._1 + assert(offer.isDefined) + manager.handleFailedTask(offer.get.taskId, TaskState.FAILED, + ExecutorLostFailure("exec1", exitCausedByApp = false, reason = None)) + + assert(sched.taskSetsFailed.contains(taskSet.id), + "TaskSet should have been aborted after the single allowed failure") + assert(manager.isZombie) + } + + test("ExecutorLostFailure is not counted without concurrent stages (regression guard)") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(1) + val manager = new TaskSetManager(sched, taskSet, maxTaskFailures = 1) + + val offer = manager.resourceOffer("exec1", "host1", TaskLocality.ANY)._1 + assert(offer.isDefined) + manager.handleFailedTask(offer.get.taskId, TaskState.FAILED, + ExecutorLostFailure("exec1", exitCausedByApp = false, reason = None)) + + assert(!sched.taskSetsFailed.contains(taskSet.id), + "Executor loss not caused by app should not count toward task failures by default") + assert(!manager.isZombie) + } + } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) { From abfeb683732f67dedb0e8c7462320d1a3bd3d503 Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Thu, 28 May 2026 06:28:45 +0000 Subject: [PATCH 2/5] Address review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes: - DAGScheduler.submitConcurrentStage: change `new IllegalStateException(...)` to `throw new IllegalStateException(...)` so the unexpected-state branch actually fails instead of silently being a no-op. - ConcurrentStageDAGScheduler.onFinalStageCreated: walk the DAG into a local `visitedStages` set and only commit to `concurrentStages` after the slot check passes, so a slot-check failure can't leak stage references into the long-lived scheduler state. - ConcurrentStageDAGScheduler.markStageAsFinished: unconditionally drop the stage's own entry from `dependentStageMap` at the end. On the success path the entry has already been removed by `checkDependentStageTasks`, so this is a no-op; on failure/cancellation/abort it's the missing cleanup that previously required the parent stage to be marked finished (which doesn't always happen if the parent is shared with another job). - ConcurrentStageDAGScheduler.onFinalStageCreated: speculation check also reads `sc.conf.get(SPECULATION_ENABLED)`, matching how the rest of core reads the config; users with cluster-wide spark.speculation=true were previously bypassing this guard. API cleanup: - Move `submitConcurrentStage` into ConcurrentStageDAGScheduler as a private method. Remove `postSchedulerEvent` entirely (callers now use `eventProcessLoop.post(event)` directly since it's already `private[spark]`). Relax `submitMissingTasks` and `activeJobForStage` to `protected` so the subclass can call them. - Reuse `StructuredStreamingIdAwareSchedulerLogging.QUERY_ID_KEY` and `BATCH_ID_KEY` constants instead of hardcoded strings; drop the unused `runId` field from `StreamingBatchId` (CrossJobDepDAGScheduler — which consumes it — is not part of this PR). Test scaffolding: - Add `protected def extraEmptyChecks(): Unit = ()` hook to `assertDataStructuresEmpty` in DAGSchedulerSuiteBase; override in ConcurrentStageDAGSchedulerSuite to assert `concurrentStages` and `dependentStageMap` are empty. - Also call `extraEmptyChecks()` in `afterEach`, so every inherited test (and every locally-defined test) automatically validates that the new state hasn't leaked. Pattern-match on the scheduler type to skip the check when an inherited test replaces the scheduler with a plain MyDAGScheduler. - Relax `failed` and `cancel` to `protected` in DAGSchedulerSuiteBase so subclass suites can use them. New tests (in ConcurrentStageDAGSchedulerSuite): - `concurrentStages is empty after slot-check failure` — exercises the visited-set commit pattern. - `dependentStageMap entry is cleaned up when a dependent stage aborts and its parent stage is shared with another job` — sets up a shared shuffle stage between a batch and a concurrent job; fails the concurrent job's leaf and verifies the cleanup runs even though the parent isn't marked finished. - `concurrentStages and dependentStageMap are cleaned up after job cancellation` — covers the JobCancelled event path. - `concurrentStages and dependentStageMap are cleaned up after executor-loss induced abort` — covers the maxFailures=1-abort path. - Speculation test split into per-job-property and cluster-wide-SparkConf variants; both verified to fail the job. Typos and wording: - Comment "states" → "stages" in ConcurrentStageDAGScheduler. - "has only has" → "has only" in the CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT error message. - "contribute the task failures" → "count toward the task failures" in TaskSetManager. - Test comment "4 tasks in stage C" → "4 tasks in stage D" in the complex- pipeline test. Co-authored-by: Isaac --- .../resources/error/error-conditions.json | 2 +- .../ConcurrentStageDAGScheduler.scala | 63 +++++-- .../apache/spark/scheduler/DAGScheduler.scala | 25 +-- .../spark/scheduler/TaskSetManager.scala | 2 +- .../ConcurrentStageDAGSchedulerSuite.scala | 175 +++++++++++++++++- .../spark/scheduler/DAGSchedulerSuite.scala | 11 +- 6 files changed, 230 insertions(+), 48 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 1f25a85266622..ef0e540a7b430 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -892,7 +892,7 @@ }, "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT" : { "message" : [ - "The minimum number of free slots required in the cluster is , however, the cluster has only has slots free. Query will stall or fail. Increase cluster size to proceed." + "The minimum number of free slots required in the cluster is , however, the cluster has only slots free. Query will stall or fail. Increase cluster size to proceed." ], "sqlState" : "53000" }, diff --git a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala index a2591821b39b5..1bd884f29ec15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala @@ -60,7 +60,7 @@ class ConcurrentStageDAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) - // This contains all the concurrent states that are yet to be scheduled across all the jobs. + // This contains all the concurrent stages that are yet to be scheduled across all the jobs. private[spark] val concurrentStages = new mutable.HashSet[Stage] private[scheduler] case class DependentStageInfo( @@ -91,8 +91,11 @@ class ConcurrentStageDAGScheduler( val queryBatchId = getStreamingBatchIdFromProperties(properties) if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) { - if (properties.getProperty(SPECULATION_ENABLED.key) == "true") { - // Speculation is not supported with concurrent stages. + // Speculation is not supported with concurrent stages. Check both the per-job local + // property (for jobs that override the cluster default via setLocalProperty) and the + // SparkConf (the documented way to enable speculation cluster-wide). + if (properties.getProperty(SPECULATION_ENABLED.key) == "true" || + sc.conf.get(SPECULATION_ENABLED)) { throw new SparkException( "Speculative execution is not supported with concurrent stages " + s"(streaming query: $queryBatchId). Please disable ${SPECULATION_ENABLED.key} config." @@ -102,14 +105,17 @@ class ConcurrentStageDAGScheduler( logInfo(log"Concurrent stages is enabled for [query ${MDC(LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, queryBatchId.get.batchId)}]") - // Mark current stage and all its ancestors as concurrent + // Mark current stage and all its ancestors as concurrent. + // Collect into a local set first so a slot-check failure below does not leak partial + // state into concurrentStages. + val visitedStages = new mutable.HashSet[Stage] var totalCoresNeeded = 0 def visit(stage: Stage): Unit = { - if (!concurrentStages.contains(stage)) { + if (!visitedStages.contains(stage)) { logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent for [query ${MDC( LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC( LogKeys.BATCH_ID, queryBatchId.get.batchId)}]") - concurrentStages += stage + visitedStages += stage totalCoresNeeded += totalNumCoreForStage(stage) stage.parents.foreach(visit) } @@ -132,6 +138,9 @@ class ConcurrentStageDAGScheduler( logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for RTM.") } } + + // Slot check passed (or was disabled) — commit the visited stages. + concurrentStages ++= visitedStages } else { super.onFinalStageCreated(finalStage, properties) } @@ -176,6 +185,21 @@ class ConcurrentStageDAGScheduler( } } + /** + * Submits a child stage even while its parents are still running. Distinct from + * `submitStage` in that it bypasses the missing-parent check. + */ + private def submitConcurrentStage(stage: Stage): Unit = { + assert(waitingStages.contains(stage)) + activeJobForStage(stage) match { + case Some(job) => + waitingStages -= stage + submitMissingTasks(stage, job) + case None => // Not expected. + throw new IllegalStateException(s"No active job for stage $stage") + } + } + // This is overridden to check if the task completion event should be delayed a parent stage // till has running tasks. See comment for `dependentStageMap` for more details. override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { @@ -219,6 +243,13 @@ class ConcurrentStageDAGScheduler( dependentStageMap(dependent).parents -= stage checkDependentStageTasks(dependent) } + + // Drop this stage's own entry from the map. On the success path + // `checkDependentStageTasks` (invoked when the stage's last parent finishes) has already + // removed the entry, so this is a no-op. On failure / cancellation / abort the entry — + // and any buffered completion events — would otherwise leak for the lifetime of the + // scheduler. + dependentStageMap.remove(stage) } // Checks if the dependent stage's parents are all done. If all the parents are done, @@ -237,7 +268,7 @@ class ConcurrentStageDAGScheduler( delayedEvents.foreach { event => logInfo(log"Posting delayed task ${MDC(LogKeys.TASK_ID, event.taskInfo.taskId)} " + log"completion event for stage ${MDC(LogKeys.STAGE, stage)}") - postSchedulerEvent(event) + eventProcessLoop.post(event) } } } @@ -253,19 +284,20 @@ object ConcurrentStageDAGScheduler { } /** - * Extracts the [[StreamingBatchId]] from the given properties if all three of the streaming - * query id, run id and batch id are present. + * Extracts the [[StreamingBatchId]] from the given properties if both the streaming + * query id and batch id are present. */ def getStreamingBatchIdFromProperties(properties: Properties): Option[StreamingBatchId] = { if (properties == null) { return None } - val queryId = Option(properties.getProperty("sql.streaming.queryId")) - val runId = Option(properties.getProperty("sql.streaming.runId")) - val batchId = Option(properties.getProperty("streaming.sql.batchId")) - if (queryId.nonEmpty && runId.nonEmpty && batchId.nonEmpty) { - Some(StreamingBatchId(queryId.get, runId.get, batchId.get.toLong)) + val queryId = Option(properties.getProperty( + StructuredStreamingIdAwareSchedulerLogging.QUERY_ID_KEY)) + val batchId = Option(properties.getProperty( + StructuredStreamingIdAwareSchedulerLogging.BATCH_ID_KEY)) + if (queryId.nonEmpty && batchId.nonEmpty) { + Some(StreamingBatchId(queryId.get, batchId.get.toLong)) } else { None } @@ -276,7 +308,6 @@ object ConcurrentStageDAGScheduler { * Case class to identify a batch in a streaming query. * * @param queryId - Streaming query id - * @param runId - Streaming query run id * @param batchId - Batch id for a micro batch in a streaming query */ -case class StreamingBatchId(queryId: String, runId: String, batchId: Long) +case class StreamingBatchId(queryId: String, batchId: Long) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f9e0c58dc4919..2973a7639be69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1285,7 +1285,7 @@ private[spark] class DAGScheduler( // stage the one with the highest priority (highest-priority pool, earliest created). // That should take care of at least part of the priority inversion problem with // cross-job dependencies. - private def activeJobForStage(stage: Stage): Option[Int] = { + protected def activeJobForStage(stage: Stage): Option[Int] = { val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted jobsThatUseStage.find(jobIdToActiveJob.contains) } @@ -1581,27 +1581,6 @@ private[spark] class DAGScheduler( } } - /** - * An experimental API to submit child stages even while the parents are running. This is only - * used in [[ConcurrentStageDAGScheduler]]. It defined here since it depends two private APIs in - * this class (namely submitMissingTasks() and activeJobForStage()). - */ - protected def submitConcurrentStage(stage: Stage): Unit = { - assert(waitingStages.contains(stage)) - activeJobForStage(stage) match { - case Some(job) => - waitingStages -= stage - submitMissingTasks(stage, job) - case None => // Not expected. - new IllegalStateException(s"No active job for stage $stage") - } - } - - protected def postSchedulerEvent(event: DAGSchedulerEvent): Unit = { - // Currently only used in [[ConcurrentStageDAGScheduler]]. - eventProcessLoop.post(event) - } - /** * `PythonRunner` needs to know what the pyspark memory and cores settings are for the profile * being run. Pass them in the local properties of the task if it's set for the stage profile. @@ -1665,7 +1644,7 @@ private[spark] class DAGScheduler( } /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { + protected def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") // For statically indeterminate stages being retried, we trigger rollback BEFORE task diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5847480979f59..52c37153ba84d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1077,7 +1077,7 @@ private[spark] class TaskSetManager( accumUpdates, metricPeaks) val countTowardsTaskFailures = reason.countTowardsTaskFailures || - // if the query is running in real time mode, any failures should contribute the task failures + // if the query is running in real time mode, any failures should count toward the task failures // so that the query can restart. ConcurrentStageDAGScheduler.isConcurrentStagesEnabled(taskSet.properties) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala index 1efd73d86679a..766a68a649528 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala @@ -48,6 +48,33 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { new TestConcurrentStageDAGScheduler(sc) } + /** + * Asserts that the concurrent scheduler's internal state — `concurrentStages` and + * `dependentStageMap` — is empty. Called from `assertDataStructuresEmpty` and at the end of + * every test via `afterEach`, so every inherited test (and every locally-defined test) gets + * free regression coverage against entries leaking into these maps. + */ + override protected def extraEmptyChecks(): Unit = { + // Inherited tests sometimes replace `scheduler` with a plain MyDAGScheduler (bypassing + // createInitialScheduler), so pattern-match rather than cast. + scheduler match { + case s: TestConcurrentStageDAGScheduler => + assert(s.concurrentStages.isEmpty, + s"concurrentStages should be empty but contains: ${s.concurrentStages}") + assert(s.dependentStageMap.isEmpty, + s"dependentStageMap should be empty but contains: ${s.dependentStageMap}") + case _ => // Not a concurrent scheduler — nothing extra to assert. + } + } + + override def afterEach(): Unit = { + try { + extraEmptyChecks() + } finally { + super.afterEach() + } + } + // Catch the job failure exception with a listener. private class TestJobListener extends JobListener { private var failureException: Option[Exception] = None @@ -69,7 +96,6 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { private val testProperties: Properties = { val properties = new Properties() properties.setProperty("sql.streaming.queryId", "test_query_id") - properties.setProperty("sql.streaming.runId", "test_run_id") properties.setProperty("streaming.sql.batchId", "5") properties.setProperty(ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") new Properties(properties) { @@ -224,7 +250,7 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { // Now complete stage D. This is similar to completing C. The tasks are enqueued. assert(depStageMap(sD).delayedTaskCompletionEvents.size === 0) completeShuffleMapStageSuccessfully(sD.id, 0, 5) // Complete stage D - assert(depStageMap(sD).delayedTaskCompletionEvents.size === 4) // 4 tasks in stage C. + assert(depStageMap(sD).delayedTaskCompletionEvents.size === 4) // 4 tasks in stage D. // Complete stage E. This is a root node and should complete normally. assert(depStageMap(sF).parents.contains(sE)) // E is one of the stages that F waits for. @@ -260,9 +286,127 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { assertDataStructuresEmpty() } - test("Should fail if speculative execution is enabled") { - // Try to a run job with two stages with speculative execution. It should fail the job with - // exception. + test("dependentStageMap entry is cleaned up when a dependent stage aborts and its " + + "parent stage is shared with another job") { + // This exercises the cleanup path in markStageAsFinished. When the dependent stage + // (here, B) aborts before its parent (A) finishes, the cascade through + // failJobAndIndependentStages only marks stages that are *independent* to the failing job + // as finished — shared stages are left alone. Without the explicit + // `dependentStageMap.remove(stage)` at the end of markStageAsFinished, B's entry would + // leak in dependentStageMap until A eventually finished for the other job. + + // Job 1 (regular batch): rddC depends on rddA via shuffleDepA. + val rddA = new MyRDD(sc, 1, Nil).setName("rddA") + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + val rddC = new MyRDD(sc, 1, List(shuffleDepA)).setName("rddC") + submit(rddC, Array(0)) // properties = null implies non-concurrent. + + // After job 1: rddA's stage is running, rddC's stage is waiting. + assert(scheduler.runningStages.exists(_.rdd.name == "rddA"), + "rddA's stage should be running after job 1 submission") + + // Job 2 (concurrent): rddB also depends on the same shuffleDepA → rddA's stage is shared. + val rddB = new MyRDD(sc, 3, List(shuffleDepA)).setName("rddB") + submit(rddB, Array(0), properties = testProperties) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + val depStageMap = concurrentScheduler.dependentStageMap + // B is in dependentStageMap with rddA's stage as a parent. + assert(depStageMap.keys.map(_.rdd.name) === Set("rddB")) + assert(depStageMap.values.flatMap(_.parents.map(_.rdd.name)).toSet === Set("rddA")) + + // Fail rddB's taskset. abortStage's failJobAndIndependentStages only marks rddB + // finished — rddA is shared with job 1, so it is NOT cancelled. + val taskSetB = taskSets.find(_.tasks.head.stageId == + scheduler.runningStages.find(_.rdd.name == "rddB").get.id).get + failed(taskSetB, "test failure: rddB aborted before parent rddA finished") + + // rddA is still running for job 1. + assert(scheduler.runningStages.exists(_.rdd.name == "rddA"), + "Shared parent rddA's stage should still be running for job 1 after job 2 aborted") + + // rddB's entry must have been removed by markStageAsFinished's cleanup, even though + // rddA is still running (and would normally be the stage whose markStageAsFinished + // cleans up B's entry). + assert(depStageMap.isEmpty, + s"dependentStageMap should be empty after rddB aborted, but contains: $depStageMap") + } + + test("concurrentStages is empty after slot-check failure") { + // The DAG walk in onFinalStageCreated accumulates stages into a local set and only commits + // them to `concurrentStages` after the slot check passes. This test verifies that a slot- + // check failure leaves `concurrentStages` empty (rather than leaking the visited stages). + sc.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, false) + try { + // local[2] gives us 2 slots; a 4-partition job exceeds that. + val rdd = new MyRDD(sc, 4, Nil) + + val jobListener = new TestJobListener() + submit(rdd, Array(0, 1, 2, 3), properties = testProperties, listener = jobListener) + + assert(jobListener.expectFailure().getMessage.contains( + "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT")) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + assert(concurrentScheduler.concurrentStages.isEmpty, + s"concurrentStages should be empty after slot-check failure, but contains: " + + s"${concurrentScheduler.concurrentStages}") + } finally { + sc.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, true) + } + } + + test("concurrentStages and dependentStageMap are cleaned up after job cancellation") { + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + val jobId = submit(resultStage, Array(0), properties = testProperties) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + // Both stages running concurrently and dependentStageMap is populated. + assert(scheduler.runningStages.map(_.id) === Set(0, 1)) + assert(concurrentScheduler.dependentStageMap.nonEmpty) + + // Cancel the job mid-execution. handleJobCancellation marks all stages of the cancelled + // job finished via markStageAsFinished, which runs our cleanup. + cancel(jobId) + + assert(concurrentScheduler.concurrentStages.isEmpty, + s"concurrentStages should be empty after cancellation, " + + s"but contains: ${concurrentScheduler.concurrentStages}") + assert(concurrentScheduler.dependentStageMap.isEmpty, + s"dependentStageMap should be empty after cancellation, " + + s"but contains: ${concurrentScheduler.dependentStageMap}") + } + + test("concurrentStages and dependentStageMap are cleaned up after executor-loss " + + "induced abort") { + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + submit(resultStage, Array(0), properties = testProperties) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + assert(concurrentScheduler.dependentStageMap.nonEmpty) + + // In real-time mode TaskSchedulerImpl caps maxFailures at 1 and TaskSetManager counts + // ExecutorLostFailure toward task failures, so a single executor loss aborts the + // TaskSet immediately. Simulate the resulting TaskSetFailed event. + failed(taskSets(1), "executor lost: simulated for test") + + assert(concurrentScheduler.concurrentStages.isEmpty, + s"concurrentStages should be empty after executor-loss abort, " + + s"but contains: ${concurrentScheduler.concurrentStages}") + assert(concurrentScheduler.dependentStageMap.isEmpty, + s"dependentStageMap should be empty after executor-loss abort, " + + s"but contains: ${concurrentScheduler.dependentStageMap}") + } + + test("Should fail if speculative execution is enabled (per-job property)") { + // Try to a run job with two stages with speculative execution as a per-job local + // property. It should fail the job with exception. val mapStage = new MyRDD(sc, 1, Nil) // stage_0 val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) @@ -277,4 +421,25 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { assert(jobListener.expectFailure().getMessage.contains( "Speculative execution is not supported with concurrent stages")) } + + test("Should fail if speculative execution is enabled (cluster-wide SparkConf)") { + // Same as the previous test, but speculation is set on SparkConf (the documented way to + // enable speculation) instead of the per-job local property. Every other consumer of + // SPECULATION_ENABLED reads it via sc.conf, so this is the common case. + + sc.conf.set(SPECULATION_ENABLED, true) + try { + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + val jobListener = new TestJobListener() + submit(resultStage, Array(0), properties = testProperties, listener = jobListener) + + assert(jobListener.expectFailure().getMessage.contains( + "Speculative execution is not supported with concurrent stages")) + } finally { + sc.conf.set(SPECULATION_ENABLED, false) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d09bf617df396..bc0f08e20f8fb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -567,12 +567,12 @@ abstract class DAGSchedulerSuiteBase extends SparkFunSuite with TempLocalSparkCo } /** Sends TaskSetFailed to the scheduler. */ - private def failed(taskSet: TaskSet, message: String): Unit = { + protected def failed(taskSet: TaskSet, message: String): Unit = { runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ - private def cancel(jobId: Int): Unit = { + protected def cancel(jobId: Int): Unit = { runEvent(JobCancelled(jobId, None)) } @@ -6133,8 +6133,15 @@ abstract class DAGSchedulerSuiteBase extends SparkFunSuite with TempLocalSparkCo assert(scheduler.shuffleIdToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) assert(scheduler.outputCommitCoordinator.isEmpty) + extraEmptyChecks() } + /** + * Hook for subclasses to extend the empty-state assertions with their own state checks. + * Default is a no-op. + */ + protected def extraEmptyChecks(): Unit = () + // Nothing in this test should break if the task info's fields are null, but // OutputCommitCoordinator requires the task info itself to not be null. private def createFakeTaskInfo(): TaskInfo = { From c81425ac47446e3a3ff6b7075bba8e650bdf88cd Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Thu, 28 May 2026 07:35:40 +0000 Subject: [PATCH 3/5] improving comments --- .../spark/scheduler/ConcurrentStageDAGScheduler.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala index 1bd884f29ec15..81840a5682866 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala @@ -200,8 +200,8 @@ class ConcurrentStageDAGScheduler( } } - // This is overridden to check if the task completion event should be delayed a parent stage - // till has running tasks. See comment for `dependentStageMap` for more details. + // This is overridden to check if the task completion event should be delayed because a + // parent stage still has running tasks. See comment for `dependentStageMap` for more details. override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { val stageId = event.task.stageId val taskId = event.taskInfo.taskId @@ -249,6 +249,13 @@ class ConcurrentStageDAGScheduler( // removed the entry, so this is a no-op. On failure / cancellation / abort the entry — // and any buffered completion events — would otherwise leak for the lifetime of the // scheduler. + // + // `willRetry=true` paths (e.g. FetchFailed) also reach this cleanup. That is safe under + // concurrent scheduling because stage retries are not supported here: TaskSchedulerImpl + // pins `maxFailures=1` for concurrent TaskSets, and any failure restarts the streaming + // query from its checkpoint rather than retrying tasks against an in-flight streaming + // shuffle. With no retry to preserve state for, it's correct to drop the entry along + // with any buffered events. dependentStageMap.remove(stage) } From 3a19f4b36cdf751b7c87a338834233cbdeb101b2 Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Thu, 28 May 2026 22:39:56 +0000 Subject: [PATCH 4/5] fixing linter --- .../scheduler/ConcurrentStageDAGScheduler.scala | 6 +++--- .../org/apache/spark/scheduler/TaskSetManager.scala | 4 ++-- .../scheduler/ConcurrentStageDAGSchedulerSuite.scala | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala index 81840a5682866..9c6b5350758d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala @@ -139,7 +139,7 @@ class ConcurrentStageDAGScheduler( } } - // Slot check passed (or was disabled) — commit the visited stages. + // Slot check passed (or was disabled). Commit the visited stages. concurrentStages ++= visitedStages } else { super.onFinalStageCreated(finalStage, properties) @@ -246,8 +246,8 @@ class ConcurrentStageDAGScheduler( // Drop this stage's own entry from the map. On the success path // `checkDependentStageTasks` (invoked when the stage's last parent finishes) has already - // removed the entry, so this is a no-op. On failure / cancellation / abort the entry — - // and any buffered completion events — would otherwise leak for the lifetime of the + // removed the entry, so this is a no-op. On failure / cancellation / abort the entry, + // and any buffered completion events, would otherwise leak for the lifetime of the // scheduler. // // `willRetry=true` paths (e.g. FetchFailed) also reach this cleanup. That is safe under diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 52c37153ba84d..dfc25c03f0842 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1077,8 +1077,8 @@ private[spark] class TaskSetManager( accumUpdates, metricPeaks) val countTowardsTaskFailures = reason.countTowardsTaskFailures || - // if the query is running in real time mode, any failures should count toward the task failures - // so that the query can restart. + // in real-time mode, any failure should count toward the task failures so that the + // query can restart. ConcurrentStageDAGScheduler.isConcurrentStagesEnabled(taskSet.properties) if (!isZombie && countTowardsTaskFailures) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala index 766a68a649528..e3b6b8ae8f7b2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala @@ -49,8 +49,8 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { } /** - * Asserts that the concurrent scheduler's internal state — `concurrentStages` and - * `dependentStageMap` — is empty. Called from `assertDataStructuresEmpty` and at the end of + * Asserts that the concurrent scheduler's internal state - `concurrentStages` and + * `dependentStageMap` - is empty. Called from `assertDataStructuresEmpty` and at the end of * every test via `afterEach`, so every inherited test (and every locally-defined test) gets * free regression coverage against entries leaking into these maps. */ @@ -63,7 +63,7 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { s"concurrentStages should be empty but contains: ${s.concurrentStages}") assert(s.dependentStageMap.isEmpty, s"dependentStageMap should be empty but contains: ${s.dependentStageMap}") - case _ => // Not a concurrent scheduler — nothing extra to assert. + case _ => // Not a concurrent scheduler - nothing extra to assert. } } @@ -291,7 +291,7 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { // This exercises the cleanup path in markStageAsFinished. When the dependent stage // (here, B) aborts before its parent (A) finishes, the cascade through // failJobAndIndependentStages only marks stages that are *independent* to the failing job - // as finished — shared stages are left alone. Without the explicit + // as finished - shared stages are left alone. Without the explicit // `dependentStageMap.remove(stage)` at the end of markStageAsFinished, B's entry would // leak in dependentStageMap until A eventually finished for the other job. @@ -305,7 +305,7 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { assert(scheduler.runningStages.exists(_.rdd.name == "rddA"), "rddA's stage should be running after job 1 submission") - // Job 2 (concurrent): rddB also depends on the same shuffleDepA → rddA's stage is shared. + // Job 2 (concurrent): rddB also depends on the same shuffleDepA to rddA's stage is shared. val rddB = new MyRDD(sc, 3, List(shuffleDepA)).setName("rddB") submit(rddB, Array(0), properties = testProperties) @@ -316,7 +316,7 @@ class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { assert(depStageMap.values.flatMap(_.parents.map(_.rdd.name)).toSet === Set("rddA")) // Fail rddB's taskset. abortStage's failJobAndIndependentStages only marks rddB - // finished — rddA is shared with job 1, so it is NOT cancelled. + // finished - rddA is shared with job 1, so it is NOT cancelled. val taskSetB = taskSets.find(_.tasks.head.stageId == scheduler.runningStages.find(_.rdd.name == "rddB").get.id).get failed(taskSetB, "test failure: rddB aborted before parent rddA finished") From 67b57cc196bafc4b26c60d9d275b1c24044d6bc0 Mon Sep 17 00:00:00 2001 From: Jerry Peng Date: Tue, 2 Jun 2026 01:41:41 +0000 Subject: [PATCH 5/5] address comment --- .../apache/spark/scheduler/ConcurrentStageDAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala index 9c6b5350758d1..c3c24c2014238 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala @@ -263,7 +263,7 @@ class ConcurrentStageDAGScheduler( // enqueues any saved task completion event (if any). private def checkDependentStageTasks(stage: Stage): Unit = { val dependentStageInfo = dependentStageMap.getOrElse( - stage, throw new RuntimeException(s"Stage $stage is not in dependentStageMap") + stage, throw new IllegalStateException(s"Stage $stage is not in dependentStageMap") ) if (dependentStageInfo.parents.isEmpty) {