Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
private[comet] val negativeScaleDecimalToStringReason: String =
"Negative-scale decimal requires spark.sql.legacy.allowNegativeScaleOfDecimal=true"

// When `spark.sql.legacy.castComplexTypesToString.enabled` is true, Spark wraps maps and
// structs with `[]` (instead of `{}`) when casting to string, and omits NULL elements of
// structs/maps/arrays (instead of rendering them as the literal "null"). Comet only
// implements the default formatting, so fall back to Spark for any array/map/struct to-string
// cast when the flag is enabled. The flag is internal in Spark 4.0 and defaults to false.
private[comet] val legacyCastComplexTypesToStringReason: String =
"spark.sql.legacy.castComplexTypesToString.enabled=true is not supported"

private def legacyCastComplexTypesToString: Boolean =
SQLConf.get
.getConfString("spark.sql.legacy.castComplexTypesToString.enabled", "false")
.toBoolean

def supportedTypes: Seq[DataType] =
Seq(
DataTypes.BooleanType,
Expand Down Expand Up @@ -150,6 +163,12 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
return Compatible()
}

if (toType == DataTypes.StringType && legacyCastComplexTypesToString && (fromType
.isInstanceOf[ArrayType] || fromType.isInstanceOf[StructType] ||
fromType.isInstanceOf[MapType])) {
return Unsupported(Some(legacyCastComplexTypesToStringReason))
}

