diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index cf392e4214..6221452c86 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -24,9 +24,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} +import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, ShortType, StringType} -import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason} import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} @@ -36,27 +35,15 @@ object CometMin extends CometAggregateExpressionSerde[Min] { override def supportsMixedPartialFinal: Boolean = true + override def getSupportLevel(expr: Min): SupportLevel = + AggSerde.minMaxSupportLevel(expr.dataType) + override def convert( aggExpr: AggregateExpression, expr: Min, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${expr.dataType}") - return None - } - - if (expr.dataType == DataTypes.FloatType || expr.dataType == DataTypes.DoubleType) { - if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get()) { - // https://github.com/apache/datafusion-comet/issues/2448 - withFallbackReason( - aggExpr, - s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true") - return None - } - } - val child = expr.children.head val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(expr.dataType) @@ -85,27 +72,15 @@ object CometMax extends CometAggregateExpressionSerde[Max] { override def supportsMixedPartialFinal: Boolean = true + override def getSupportLevel(expr: Max): SupportLevel = + AggSerde.minMaxSupportLevel(expr.dataType) + override def convert( aggExpr: AggregateExpression, expr: Max, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${expr.dataType}") - return None - } - - if (expr.dataType == DataTypes.FloatType || expr.dataType == DataTypes.DoubleType) { - if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get()) { - // https://github.com/apache/datafusion-comet/issues/2448 - withFallbackReason( - aggExpr, - s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true") - return None - } - } - val child = expr.children.head val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(expr.dataType) @@ -158,6 +133,13 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { override def getUnsupportedReasons(): Seq[String] = Seq( "YearMonthIntervalType and DayTimeIntervalType inputs are not supported") + override def getSupportLevel(expr: Average): SupportLevel = + if (AggSerde.avgDataTypeSupported(expr.dataType)) { + Compatible() + } else { + Unsupported(Some(s"Unsupported data type: ${expr.dataType}")) + } + override def convert( aggExpr: AggregateExpression, avg: Average, @@ -165,11 +147,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.avgDataTypeSupported(avg.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${avg.dataType}") - return None - } - val child = avg.child val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(avg.dataType) @@ -209,6 +186,13 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { object CometSum extends CometAggregateExpressionSerde[Sum] { + override def getSupportLevel(expr: Sum): SupportLevel = + if (AggSerde.sumDataTypeSupported(expr.dataType)) { + Compatible() + } else { + Unsupported(Some(s"Unsupported data type: ${expr.dataType}")) + } + override def convert( aggExpr: AggregateExpression, sum: Sum, @@ -216,11 +200,6 @@ object CometSum extends CometAggregateExpressionSerde[Sum] { binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.sumDataTypeSupported(sum.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${sum.dataType}") - return None - } - val evalMode = CometEvalModeUtil.sumEvalMode(sum) val childExpr = exprToProto(sum.child, inputs, binding) @@ -323,16 +302,19 @@ object CometLast extends CometAggregateExpressionSerde[Last] { object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { override def supportsMixedPartialFinal: Boolean = true + override def getSupportLevel(expr: BitAndAgg): SupportLevel = + if (AggSerde.bitwiseAggTypeSupported(expr.dataType)) { + Compatible() + } else { + Unsupported(Some(s"Unsupported data type: ${expr.dataType}")) + } + override def convert( aggExpr: AggregateExpression, bitAnd: BitAndAgg, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.bitwiseAggTypeSupported(bitAnd.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${bitAnd.dataType}") - return None - } val child = bitAnd.child val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(bitAnd.dataType) @@ -359,16 +341,19 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { override def supportsMixedPartialFinal: Boolean = true + override def getSupportLevel(expr: BitOrAgg): SupportLevel = + if (AggSerde.bitwiseAggTypeSupported(expr.dataType)) { + Compatible() + } else { + Unsupported(Some(s"Unsupported data type: ${expr.dataType}")) + } + override def convert( aggExpr: AggregateExpression, bitOr: BitOrAgg, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.bitwiseAggTypeSupported(bitOr.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${bitOr.dataType}") - return None - } val child = bitOr.child val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(bitOr.dataType) @@ -395,16 +380,19 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] { override def supportsMixedPartialFinal: Boolean = true + override def getSupportLevel(expr: BitXorAgg): SupportLevel = + if (AggSerde.bitwiseAggTypeSupported(expr.dataType)) { + Compatible() + } else { + Unsupported(Some(s"Unsupported data type: ${expr.dataType}")) + } + override def convert( aggExpr: AggregateExpression, bitXor: BitXorAgg, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - if (!AggSerde.bitwiseAggTypeSupported(bitXor.dataType)) { - withFallbackReason(aggExpr, s"Unsupported data type: ${bitXor.dataType}") - return None - } val child = bitXor.child val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(bitXor.dataType) @@ -638,6 +626,14 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt override def supportsMixedPartialFinal: Boolean = true + override def getSupportLevel(expr: BloomFilterAggregate): SupportLevel = + expr.child.dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: StringType => + Compatible() + case other => + Unsupported(Some(s"Unsupported data type for bloom_filter_agg child: $other")) + } + override def convert( aggExpr: AggregateExpression, bloomFilter: BloomFilterAggregate, @@ -664,16 +660,6 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt val dataType = serializeDataType(bloomFilter.dataType) if (childExpr.isDefined && - (bloomFilter.child.dataType - .isInstanceOf[ByteType] || - bloomFilter.child.dataType - .isInstanceOf[ShortType] || - bloomFilter.child.dataType - .isInstanceOf[IntegerType] || - bloomFilter.child.dataType - .isInstanceOf[LongType] || - bloomFilter.child.dataType - .isInstanceOf[StringType]) && numItemsExpr.isDefined && numBitsExpr.isDefined && dataType.isDefined) { @@ -793,4 +779,18 @@ object AggSerde { } } + /** Shared support level for `Min` / `Max` based on the result data type. */ + def minMaxSupportLevel(dt: DataType): SupportLevel = { + if (!minMaxDataTypeSupported(dt)) { + Unsupported(Some(s"Unsupported data type: $dt")) + } else if ((dt == FloatType || dt == DoubleType) && + COMET_EXEC_STRICT_FLOATING_POINT.get()) { + // https://github.com/apache/datafusion-comet/issues/2448 + Unsupported( + Some(s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true")) + } else { + Compatible() + } + } + }