|
19 | 19 |
|
20 | 20 | package org.apache.spark.sql.comet |
21 | 21 |
|
22 | | -import org.apache.spark.TaskContext |
| 22 | +import java.util.UUID |
| 23 | +import java.util.concurrent.{Future, TimeoutException, TimeUnit} |
| 24 | + |
| 25 | +import scala.concurrent.Promise |
| 26 | +import scala.util.control.NonFatal |
| 27 | + |
| 28 | +import org.apache.spark.{broadcast, SparkException, TaskContext} |
23 | 29 | import org.apache.spark.rdd.RDD |
24 | 30 | import org.apache.spark.sql.catalyst.InternalRow |
25 | 31 | import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} |
26 | 32 | import org.apache.spark.sql.catalyst.plans.physical.Partitioning |
27 | | -import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} |
| 33 | +import org.apache.spark.sql.comet.util.{Utils => CometUtils} |
| 34 | +import org.apache.spark.sql.errors.QueryExecutionErrors |
| 35 | +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan, SQLExecution} |
| 36 | +import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec |
| 37 | +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec |
28 | 38 | import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} |
29 | 39 | import org.apache.spark.sql.types.StructType |
30 | | -import org.apache.spark.util.Utils |
| 40 | +import org.apache.spark.util.{SparkFatalException, Utils} |
31 | 41 |
|
32 | 42 | import org.apache.comet.{CometConf, NativeColumnarToRowConverter} |
33 | 43 |
|
@@ -64,6 +74,116 @@ case class CometNativeColumnarToRowExec(child: SparkPlan) |
64 | 74 | "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), |
65 | 75 | "convertTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time in conversion")) |
66 | 76 |
|
| 77 | + @transient |
| 78 | + private lazy val promise = Promise[broadcast.Broadcast[Any]]() |
| 79 | + |
| 80 | + @transient |
| 81 | + private val timeout: Long = conf.broadcastTimeout |
| 82 | + |
| 83 | + private val runId: UUID = UUID.randomUUID |
| 84 | + |
| 85 | + private lazy val cometBroadcastExchange = findCometBroadcastExchange(child) |
| 86 | + |
| 87 | + @transient |
| 88 | + lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { |
| 89 | + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( |
| 90 | + session, |
| 91 | + CometBroadcastExchangeExec.executionContext) { |
| 92 | + try { |
| 93 | + // Setup a job group here so later it may get cancelled by groupId if necessary. |
| 94 | + sparkContext.setJobGroup( |
| 95 | + runId.toString, |
| 96 | + s"CometNativeColumnarToRow broadcast exchange (runId $runId)", |
| 97 | + interruptOnCancel = true) |
| 98 | + |
| 99 | + val numOutputRows = longMetric("numOutputRows") |
| 100 | + val numInputBatches = longMetric("numInputBatches") |
| 101 | + val localSchema = this.schema |
| 102 | + val batchSize = CometConf.COMET_BATCH_SIZE.get() |
| 103 | + val broadcastColumnar = child.executeBroadcast() |
| 104 | + val serializedBatches = |
| 105 | + broadcastColumnar.value.asInstanceOf[Array[org.apache.spark.util.io.ChunkedByteBuffer]] |
| 106 | + |
| 107 | + // Use native converter to convert columnar data to rows |
| 108 | + val converter = new NativeColumnarToRowConverter(localSchema, batchSize) |
| 109 | + try { |
| 110 | + val rows = serializedBatches.iterator |
| 111 | + .flatMap(CometUtils.decodeBatches(_, this.getClass.getSimpleName)) |
| 112 | + .flatMap { batch => |
| 113 | + numInputBatches += 1 |
| 114 | + numOutputRows += batch.numRows() |
| 115 | + val result = converter.convert(batch) |
| 116 | + // Wrap iterator to close batch after consumption |
| 117 | + new Iterator[InternalRow] { |
| 118 | + override def hasNext: Boolean = { |
| 119 | + val hasMore = result.hasNext |
| 120 | + if (!hasMore) { |
| 121 | + batch.close() |
| 122 | + } |
| 123 | + hasMore |
| 124 | + } |
| 125 | + override def next(): InternalRow = result.next() |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + val mode = cometBroadcastExchange.get.mode |
| 130 | + val relation = mode.transform(rows, Some(numOutputRows.value)) |
| 131 | + val broadcasted = sparkContext.broadcastInternal(relation, serializedOnly = true) |
| 132 | + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) |
| 133 | + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) |
| 134 | + promise.trySuccess(broadcasted) |
| 135 | + broadcasted |
| 136 | + } finally { |
| 137 | + converter.close() |
| 138 | + } |
| 139 | + } catch { |
| 140 | + // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw |
| 141 | + // SparkFatalException, which is a subclass of Exception. ThreadUtils.awaitResult |
| 142 | + // will catch this exception and re-throw the wrapped fatal throwable. |
| 143 | + case oe: OutOfMemoryError => |
| 144 | + val ex = new SparkFatalException(oe) |
| 145 | + promise.tryFailure(ex) |
| 146 | + throw ex |
| 147 | + case e if !NonFatal(e) => |
| 148 | + val ex = new SparkFatalException(e) |
| 149 | + promise.tryFailure(ex) |
| 150 | + throw ex |
| 151 | + case e: Throwable => |
| 152 | + promise.tryFailure(e) |
| 153 | + throw e |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { |
| 159 | + if (cometBroadcastExchange.isEmpty) { |
| 160 | + throw new SparkException( |
| 161 | + "CometNativeColumnarToRowExec only supports doExecuteBroadcast when child contains a " + |
| 162 | + "CometBroadcastExchange, but got " + child) |
| 163 | + } |
| 164 | + |
| 165 | + try { |
| 166 | + relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]] |
| 167 | + } catch { |
| 168 | + case ex: TimeoutException => |
| 169 | + logError(s"Could not execute broadcast in $timeout secs.", ex) |
| 170 | + if (!relationFuture.isDone) { |
| 171 | + sparkContext.cancelJobGroup(runId.toString) |
| 172 | + relationFuture.cancel(true) |
| 173 | + } |
| 174 | + throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex)) |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + private def findCometBroadcastExchange(op: SparkPlan): Option[CometBroadcastExchangeExec] = { |
| 179 | + op match { |
| 180 | + case b: CometBroadcastExchangeExec => Some(b) |
| 181 | + case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan) |
| 182 | + case b: ReusedExchangeExec => findCometBroadcastExchange(b.child) |
| 183 | + case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange)) |
| 184 | + } |
| 185 | + } |
| 186 | + |
67 | 187 | override def doExecute(): RDD[InternalRow] = { |
68 | 188 | val numOutputRows = longMetric("numOutputRows") |
69 | 189 | val numInputBatches = longMetric("numInputBatches") |
@@ -91,7 +211,17 @@ case class CometNativeColumnarToRowExec(child: SparkPlan) |
91 | 211 | val result = converter.convert(batch) |
92 | 212 | convertTime += System.nanoTime() - startTime |
93 | 213 |
|
94 | | - result |
| 214 | + // Wrap iterator to close batch after consumption |
| 215 | + new Iterator[InternalRow] { |
| 216 | + override def hasNext: Boolean = { |
| 217 | + val hasMore = result.hasNext |
| 218 | + if (!hasMore) { |
| 219 | + batch.close() |
| 220 | + } |
| 221 | + hasMore |
| 222 | + } |
| 223 | + override def next(): InternalRow = result.next() |
| 224 | + } |
95 | 225 | } |
96 | 226 | } |
97 | 227 | } |
|
0 commit comments