(fromType, toType) match {
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
case (ArrayType(DataTypes.DateType, _), ArrayType(toElementType, _))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,10 @@ case class CometExecRule(session: SparkSession)

if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false

if (groupingExpressions.exists(e => QueryPlanSerde.containsMapType(e.dataType))) return false
if (groupingExpressions.exists(e =>
SupportLevel.containsType(e.dataType, classOf[MapType]))) {
return false
}

if (!groupingExpressions.forall(e =>
QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) {
Expand Down
17 changes: 5 additions & 12 deletions spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,11 @@ object CometSortOrder extends CometExpressionSerde[SortOrder] {
" floating-point types is not 100% compatible with Spark")

override def getSupportLevel(expr: SortOrder): SupportLevel = {

if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(expr.child.dataType)) {
// https://github.com/apache/datafusion-comet/issues/2626
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
// https://github.com/apache/datafusion-comet/issues/2626
SupportLevel
.strictFloatingPointReason(expr.child.dataType, "Sorting on floating-point")
.map(reason => Incompatible(Some(reason)))
.getOrElse(Compatible())
}

override def convert(
Expand Down
13 changes: 0 additions & 13 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,19 +490,6 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
false
}

/**
* Returns true if the given data type is or contains a `MapType` at any nesting level. Arrow's
* row format (used by DataFusion's grouped hash aggregate for composite group keys) does not
* support `Map`, so grouping on any type that transitively contains a map would crash in native
* execution.
*/
def containsMapType(dt: DataType): Boolean = dt match {
case _: MapType => true
case a: ArrayType => containsMapType(a.elementType)
case s: StructType => s.fields.exists(f => containsMapType(f.dataType))
case _ => false
}

/**
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
Expand Down
46 changes: 38 additions & 8 deletions spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ package org.apache.comet.serde

import org.apache.spark.sql.types._

import org.apache.comet.CometConf
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT

sealed trait SupportLevel

/**
Expand All @@ -46,14 +49,41 @@ case class Unsupported(notes: Option[String] = None) extends SupportLevel
object SupportLevel {

/**
* Returns true if the given data type contains FloatType or DoubleType at any nesting level.
* Returns true if `dt` is, or transitively contains, an instance of any of the given `DataType`
* classes. Walks `ArrayType` element, `StructType` fields, and `MapType` key/value at every
* nesting level.
*/
def containsType(dt: DataType, classes: Class[_ <: DataType]*): Boolean = {
if (classes.exists(_.isInstance(dt))) {
true
} else {
dt match {
case ArrayType(elementType, _) => containsType(elementType, classes: _*)
case StructType(fields) => fields.exists(f => containsType(f.dataType, classes: _*))
case MapType(keyType, valueType, _) =>
containsType(keyType, classes: _*) || containsType(valueType, classes: _*)
case _ => false
}
}
}

/**
* Gate for [[CometConf.COMET_EXEC_STRICT_FLOATING_POINT]]: returns the standard incompatibility
* reason when strict mode is enabled and `dt` contains a float or double (at any nesting
* level), and `None` otherwise. Callers wrap the reason with `Incompatible` or pass it to
* `withFallbackReason` as appropriate.
*
* `what` describes the operation being gated, e.g. "Sorting on floating-point" or "MapSort on
* floating-point key", and is interpolated into the returned message.
*/
def containsFloatingPoint(dt: DataType): Boolean = dt match {
case FloatType | DoubleType => true
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
case MapType(keyType, valueType, _) =>
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
case _ => false
def strictFloatingPointReason(dt: DataType, what: String): Option[String] = {
if (COMET_EXEC_STRICT_FLOATING_POINT.get() &&
containsType(dt, classOf[FloatType], classOf[DoubleType])) {
Some(
s"$what is not 100% compatible with Spark, and Comet is running with " +
s"${COMET_EXEC_STRICT_FLOATING_POINT.key}=true. ${CometConf.COMPAT_GUIDE}")
} else {
None
}
}
}
18 changes: 7 additions & 11 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -715,17 +715,13 @@ object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] {
" `spark.comet.expression.CollectSet.allowIncompatible=true` is set.")

override def getSupportLevel(expr: CollectSet): SupportLevel = {
if (COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(expr.children.head.dataType)) {
Incompatible(
Some(
"collect_set on floating-point types is not 100% compatible with Spark " +
"(Comet deduplicates NaN values while Spark treats each NaN as distinct), " +
s"and Comet is running with ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
SupportLevel
.strictFloatingPointReason(
expr.children.head.dataType,
"collect_set on floating-point types " +
"(Comet deduplicates NaN values while Spark treats each NaN as distinct)")
.map(reason => Incompatible(Some(reason)))
.getOrElse(Compatible())
}

override def convert(
Expand Down
23 changes: 5 additions & 18 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,11 @@ object CometSortArray extends CometExpressionSerde[SortArray] {

if (!supportedSortArrayElementType(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(elementType)) {
Incompatible(
Some(
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
SupportLevel
.strictFloatingPointReason(elementType, "Sorting on floating-point")
.map(reason => Incompatible(Some(reason)))
.getOrElse(Compatible())
}
}

Expand Down Expand Up @@ -553,17 +549,8 @@ object CometArrayReverse extends CometExpressionSerde[Reverse] with ArraysBase {

override def getIncompatibleReasons(): Seq[String] = Seq(unsupportedReason)

@tailrec
private def containsBinary(dt: DataType): Boolean = {
dt match {
case BinaryType => true
case ArrayType(elementType, _) => containsBinary(elementType)
case _ => false
}
}

override def getSupportLevel(expr: Reverse): SupportLevel = {
if (containsBinary(expr.child.dataType)) {
if (SupportLevel.containsType(expr.child.dataType, classOf[BinaryType])) {
Incompatible(Some(unsupportedReason))
} else {
Compatible(None)
Expand Down
13 changes: 2 additions & 11 deletions spark/src/main/scala/org/apache/comet/serde/maps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,11 @@ object CometMapFromEntries
override def getIncompatibleReasons(): Seq[String] =
Seq(keyUnsupportedReason, valueUnsupportedReason)

private def containsBinary(dataType: DataType): Boolean = {
dataType match {
case BinaryType => true
case StructType(fields) => fields.exists(field => containsBinary(field.dataType))
case ArrayType(elementType, _) => containsBinary(elementType)
case _ => false
}
}

override def getSupportLevel(expr: MapFromEntries): SupportLevel = {
if (containsBinary(expr.dataType.keyType)) {
if (SupportLevel.containsType(expr.dataType.keyType, classOf[BinaryType])) {
return Incompatible(Some(keyUnsupportedReason))
}
if (containsBinary(expr.dataType.valueType)) {
if (SupportLevel.containsType(expr.dataType.valueType, classOf[BinaryType])) {
return Incompatible(Some(valueUnsupportedReason))
}
Compatible(None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,8 @@ trait CometBaseAggregate {
return None
}

if (groupingExpressions.exists(expr => QueryPlanSerde.containsMapType(expr.dataType))) {
if (groupingExpressions.exists(expr =>
SupportLevel.containsType(expr.dataType, classOf[MapType]))) {
withFallbackReason(aggregate, "Grouping on map-containing types is not supported")
return None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,11 @@ object CometMapSort extends CometExpressionSerde[MapSort] {
val keyType = expr.dataType.asInstanceOf[MapType].keyType
if (!supportedScalarSortElementType(keyType)) {
Unsupported(Some(s"MapSort on map with key type $keyType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(keyType)) {
Incompatible(
Some(
"MapSort on floating-point key is not 100% compatible with Spark, and Comet is " +
s"running with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible(None)
SupportLevel
.strictFloatingPointReason(keyType, "MapSort on floating-point key")
.map(reason => Incompatible(Some(reason)))
.getOrElse(Compatible(None))
}
}

Expand Down
Loading
Loading