Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -290,21 +290,22 @@ object CometSparkSessionExtensions extends Logging {
* @return
* `node` with fallback reasons attached (as a side effect on its tag map).
*/
def withInfo[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = {
def withFallbackReason[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = {
// support existing approach of passing in multiple infos in a newline-delimited string
val infoSet = if (info == null || info.isEmpty) {
Set.empty[String]
} else {
info.split("\n").toSet
}
withInfos(node, infoSet, exprs: _*)
withFallbackReasons(node, infoSet, exprs: _*)
}

/**
* Record one or more fallback reasons on a `TreeNode` and roll up reasons from any child nodes.
* This is the set-valued form of [[withInfo]]; see that overload for the full contract.
* This is the set-valued form of [[withFallbackReason]]; see that overload for the full
* contract.
*
* Reasons are accumulated (never overwritten) on the node's `EXTENSION_INFO` tag and are
* Reasons are accumulated (never overwritten) on the node's `FALLBACK_REASONS` tag and are
* surfaced in extended explain output. When `COMET_LOG_FALLBACK_REASONS` is enabled, each new
* reason is also emitted as a warning.
*
Expand All @@ -320,16 +321,16 @@ object CometSparkSessionExtensions extends Logging {
* @return
* `node` with fallback reasons attached (as a side effect on its tag map).
*/
def withInfos[T <: TreeNode[_]](node: T, info: Set[String], exprs: T*): T = {
def withFallbackReasons[T <: TreeNode[_]](node: T, info: Set[String], exprs: T*): T = {
if (CometConf.COMET_LOG_FALLBACK_REASONS.get()) {
for (reason <- info) {
logWarning(s"Comet cannot accelerate ${node.getClass.getSimpleName} because: $reason")
}
}
val existingNodeInfos = node.getTagValue(CometExplainInfo.EXTENSION_INFO)
val existingNodeInfos = node.getTagValue(CometExplainInfo.FALLBACK_REASONS)
val newNodeInfo = (existingNodeInfos ++ exprs
.flatMap(_.getTagValue(CometExplainInfo.EXTENSION_INFO))).flatten.toSet
node.setTagValue(CometExplainInfo.EXTENSION_INFO, newNodeInfo ++ info)
.flatMap(_.getTagValue(CometExplainInfo.FALLBACK_REASONS))).flatten.toSet
node.setTagValue(CometExplainInfo.FALLBACK_REASONS, newNodeInfo ++ info)
node
}

Expand All @@ -347,17 +348,17 @@ object CometSparkSessionExtensions extends Logging {
* @return
* `node` with the rolled-up reasons attached (as a side effect on its tag map).
*/
def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = {
withInfos(node, Set.empty, exprs: _*)
def withFallbackReason[T <: TreeNode[_]](node: T, exprs: T*): T = {
withFallbackReasons(node, Set.empty, exprs: _*)
}

/**
* True if any fallback reason has been recorded on `node` (via [[withInfo]] / [[withInfos]]).
* Callers that need to short-circuit when a prior rule pass has already decided a node falls
* back can use this as the sticky signal.
* True if any fallback reason has been recorded on `node` (via [[withFallbackReason]] /
* [[withFallbackReasons]]). Callers that need to short-circuit when a prior rule pass has
* already decided a node falls back can use this as the sticky signal.
*/
def hasExplainInfo(node: TreeNode[_]): Boolean = {
node.getTagValue(CometExplainInfo.EXTENSION_INFO).exists(_.nonEmpty)
def hasFallbackReason(node: TreeNode[_]): Boolean = {
node.getTagValue(CometExplainInfo.FALLBACK_REASONS).exists(_.nonEmpty)
}

}
12 changes: 7 additions & 5 deletions spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator {
}

def getFallbackReasons(plan: SparkPlan): Seq[String] = {
extensionInfo(plan).toSeq.sorted
fallbackReasons(plan).toSeq.sorted
}

private[comet] def extensionInfo(node: TreeNode[_]): Set[String] = {
private[comet] def fallbackReasons(node: TreeNode[_]): Set[String] = {
var info = mutable.Seq[String]()
val sorted = sortup(node)
sorted.foreach { p =>
val all: Set[String] =
getActualPlan(p).getTagValue(CometExplainInfo.EXTENSION_INFO).getOrElse(Set.empty[String])
getActualPlan(p)
.getTagValue(CometExplainInfo.FALLBACK_REASONS)
.getOrElse(Set.empty[String])
for (s <- all) {
info = info :+ s
}
Expand Down Expand Up @@ -120,7 +122,7 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator {
outString.append(if (lastChildren.last) "+- " else ":- ")
}

val tagValue = node.getTagValue(CometExplainInfo.EXTENSION_INFO)
val tagValue = node.getTagValue(CometExplainInfo.FALLBACK_REASONS)
val str = if (tagValue.nonEmpty) {
s" ${node.nodeName} [COMET: ${tagValue.get.mkString(", ")}]"
} else {
Expand Down Expand Up @@ -212,7 +214,7 @@ object CometCoverageStats {
}

object CometExplainInfo {
val EXTENSION_INFO = new TreeNodeTag[Set[String]]("CometExtensionInfo")
val FALLBACK_REASONS = new TreeNodeTag[Set[String]]("CometFallbackReasons")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like CometMetricsListener is still referencing CometExplainInfo.EXTENSION_INFO. CometMetricsListener is not being used (?) but this might break CI


def getActualPlan(node: TreeNode[_]): TreeNode[_] = {
node match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {

/**
* Plan-time predicate. `None` greenlights the serde to emit the codegen proto; `Some(reason)`
* forces a Spark fallback (typically `withInfo(...) + None`) so the operator falls back cleanly
* rather than crashing the Janino compile at execute time.
* forces a Spark fallback (typically `withFallbackReason(...) + None`) so the operator falls
* back cleanly rather than crashing the Janino compile at execute time.
*
* Checks every `BoundReference`'s data type and the root `expr.dataType` against
* [[isSupportedDataType]], rejects aggregates / generators / `CodegenFallback` (other than
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType, TimestampNTZType, TimestampType}

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isSpark40Plus, withInfo}
import org.apache.comet.CometSparkSessionExtensions.{isSpark40Plus, withFallbackReason}
import org.apache.comet.serde.{CometExpressionSerde, Compatible, ExprOuterClass, Incompatible, SupportLevel, Unsupported}
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, serializeDataType}
Expand Down Expand Up @@ -81,7 +81,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
if (childExpr.isDefined) {
castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, cometEvalMode)
} else {
withInfo(cast, cast.child)
withFallbackReason(cast, cast.child)
None
}
}
Expand Down Expand Up @@ -131,7 +131,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
.setCast(castBuilder)
.build())
case _ =>
withInfo(expr, s"Unsupported datatype in castToProto: $dt")
withFallbackReason(expr, s"Unsupported datatype in castToProto: $dt")
None
}
}
Expand Down
20 changes: 11 additions & 9 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ case class CometExecRule(session: SparkSession)
} else {
// copy fallback reasons to the original plan
newPlan
.getTagValue(CometExplainInfo.EXTENSION_INFO)
.foreach(reasons => withInfos(plan, reasons))
.getTagValue(CometExplainInfo.FALLBACK_REASONS)
.foreach(reasons => withFallbackReasons(plan, reasons))
// return the original plan
plan
}
Expand Down Expand Up @@ -382,8 +382,8 @@ case class CometExecRule(session: SparkSession)
// reasons.
// 3. The operator has children that could not be converted, so execution
// has already fallen back to Spark.
if (op.children.forall(_.isInstanceOf[CometNativeExec]) && !hasExplainInfo(op)) {
withInfo(op, s"${op.nodeName} is not supported")
if (op.children.forall(_.isInstanceOf[CometNativeExec]) && !hasFallbackReason(op)) {
withFallbackReason(op, s"${op.nodeName} is not supported")
} else {
op
}
Expand Down Expand Up @@ -587,7 +587,7 @@ case class CometExecRule(session: SparkSession)
// config is enabled)
if (CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.get()) {
val info = new ExtendedExplainInfo()
if (info.extensionInfo(newPlan).nonEmpty) {
if (info.fallbackReasons(newPlan).nonEmpty) {
logWarning(
"Comet cannot execute some parts of this plan natively " +
s"(set ${CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key}=false " +
Expand Down Expand Up @@ -693,7 +693,9 @@ case class CometExecRule(session: SparkSession)
case other => Seq(other)
}
if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) {
withInfo(op, "Cannot perform native operation because input is not in Arrow format")
withFallbackReason(
op,
"Cannot perform native operation because input is not in Arrow format")
return None
}
}
Expand Down Expand Up @@ -721,7 +723,7 @@ case class CometExecRule(session: SparkSession)
if (handler.enabledConfig.forall(_.get(op.conf))) {
handler.getSupportLevel(op) match {
case Unsupported(notes) =>
withInfo(op, notes.getOrElse(""))
withFallbackReason(op, notes.getOrElse(""))
false
case Incompatible(notes) =>
val allowIncompat = CometConf.isOperatorAllowIncompat(opName)
Expand All @@ -735,7 +737,7 @@ case class CometExecRule(session: SparkSession)
true
} else {
val optionalNotes = notes.map(str => s" ($str)").getOrElse("")
withInfo(
withFallbackReason(
op,
s"$opName is not fully compatible with Spark$optionalNotes. " +
s"To enable it anyway, set $incompatConf=true. " +
Expand All @@ -749,7 +751,7 @@ case class CometExecRule(session: SparkSession)
true
}
} else {
withInfo(
withFallbackReason(
op,
s"Native support for operator $opName is disabled. " +
s"Set ${handler.enabledConfig.get.key}=true to enable it.")
Expand Down
Loading
Loading