diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index e92ef6f462a3f..fc3777d1a93fd 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -577,6 +577,7 @@ public enum LogKeys implements LogKey { OUTPUT_BUFFER, OVERHEAD_MEMORY_SIZE, PAGE_SIZE, + PARENT_STAGE, PARENT_STAGES, PARSE_MODE, PARTITIONED_FILE_READER, @@ -792,6 +793,7 @@ public enum LogKeys implements LogKey { STREAMING_DATA_SOURCE_NAME, STREAMING_OFFSETS_END, STREAMING_OFFSETS_START, + STREAMING_QUERY_ID, STREAMING_QUERY_PROGRESS, STREAMING_SOURCE, STREAMING_TABLE, diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f1e162a6260f7..ef0e540a7b430 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -890,6 +890,12 @@ ], "sqlState" : "0A000" }, + "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT" : { + "message" : [ + "The minimum number of free slots required in the cluster is , however, the cluster has only slots free. Query will stall or fail. Increase cluster size to proceed." + ], + "sqlState" : "53000" + }, "CONCURRENT_STREAM_LOG_UPDATE" : { "message" : [ "Concurrent update to the log. Multiple streaming jobs detected for .", diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0262144490ce8..6cd9d3895e9ac 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -600,7 +600,11 @@ class SparkContext(config: SparkConf) extends Logging { val (sched, ts) = SparkContext.createTaskScheduler(this, master) _schedulerBackend = sched _taskScheduler = ts - _dagScheduler = new DAGScheduler(this) + _dagScheduler = conf.get(DAG_SCHEDULER_TYPE) match { + case "ConcurrentStageDAGScheduler" => + new ConcurrentStageDAGScheduler(this) + case _ => new DAGScheduler(this) + } _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) if (_conf.get(EXECUTOR_ALLOW_SYNC_LOG_LEVEL)) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 86e5422a85515..0ea0b1a46f691 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2396,6 +2396,26 @@ package object config { .booleanConf .createWithDefault(true) + private[spark] val STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED = + ConfigBuilder("spark.scheduler.realtimeModeSlotsCheck.disabled") + .internal() + .doc("For query running in real time mode, disable the check if the number of slots" + + " required by all concurrent stages is available before submit the query" ) + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .version("4.2.0") + .booleanConf + .createWithDefault(false) + + private[spark] val DAG_SCHEDULER_TYPE = + ConfigBuilder("spark.scheduler.dagSchedulerType") + .internal() + .doc("The DAGScheduler implementation to use. Set to 'ConcurrentStageDAGScheduler' to " + + "enable real-time mode, which runs stages concurrently for low-latency streaming queries.") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .version("4.2.0") + .stringConf + .createWithDefault("DAGScheduler") + private[spark] val STREAMING_ID_AWARE_SCHEDULER_LOGGING_QUERY_ID_LENGTH = ConfigBuilder("spark.scheduler.streaming.idAwareLogging.queryIdLength") .doc("Maximum number of characters of the streaming query ID to include " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala new file mode 100644 index 0000000000000..c3c24c2014238 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +import scala.collection.mutable + +import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, SparkException, SparkRuntimeException, Success} +import org.apache.spark.internal.LogKeys +import org.apache.spark.internal.config.{SPECULATION_ENABLED, STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED} +import org.apache.spark.resource.ResourceProfile +import org.apache.spark.storage.BlockManagerMaster +import org.apache.spark.util.Clock +import org.apache.spark.util.SystemClock + +/** + * A [[DAGScheduler]] that runs all the stages in a job without waiting for its parents + * complete. This combined with streaming shuffle between the stages, allows for low latency + * execution of streaming queries in real-time mode. + */ +class ConcurrentStageDAGScheduler( + sc: SparkContext, + taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock()) + extends DAGScheduler( + sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) { + + import ConcurrentStageDAGScheduler._ + + def this(sc: SparkContext, taskScheduler: TaskScheduler) = { + this( + sc, + taskScheduler, + sc.listenerBus, + sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + sc.env.blockManager.master, + sc.env + ) + } + + def this(sc: SparkContext) = this(sc, sc.taskScheduler) + + // This contains all the concurrent stages that are yet to be scheduled across all the jobs. + private[spark] val concurrentStages = new mutable.HashSet[Stage] + + private[scheduler] case class DependentStageInfo( + parents: mutable.HashSet[Stage] = mutable.HashSet.empty, + delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = mutable.ListBuffer.empty) + + // This map holds parents of concurrently scheduled stages. When tasks for such a stage complete, + // and if any of the parents are still running, we delay processing of such events until parent + // stages are complete. We save these events in this map until then. + private[spark] val dependentStageMap = new mutable.HashMap[Stage, DependentStageInfo] + + private def totalNumCoreForStage(stage: Stage): Int = { + val numTask = stage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.numPartitions + } + val resourceProfile = sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId) + val taskCpus = ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf) + taskCpus * numTask + } + + /** + * Hook invoked after the final stage is created. Registers stages reachable from + * the final stage as concurrent so they can be submitted in parallel. + */ + override def onFinalStageCreated(finalStage: Stage, properties: Properties): Unit = { + + val queryBatchId = getStreamingBatchIdFromProperties(properties) + + if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) { + // Speculation is not supported with concurrent stages. Check both the per-job local + // property (for jobs that override the cluster default via setLocalProperty) and the + // SparkConf (the documented way to enable speculation cluster-wide). + if (properties.getProperty(SPECULATION_ENABLED.key) == "true" || + sc.conf.get(SPECULATION_ENABLED)) { + throw new SparkException( + "Speculative execution is not supported with concurrent stages " + + s"(streaming query: $queryBatchId). Please disable ${SPECULATION_ENABLED.key} config." + ) + } + + logInfo(log"Concurrent stages is enabled for [query ${MDC(LogKeys.STREAMING_QUERY_ID, + queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, queryBatchId.get.batchId)}]") + + // Mark current stage and all its ancestors as concurrent. + // Collect into a local set first so a slot-check failure below does not leak partial + // state into concurrentStages. + val visitedStages = new mutable.HashSet[Stage] + var totalCoresNeeded = 0 + def visit(stage: Stage): Unit = { + if (!visitedStages.contains(stage)) { + logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent for [query ${MDC( + LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC( + LogKeys.BATCH_ID, queryBatchId.get.batchId)}]") + visitedStages += stage + totalCoresNeeded += totalNumCoreForStage(stage) + stage.parents.foreach(visit) + } + } + visit(finalStage) + + if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) { + try { + val totalSlots = sc.schedulerBackend.defaultParallelism() + val coresInUse = runningStages.toArray.map(totalNumCoreForStage(_)).sum + if (totalSlots - coresInUse < totalCoresNeeded) { + throw new SparkRuntimeException( + errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT", + messageParameters = Map( + "numSlots" -> (totalSlots - coresInUse).toString, + "numTasks" -> totalCoresNeeded.toString)) + } + } catch { + case e: UnsupportedOperationException => + logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for RTM.") + } + } + + // Slot check passed (or was disabled). Commit the visited stages. + concurrentStages ++= visitedStages + } else { + super.onFinalStageCreated(finalStage, properties) + } + } + + override def submitStage(stage: Stage): Unit = { + super.submitStage(stage) + + if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) { + // The current stage is not registered in waitingStages, which means it has + // no parents. This case we should remove it from concurrentStages since it is already + // running. + assert(runningStages.contains(stage), "stage should be running if not in waitingStages") + logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from concurrentStages") + concurrentStages -= stage + } + + // Find the stages that should be submitted concurrently with this stage. + waitingStages.intersect(concurrentStages).foreach { stage => + logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}") + concurrentStages -= stage // Don't submit this stage concurrently for subsequent attempts. + stage.parents.foreach { parent => + if (isRunningStage(parent)) { + logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE, stage)} with parent ${ + MDC(LogKeys.PARENT_STAGE, parent)}") + dependentStageMap.getOrElseUpdate(stage, DependentStageInfo()).parents += parent + } + } + // Remove stage and its parents from concurrentStages + def removeFromConcurrentStages(stage: Stage): Unit = { + if (concurrentStages.contains(stage)) { + logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from concurrentStages") + concurrentStages -= stage + } + stage.parents.foreach { parent => + assert(!waitingStages.contains(parent), "Parent stage should not still be waiting") + removeFromConcurrentStages(parent) + } + } + removeFromConcurrentStages(stage) + submitConcurrentStage(stage) + } + } + + /** + * Submits a child stage even while its parents are still running. Distinct from + * `submitStage` in that it bypasses the missing-parent check. + */ + private def submitConcurrentStage(stage: Stage): Unit = { + assert(waitingStages.contains(stage)) + activeJobForStage(stage) match { + case Some(job) => + waitingStages -= stage + submitMissingTasks(stage, job) + case None => // Not expected. + throw new IllegalStateException(s"No active job for stage $stage") + } + } + + // This is overridden to check if the task completion event should be delayed because a + // parent stage still has running tasks. See comment for `dependentStageMap` for more details. + override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { + val stageId = event.task.stageId + val taskId = event.taskInfo.taskId + + getStage(stageId) match { + case Some(stage) if event.reason == Success && dependentStageMap.contains(stage) => + val dependentStageInfo = dependentStageMap(stage) + logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID, taskId)} in stage ${ + MDC(LogKeys.STAGE, stage)}. Active parent(s): ${MDC(LogKeys.PARENT_STAGES, + dependentStageInfo.parents.mkString(", "))}") + dependentStageInfo.delayedTaskCompletionEvents += event + + case _ => // Otherwise handle the event as usual. + super.handleTaskCompletion(event) + } + } + + // This is overridden to handle any delayed task completion events for dependent stages. + override def markStageAsFinished( + stage: Stage, + errorMessage: Option[String] = None, + willRetry: Boolean = false): Unit = { + + super.markStageAsFinished(stage, errorMessage, willRetry) + + // If this is a parent of a stage in dependentStageMap, remove it from parents. + val dependentStages = dependentStageMap + .filter(_._2.parents.contains(stage)) + .keys + + dependentStages.foreach { dependent => + if (errorMessage.isEmpty) { + assert( + isRunningStage(dependent), + s"Parent stages $stage's dependent stage $dependent should be running") + } + logInfo(log"Removing parent stage ${MDC(LogKeys.PARENT_STAGE, stage)} from dependent map " + + log"for stage ${MDC(LogKeys.STAGE, dependent)}") + dependentStageMap(dependent).parents -= stage + checkDependentStageTasks(dependent) + } + + // Drop this stage's own entry from the map. On the success path + // `checkDependentStageTasks` (invoked when the stage's last parent finishes) has already + // removed the entry, so this is a no-op. On failure / cancellation / abort the entry, + // and any buffered completion events, would otherwise leak for the lifetime of the + // scheduler. + // + // `willRetry=true` paths (e.g. FetchFailed) also reach this cleanup. That is safe under + // concurrent scheduling because stage retries are not supported here: TaskSchedulerImpl + // pins `maxFailures=1` for concurrent TaskSets, and any failure restarts the streaming + // query from its checkpoint rather than retrying tasks against an in-flight streaming + // shuffle. With no retry to preserve state for, it's correct to drop the entry along + // with any buffered events. + dependentStageMap.remove(stage) + } + + // Checks if the dependent stage's parents are all done. If all the parents are done, + // enqueues any saved task completion event (if any). + private def checkDependentStageTasks(stage: Stage): Unit = { + val dependentStageInfo = dependentStageMap.getOrElse( + stage, throw new IllegalStateException(s"Stage $stage is not in dependentStageMap") + ) + + if (dependentStageInfo.parents.isEmpty) { + val delayedEvents = dependentStageInfo.delayedTaskCompletionEvents + logInfo(log"All the parents are done for ${MDC(LogKeys.STAGE, stage)}. Removing it from " + + log"the map. It has ${MDC(LogKeys.NUM_EVENTS, delayedEvents.size.toLong)} " + + log"task completion events") + dependentStageMap -= stage + delayedEvents.foreach { event => + logInfo(log"Posting delayed task ${MDC(LogKeys.TASK_ID, event.taskInfo.taskId)} " + + log"completion event for stage ${MDC(LogKeys.STAGE, stage)}") + eventProcessLoop.post(event) + } + } + } +} + +object ConcurrentStageDAGScheduler { + + val CONCURRENT_STAGES_ENABLED_PROPERTY: String = "streaming.concurrent.stages.enabled" + + def isConcurrentStagesEnabled(properties: Properties): Boolean = { + properties != null && + properties.getProperty(CONCURRENT_STAGES_ENABLED_PROPERTY) == "true" + } + + /** + * Extracts the [[StreamingBatchId]] from the given properties if both the streaming + * query id and batch id are present. + */ + def getStreamingBatchIdFromProperties(properties: Properties): Option[StreamingBatchId] = { + if (properties == null) { + return None + } + + val queryId = Option(properties.getProperty( + StructuredStreamingIdAwareSchedulerLogging.QUERY_ID_KEY)) + val batchId = Option(properties.getProperty( + StructuredStreamingIdAwareSchedulerLogging.BATCH_ID_KEY)) + if (queryId.nonEmpty && batchId.nonEmpty) { + Some(StreamingBatchId(queryId.get, batchId.get.toLong)) + } else { + None + } + } +} + +/** + * Case class to identify a batch in a streaming query. + * + * @param queryId - Streaming query id + * @param batchId - Batch id for a micro batch in a streaming query + */ +case class StreamingBatchId(queryId: String, batchId: Long) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22720b98aafde..2973a7639be69 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1272,12 +1272,20 @@ private[spark] class DAGScheduler( } } + private[scheduler] def isRunningStage(stage: Stage): Boolean = { + runningStages.contains(stage) + } + + private[scheduler] def getStage(stageId: Int): Option[Stage] = { + stageIdToStage.get(stageId) + } + /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). // That should take care of at least part of the priority inversion problem with // cross-job dependencies. - private def activeJobForStage(stage: Stage): Option[Int] = { + protected def activeJobForStage(stage: Stage): Option[Int] = { val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted jobsThatUseStage.find(jobIdToActiveJob.contains) } @@ -1385,6 +1393,8 @@ private[spark] class DAGScheduler( listenerBus.post(SparkListenerTaskGettingResult(taskInfo)) } + protected def onFinalStageCreated(finalStage: Stage, properties: Properties): Unit = {} + private def getQueryExecutionIdFromProperties(properties: Properties): Option[Long] = { try { Option(properties) @@ -1420,6 +1430,7 @@ private[spark] class DAGScheduler( // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite) + onFinalStageCreated(finalStage, properties) } catch { case e: BarrierJobSlotsNumberCheckFailed => // If jobId doesn't exist in the map, Scala coverts its value null to 0: Int automatically. @@ -1497,6 +1508,7 @@ private[spark] class DAGScheduler( // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. finalStage = getOrCreateShuffleMapStage(dependency, jobId) + onFinalStageCreated(finalStage, properties) } catch { case e: Exception => logWarning(log"Creating new stage failed due to exception - job: ${MDC(JOB_ID, jobId)}", e) @@ -1537,7 +1549,7 @@ private[spark] class DAGScheduler( } /** Submits stage, but first recursively submits any missing parents. */ - private def submitStage(stage: Stage): Unit = { + protected def submitStage(stage: Stage): Unit = { val jobId = activeJobForStage(stage) if (jobId.isDefined) { logDebug(s"submitStage($stage (name=${stage.name};" + @@ -1632,7 +1644,7 @@ private[spark] class DAGScheduler( } /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage, jobId: Int): Unit = { + protected def submitMissingTasks(stage: Stage, jobId: Int): Unit = { logDebug("submitMissingTasks(" + stage + ")") // For statically indeterminate stages being retried, we trigger rollback BEFORE task @@ -3253,7 +3265,7 @@ private[spark] class DAGScheduler( /** * Marks a stage as finished and removes it from the list of running stages. */ - private def markStageAsFinished( + protected def markStageAsFinished( stage: Stage, errorMessage: Option[String] = None, willRetry: Boolean = false): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 618c8eb459026..6078a91f0f591 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -245,8 +245,15 @@ private[spark] class TaskSchedulerImpl( logInfo(log"Adding task set " + taskSet.logId + log" with ${MDC(LogKeys.NUM_TASKS, tasks.length)} tasks resource profile " + log"${MDC(LogKeys.RESOURCE_PROFILE_ID, taskSet.resourceProfileId)}") + val maxFailures = if (ConcurrentStageDAGScheduler + .isConcurrentStagesEnabled(taskSet.properties)) { + logInfo(s"Task retries are disabled for task set ${taskSet.id} with ${tasks.length} tasks") + 1 // Concurrent stage execution does not support task retries. + } else { + maxTaskFailures + } this.synchronized { - val manager = createTaskSetManager(taskSet, maxTaskFailures) + val manager = createTaskSetManager(taskSet, maxFailures) val stage = taskSet.stageId val stageTaskSets = taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c077a7a3bbb8..dfc25c03f0842 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1051,10 +1051,17 @@ private[spark] class TaskSetManager( logWarning(failureReason) None - case e: ExecutorLostFailure if !e.exitCausedByApp => - logInfo(log"${MDC(TASK_NAME, taskName(tid))} failed because while it was being computed," + - log" its executor exited for a reason unrelated to the task. " + - log"Not counting this failure towards the maximum number of failures for the task.") + case e: ExecutorLostFailure => + if (!e.exitCausedByApp + // if the query is running in real time mode, any failure should be considered + // a task failure + && !ConcurrentStageDAGScheduler.isConcurrentStagesEnabled(taskSet.properties)) { + logInfo(log"${MDC(TASK_NAME, taskName(tid))} failed because while it was being " + + log"computed, its executor exited for a reason unrelated to the task. " + + log"Not counting this failure towards the maximum number of failures for the task.") + } else { + logWarning(failureReason) + } None case _: TaskFailedReason => // TaskResultLost and others @@ -1069,7 +1076,12 @@ private[spark] class TaskSetManager( emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null, accumUpdates, metricPeaks) - if (!isZombie && reason.countTowardsTaskFailures) { + val countTowardsTaskFailures = reason.countTowardsTaskFailures || + // in real-time mode, any failure should count toward the task failures so that the + // query can restart. + ConcurrentStageDAGScheduler.isConcurrentStagesEnabled(taskSet.properties) + + if (!isZombie && countTowardsTaskFailures) { assert (null != failureReason) taskSetExcludelistHelperOpt.foreach(_.updateExcludedForFailedTask( info.host, info.executorId, index, failureReasonString)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala new file mode 100644 index 0000000000000..e3b6b8ae8f7b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties + +import org.apache.spark.HashPartitioner +import org.apache.spark.ShuffleDependency +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.internal.config.SPECULATION_ENABLED +import org.apache.spark.internal.config.STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED + +class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase { + + // The unit-test SparkContext runs in local[2] mode, but the concurrent pipelines exercised + // here often need more slots than that. Disable the slot check so the tests aren't gated by + // executor capacity. + override def conf: SparkConf = + super.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, true) + + class TestConcurrentStageDAGScheduler(sc: SparkContext) + extends ConcurrentStageDAGScheduler( + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env) + with TestDAGScheduler + + override def createInitialScheduler(sc: SparkContext): DAGScheduler = { + new TestConcurrentStageDAGScheduler(sc) + } + + /** + * Asserts that the concurrent scheduler's internal state - `concurrentStages` and + * `dependentStageMap` - is empty. Called from `assertDataStructuresEmpty` and at the end of + * every test via `afterEach`, so every inherited test (and every locally-defined test) gets + * free regression coverage against entries leaking into these maps. + */ + override protected def extraEmptyChecks(): Unit = { + // Inherited tests sometimes replace `scheduler` with a plain MyDAGScheduler (bypassing + // createInitialScheduler), so pattern-match rather than cast. + scheduler match { + case s: TestConcurrentStageDAGScheduler => + assert(s.concurrentStages.isEmpty, + s"concurrentStages should be empty but contains: ${s.concurrentStages}") + assert(s.dependentStageMap.isEmpty, + s"dependentStageMap should be empty but contains: ${s.dependentStageMap}") + case _ => // Not a concurrent scheduler - nothing extra to assert. + } + } + + override def afterEach(): Unit = { + try { + extraEmptyChecks() + } finally { + super.afterEach() + } + } + + // Catch the job failure exception with a listener. + private class TestJobListener extends JobListener { + private var failureException: Option[Exception] = None + + override def jobFailed(exception: Exception): Unit = { + failureException = Some(exception) + } + + override def taskSucceeded(index: Int, result: Any): Unit = { } + + def expectFailure(): Exception = { + assert(failureException.nonEmpty, "Job was expected to fail with an exception, but didn't") + failureException.get + } + } + + + /** Default job properties with query settings and concurrent stages enabled. */ + private val testProperties: Properties = { + val properties = new Properties() + properties.setProperty("sql.streaming.queryId", "test_query_id") + properties.setProperty("streaming.sql.batchId", "5") + properties.setProperty(ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") + new Properties(properties) { + // Make it read-only. + override def setProperty(key: String, value: String): AnyRef = { + throw new UnsupportedOperationException("Default properties are read-only.") + } + } + } + + test("Simple job with two concurrent stages") { + // Run a simple job with two stages. Both stages should be running concurrently. + + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage] + + submit(resultStage, Array(0), properties = testProperties) + + assert(scheduler.waitingStages.isEmpty) // Both are submitted. + assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are running. + + // Verify concurrent scheduler specific state. + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + + assert(concurrentScheduler.concurrentStages.isEmpty) // All are already scheduled + + val depStageMap = concurrentScheduler.dependentStageMap + assert(depStageMap.keys.map(_.id) == Set(1)) // Result stage is the key. + assert(depStageMap.values.flatMap(_.parents.map(_.id)) == Seq(0)) // Map stage is the parent. + assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty) // No completed tasks. + + // First complete the result stage. Its tasks will complete, but the actual stage would still + // be running since its parent (map stage) hasn't completed yet. + + completeNextResultStageWithSuccess(1, 0) + assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are still running. + // dependentStageMap should have the completed task from result stage enqueued. + assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).size === 1) + + // Now complete the map stage. This should complete the result stage as well. + completeShuffleMapStageSuccessfully(0, 0, 1) + + assert(scheduler.runningStages.map(_.id) === Set()) // Both stages are complete. + assert(depStageMap.isEmpty) // No more dependent stages. + + assertDataStructuresEmpty() + } + + test("Default scheduler using a simple job with concurrent stages disabled") { + // This is opposite of the previous test. Concurrent stages are disabled, so the stages should + // be submitted one after the other. + + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage] + + submit(resultStage, Array(0), properties = new Properties()) + + assert(scheduler.runningStages.map(_.id) == Set(0)) // Only the map stage is running. + assert(scheduler.waitingStages.map(_.id) == Set(1)) // Result stage is waiting. + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + assert(concurrentScheduler.concurrentStages.isEmpty) // No concurrent stages. + assert(concurrentScheduler.dependentStageMap.isEmpty) // No dependent stages. + + // Complete the map stage. This should submit the result stage. + completeShuffleMapStageSuccessfully(0, 0, 1) + + assert(scheduler.runningStages.map(_.id) == Set(1)) // Only the result stage is running. + assert(scheduler.waitingStages.map(_.id) == Set()) // No waiting stages + + completeNextResultStageWithSuccess(1, 0) + assertDataStructuresEmpty() + } + + test("Complex pipeline with many stages") { + // Run a complex pipeline with multiple stages with multiple branches. Such a pipeline not + // common, but useful to ensure scheduler works as expected. + + // Shape: + // /<-------------------- stage_D + // stage_A <--- stage_B <--- stage_C <---\ ^ + // \ \<---------/ \ | + // \ <----------------/ \ | + // stage_E <--------------------------- stage_F + + // All of these should be running concurrently. + + val rddA = new MyRDD(sc, 2, Nil).setName("rddA") + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + + val rddB = new MyRDD(sc, 1, List(shuffleDepA)).setName("rddB") + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1)) + + val rddC = new MyRDD(sc, 1, List(shuffleDepA, shuffleDepB)).setName("rddC") + val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(5)) + + val rddD = new MyRDD(sc, 4, List(shuffleDepB)).setName("rddD") + val shuffleDepD = new ShuffleDependency(rddD, new HashPartitioner(5)) + + val rddE = new MyRDD(sc, 3, Nil).setName("rddE") + val shuffleDepE = new ShuffleDependency(rddE, new HashPartitioner(5)) + + val rddF = new MyRDD(sc, 5, List(shuffleDepC, shuffleDepD, shuffleDepE)).setName("rddF") + + submit(rddF, Array(0, 1, 2, 3, 4), properties = testProperties) + + assert(scheduler.waitingStages.isEmpty) // All the stages are submitted. + assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5)) // All 6 are running. + + // Assign stage ids corresponding to the RDDs A, B, etc + def stageFor(rddName: String): Stage = scheduler.runningStages.find(_.rdd.name == rddName).get + + val sA = stageFor("rddA") + val sB = stageFor("rddB") + val sC = stageFor("rddC") + val sD = stageFor("rddD") + val sE = stageFor("rddE") + val sF = stageFor("rddF") + + // log stage id mapping for debugging: + for (name <- List("A", "B", "C", "D", "E", "F")) { + logInfo(s"Stage id for stage $name is ${stageFor("rdd" + name).id}") + } + + // Verify concurrent scheduler specific state. + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + + assert(concurrentScheduler.concurrentStages.isEmpty) // All are already scheduled + + val depStageMap = concurrentScheduler.dependentStageMap + assert(depStageMap.keys === Set(sB, sC, sD, sF)) // All non-root stages are keys. + assert(depStageMap.values.flatMap(_.parents).toSet === Set( + sA, sB, sC, sD, sE)) // All except the results stage. + assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty) // No completed tasks. + + // Complete stages in order-of-order and verify the state. + + // First complete C. Entry for C would be updated with the completed task. + assert(depStageMap(sC).delayedTaskCompletionEvents.size === 0) + completeShuffleMapStageSuccessfully(sC.id, 0, 5) // Complete stage C. + assert(depStageMap(sC).delayedTaskCompletionEvents.size === 1) + + // All the 6 stages are still 'running' since C's completion events are delayed. + assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5)) + + // Now complete stage D. This is similar to completing C. The tasks are enqueued. + assert(depStageMap(sD).delayedTaskCompletionEvents.size === 0) + completeShuffleMapStageSuccessfully(sD.id, 0, 5) // Complete stage D + assert(depStageMap(sD).delayedTaskCompletionEvents.size === 4) // 4 tasks in stage D. + + // Complete stage E. This is a root node and should complete normally. + assert(depStageMap(sF).parents.contains(sE)) // E is one of the stages that F waits for. + completeShuffleMapStageSuccessfully(sE.id, 0, 3) + assert(depStageMap(sF).parents.contains(sE) === false) // E is removed from F's parents. + + // One less running stage. + assert(scheduler.runningStages === Set(sA, sB, sC, sD, sF)) // E is complete. + + // Complete stage A. This is a root node and should complete normally. + assert(depStageMap(sC).parents.contains(sA)) // A is one of the stages that C waits for. + assert(depStageMap(sB).parents.contains(sA)) // Same for B + completeShuffleMapStageSuccessfully(sA.id, 0, 2) + assert(scheduler.runningStages === Set(sB, sC, sD, sF)) // A is complete. + assert(depStageMap(sC).parents.contains(sA) === false) // A is removed from C's parents. + // In the case B, it is bit more. Its only parent A is complete. So there is no need to + // track it in depStageMap. So it is removed from `dependentStageMap` entirely. + assert(depStageMap.contains(sB) === false) // B is removed from the depStageMap. + + // Complete result stage F. This will be enqueued, will complete later. + completeNextResultStageWithSuccess(sF.id, 0) + assert(scheduler.runningStages === Set(sB, sC, sD, sF)) // F is still running. + assert(depStageMap(sF).delayedTaskCompletionEvents.size === 5) + + // Complete stage B, it will complete normally. + completeShuffleMapStageSuccessfully(sB.id, 0, 1) + // This will trigger completion of C as well since both its parents A & B are done. + // This will also trigger completion of D as its parent B is done. + // Which finally completes F as well. + assert(scheduler.runningStages === Set()) // All stages are complete. + assert(depStageMap.isEmpty) // No more dependent stages. + + assertDataStructuresEmpty() + } + + test("dependentStageMap entry is cleaned up when a dependent stage aborts and its " + + "parent stage is shared with another job") { + // This exercises the cleanup path in markStageAsFinished. When the dependent stage + // (here, B) aborts before its parent (A) finishes, the cascade through + // failJobAndIndependentStages only marks stages that are *independent* to the failing job + // as finished - shared stages are left alone. Without the explicit + // `dependentStageMap.remove(stage)` at the end of markStageAsFinished, B's entry would + // leak in dependentStageMap until A eventually finished for the other job. + + // Job 1 (regular batch): rddC depends on rddA via shuffleDepA. + val rddA = new MyRDD(sc, 1, Nil).setName("rddA") + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + val rddC = new MyRDD(sc, 1, List(shuffleDepA)).setName("rddC") + submit(rddC, Array(0)) // properties = null implies non-concurrent. + + // After job 1: rddA's stage is running, rddC's stage is waiting. + assert(scheduler.runningStages.exists(_.rdd.name == "rddA"), + "rddA's stage should be running after job 1 submission") + + // Job 2 (concurrent): rddB also depends on the same shuffleDepA to rddA's stage is shared. + val rddB = new MyRDD(sc, 3, List(shuffleDepA)).setName("rddB") + submit(rddB, Array(0), properties = testProperties) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + val depStageMap = concurrentScheduler.dependentStageMap + // B is in dependentStageMap with rddA's stage as a parent. + assert(depStageMap.keys.map(_.rdd.name) === Set("rddB")) + assert(depStageMap.values.flatMap(_.parents.map(_.rdd.name)).toSet === Set("rddA")) + + // Fail rddB's taskset. abortStage's failJobAndIndependentStages only marks rddB + // finished - rddA is shared with job 1, so it is NOT cancelled. + val taskSetB = taskSets.find(_.tasks.head.stageId == + scheduler.runningStages.find(_.rdd.name == "rddB").get.id).get + failed(taskSetB, "test failure: rddB aborted before parent rddA finished") + + // rddA is still running for job 1. + assert(scheduler.runningStages.exists(_.rdd.name == "rddA"), + "Shared parent rddA's stage should still be running for job 1 after job 2 aborted") + + // rddB's entry must have been removed by markStageAsFinished's cleanup, even though + // rddA is still running (and would normally be the stage whose markStageAsFinished + // cleans up B's entry). + assert(depStageMap.isEmpty, + s"dependentStageMap should be empty after rddB aborted, but contains: $depStageMap") + } + + test("concurrentStages is empty after slot-check failure") { + // The DAG walk in onFinalStageCreated accumulates stages into a local set and only commits + // them to `concurrentStages` after the slot check passes. This test verifies that a slot- + // check failure leaves `concurrentStages` empty (rather than leaking the visited stages). + sc.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, false) + try { + // local[2] gives us 2 slots; a 4-partition job exceeds that. + val rdd = new MyRDD(sc, 4, Nil) + + val jobListener = new TestJobListener() + submit(rdd, Array(0, 1, 2, 3), properties = testProperties, listener = jobListener) + + assert(jobListener.expectFailure().getMessage.contains( + "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT")) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + assert(concurrentScheduler.concurrentStages.isEmpty, + s"concurrentStages should be empty after slot-check failure, but contains: " + + s"${concurrentScheduler.concurrentStages}") + } finally { + sc.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, true) + } + } + + test("concurrentStages and dependentStageMap are cleaned up after job cancellation") { + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + val jobId = submit(resultStage, Array(0), properties = testProperties) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + // Both stages running concurrently and dependentStageMap is populated. + assert(scheduler.runningStages.map(_.id) === Set(0, 1)) + assert(concurrentScheduler.dependentStageMap.nonEmpty) + + // Cancel the job mid-execution. handleJobCancellation marks all stages of the cancelled + // job finished via markStageAsFinished, which runs our cleanup. + cancel(jobId) + + assert(concurrentScheduler.concurrentStages.isEmpty, + s"concurrentStages should be empty after cancellation, " + + s"but contains: ${concurrentScheduler.concurrentStages}") + assert(concurrentScheduler.dependentStageMap.isEmpty, + s"dependentStageMap should be empty after cancellation, " + + s"but contains: ${concurrentScheduler.dependentStageMap}") + } + + test("concurrentStages and dependentStageMap are cleaned up after executor-loss " + + "induced abort") { + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + submit(resultStage, Array(0), properties = testProperties) + + val concurrentScheduler = scheduler.asInstanceOf[TestConcurrentStageDAGScheduler] + assert(concurrentScheduler.dependentStageMap.nonEmpty) + + // In real-time mode TaskSchedulerImpl caps maxFailures at 1 and TaskSetManager counts + // ExecutorLostFailure toward task failures, so a single executor loss aborts the + // TaskSet immediately. Simulate the resulting TaskSetFailed event. + failed(taskSets(1), "executor lost: simulated for test") + + assert(concurrentScheduler.concurrentStages.isEmpty, + s"concurrentStages should be empty after executor-loss abort, " + + s"but contains: ${concurrentScheduler.concurrentStages}") + assert(concurrentScheduler.dependentStageMap.isEmpty, + s"dependentStageMap should be empty after executor-loss abort, " + + s"but contains: ${concurrentScheduler.dependentStageMap}") + } + + test("Should fail if speculative execution is enabled (per-job property)") { + // Try to a run job with two stages with speculative execution as a per-job local + // property. It should fail the job with exception. + + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + val properties = new Properties(testProperties) + properties.setProperty(SPECULATION_ENABLED.key, "true") + + val jobListener = new TestJobListener() + submit(resultStage, Array(0), properties = properties, listener = jobListener) + + assert(jobListener.expectFailure().getMessage.contains( + "Speculative execution is not supported with concurrent stages")) + } + + test("Should fail if speculative execution is enabled (cluster-wide SparkConf)") { + // Same as the previous test, but speculation is set on SparkConf (the documented way to + // enable speculation) instead of the per-job local property. Every other consumer of + // SPECULATION_ENABLED reads it via sc.conf, so this is the common case. + + sc.conf.set(SPECULATION_ENABLED, true) + try { + val mapStage = new MyRDD(sc, 1, Nil) // stage_0 + val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1)) + val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1 + + val jobListener = new TestJobListener() + submit(resultStage, Array(0), properties = testProperties, listener = jobListener) + + assert(jobListener.expectFailure().getMessage.contains( + "Speculative execution is not supported with concurrent stages")) + } finally { + sc.conf.set(SPECULATION_ENABLED, false) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index e12348e1be2d7..bc0f08e20f8fb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -180,7 +180,8 @@ class DummyScheduledFuture( class DAGSchedulerSuiteDummyException extends Exception -class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with TimeLimits { +abstract class DAGSchedulerSuiteBase extends SparkFunSuite with TempLocalSparkContext + with TimeLimits { import DAGSchedulerSuite._ @@ -368,18 +369,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } - class MyDAGScheduler( - sc: SparkContext, - taskScheduler: TaskScheduler, - listenerBus: LiveListenerBus, - mapOutputTracker: MapOutputTrackerMaster, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv, - clock: Clock = new SystemClock(), - shuffleMergeFinalize: Boolean = true, - shuffleMergeRegister: Boolean = true - ) extends DAGScheduler( - sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) { + trait TestDAGScheduler extends DAGScheduler { + + protected def getShuffleMergeFinalize: Boolean = true + protected def getShuffleMergeRegister: Boolean = true + /** * Schedules shuffle merge finalize. */ @@ -387,7 +381,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti shuffleMapStage: ShuffleMapStage, delay: Long, registerMergeResults: Boolean = true): Unit = { - if (shuffleMergeRegister && registerMergeResults) { + if (getShuffleMergeRegister && registerMergeResults) { for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) { val mergeStatuses = Seq((part, makeMergeStatus("", shuffleMapStage.shuffleDep.shuffleMergeId))) @@ -403,7 +397,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti shuffleMapStage.shuffleDep.setFinalizeTask( new DummyScheduledFuture(delay, registerMergeResults)) - if (shuffleMergeFinalize) { + if (getShuffleMergeFinalize) { handleShuffleMergeFinalized(shuffleMapStage, shuffleMapStage.shuffleDep.shuffleMergeId) } } @@ -417,6 +411,24 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } + class MyDAGScheduler( + sc: SparkContext, + taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock(), + shuffleMergeFinalize: Boolean = true, + shuffleMergeRegister: Boolean = true + ) extends DAGScheduler( + sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) + with TestDAGScheduler { + + override def getShuffleMergeFinalize: Boolean = shuffleMergeFinalize + override def getShuffleMergeRegister: Boolean = shuffleMergeRegister + } + override def beforeEach(): Unit = { super.beforeEach() firstInit = true @@ -446,15 +458,19 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti new MyMapOutputTrackerMaster(sc.getConf, broadcastManager)) blockManagerMaster = spy[MyBlockManagerMaster](new MyBlockManagerMaster(sc.getConf)) doNothing().when(blockManagerMaster).updateRDDBlockVisibility(any(), any()) - scheduler = spy[MyDAGScheduler](new MyDAGScheduler( + scheduler = spy(createInitialScheduler(sc)) + + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + } + + protected def createInitialScheduler(sc: SparkContext): DAGScheduler = { + new MyDAGScheduler( sc, taskScheduler, sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env)) - - dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + sc.env) } override def afterEach(): Unit = { @@ -527,7 +543,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } /** Submits a job to the scheduler and returns the job id. */ - private def submit( + protected def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, @@ -551,12 +567,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } /** Sends TaskSetFailed to the scheduler. */ - private def failed(taskSet: TaskSet, message: String): Unit = { + protected def failed(taskSet: TaskSet, message: String): Unit = { runEvent(TaskSetFailed(taskSet, message, None)) } /** Sends JobCancelled to the DAG scheduler. */ - private def cancel(jobId: Int): Unit = { + protected def cancel(jobId: Int): Unit = { runEvent(JobCancelled(jobId, None)) } @@ -1168,7 +1184,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti * @param hostNames - Host on which each task in the task set is executed. In case no hostNames * are provided, the tasks will progressively complete on hostA, hostB, etc. */ - private def completeShuffleMapStageSuccessfully( + protected def completeShuffleMapStageSuccessfully( stageId: Int, attemptIdx: Int, numShufflePartitions: Int, @@ -1219,7 +1235,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti * @param stageId - The current stageId * @param attemptIdx - The current attempt count */ - private def completeNextResultStageWithSuccess( + protected def completeNextResultStageWithSuccess( stageId: Int, attemptIdx: Int, partitionToResult: Int => Int = _ => 42): Unit = { @@ -6107,7 +6123,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } - private def assertDataStructuresEmpty(): Unit = { + protected def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -6117,8 +6133,15 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(scheduler.shuffleIdToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) assert(scheduler.outputCommitCoordinator.isEmpty) + extraEmptyChecks() } + /** + * Hook for subclasses to extend the empty-state assertions with their own state checks. + * Default is a no-op. + */ + protected def extraEmptyChecks(): Unit = () + // Nothing in this test should break if the task info's fields are null, but // OutputCommitCoordinator requires the task info itself to not be null. private def createFakeTaskInfo(): TaskInfo = { @@ -6150,6 +6173,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } +class DAGSchedulerSuite extends DAGSchedulerSuiteBase + class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite { override def conf: SparkConf = super.conf.set(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS, false) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index c9c9e529405ec..cb1f574e201cb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -2752,4 +2752,32 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext assert(!tsm.isInstanceOf[StructuredStreamingIdAwareSchedulerLogging]) } + test("maxTaskFailures is overridden to 1 when concurrent stages are enabled") { + // Concurrent stage execution does not support task retries — a failure must restart the + // streaming query rather than be silently retried against a still-running shuffle. The + // scheduler enforces this by capping the TaskSetManager's maxFailures at 1, regardless of + // the cluster-wide spark.task.maxFailures. + val clusterMaxFailures = 4 + val taskScheduler = setupScheduler(config.TASK_MAX_FAILURES.key -> clusterMaxFailures.toString) + val props = new Properties() + props.setProperty( + ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") + val base = FakeTask.createTaskSet(numTasks = 1) + val taskSet = new TaskSet(base.tasks, base.stageId, base.stageAttemptId, base.priority, + props, base.resourceProfileId, base.shuffleId) + + taskScheduler.submitTasks(taskSet) + val tsm = taskScheduler.taskSetManagerForAttempt( + taskSet.stageId, taskSet.stageAttemptId).get + + assert(tsm.maxTaskFailures === 1) + + // Regression guard: a regular TaskSet should still get the cluster's maxTaskFailures. + val regular = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 0) + taskScheduler.submitTasks(regular) + val regularTsm = taskScheduler.taskSetManagerForAttempt( + regular.stageId, regular.stageAttemptId).get + assert(regularTsm.maxTaskFailures === clusterMaxFailures) + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 4fd710712221b..d07e52f3fc382 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2921,6 +2921,58 @@ class TaskSetManagerSuite s"\nCaptured logs:\n${logs.mkString("\n")}") } + /** + * Wraps an existing TaskSet with a copy that has the given properties. Used by the + * concurrent-stages tests below since `FakeTask.createTaskSet` always sets properties to null. + */ + private def withProperties(taskSet: TaskSet, properties: Properties): TaskSet = { + new TaskSet(taskSet.tasks, taskSet.stageId, taskSet.stageAttemptId, taskSet.priority, + properties, taskSet.resourceProfileId, taskSet.shuffleId) + } + + private def concurrentStagesProperties: Properties = { + val props = new Properties() + props.setProperty( + ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY, "true") + props + } + + test("ExecutorLostFailure counts as task failure when concurrent stages are enabled") { + // Otherwise (the default), an executor exit that wasn't "caused by the app" is exempt + // from the maxTaskFailures count. For real-time mode that exemption is wrong: an executor + // loss is a query failure, and the query should be restarted. + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = withProperties(FakeTask.createTaskSet(1), concurrentStagesProperties) + // Concurrent stage execution does not support task retries, so maxFailures = 1. + val manager = new TaskSetManager(sched, taskSet, maxTaskFailures = 1) + + val offer = manager.resourceOffer("exec1", "host1", TaskLocality.ANY)._1 + assert(offer.isDefined) + manager.handleFailedTask(offer.get.taskId, TaskState.FAILED, + ExecutorLostFailure("exec1", exitCausedByApp = false, reason = None)) + + assert(sched.taskSetsFailed.contains(taskSet.id), + "TaskSet should have been aborted after the single allowed failure") + assert(manager.isZombie) + } + + test("ExecutorLostFailure is not counted without concurrent stages (regression guard)") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(1) + val manager = new TaskSetManager(sched, taskSet, maxTaskFailures = 1) + + val offer = manager.resourceOffer("exec1", "host1", TaskLocality.ANY)._1 + assert(offer.isDefined) + manager.handleFailedTask(offer.get.taskId, TaskState.FAILED, + ExecutorLostFailure("exec1", exitCausedByApp = false, reason = None)) + + assert(!sched.taskSetsFailed.contains(taskSet.id), + "Executor loss not caused by app should not count toward task failures by default") + assert(!manager.isZombie) + } + } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) {