Skip to content

Commit 2c5b842

Browse files
authored
refactor: move aggregate expression support checks to getSupportLevel (#4678)
1 parent 17dddd0 commit 2c5b842

1 file changed

Lines changed: 64 additions & 64 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ import scala.jdk.CollectionConverters._
2424
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
27-
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
27+
import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, ShortType, StringType}
2828

29-
import org.apache.comet.CometConf
3029
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
3130
import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withFallbackReason}
3231
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
@@ -36,27 +35,15 @@ object CometMin extends CometAggregateExpressionSerde[Min] {
3635

3736
override def supportsMixedPartialFinal: Boolean = true
3837

38+
override def getSupportLevel(expr: Min): SupportLevel =
39+
AggSerde.minMaxSupportLevel(expr.dataType)
40+
3941
override def convert(
4042
aggExpr: AggregateExpression,
4143
expr: Min,
4244
inputs: Seq[Attribute],
4345
binding: Boolean,
4446
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
45-
if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) {
46-
withFallbackReason(aggExpr, s"Unsupported data type: ${expr.dataType}")
47-
return None
48-
}
49-
50-
if (expr.dataType == DataTypes.FloatType || expr.dataType == DataTypes.DoubleType) {
51-
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get()) {
52-
// https://github.com/apache/datafusion-comet/issues/2448
53-
withFallbackReason(
54-
aggExpr,
55-
s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true")
56-
return None
57-
}
58-
}
59-
6047
val child = expr.children.head
6148
val childExpr = exprToProto(child, inputs, binding)
6249
val dataType = serializeDataType(expr.dataType)
@@ -85,27 +72,15 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
8572

8673
override def supportsMixedPartialFinal: Boolean = true
8774

75+
override def getSupportLevel(expr: Max): SupportLevel =
76+
AggSerde.minMaxSupportLevel(expr.dataType)
77+
8878
override def convert(
8979
aggExpr: AggregateExpression,
9080
expr: Max,
9181
inputs: Seq[Attribute],
9282
binding: Boolean,
9383
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
94-
if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) {
95-
withFallbackReason(aggExpr, s"Unsupported data type: ${expr.dataType}")
96-
return None
97-
}
98-
99-
if (expr.dataType == DataTypes.FloatType || expr.dataType == DataTypes.DoubleType) {
100-
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get()) {
101-
// https://github.com/apache/datafusion-comet/issues/2448
102-
withFallbackReason(
103-
aggExpr,
104-
s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true")
105-
return None
106-
}
107-
}
108-
10984
val child = expr.children.head
11085
val childExpr = exprToProto(child, inputs, binding)
11186
val dataType = serializeDataType(expr.dataType)
@@ -158,18 +133,20 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
158133
override def getUnsupportedReasons(): Seq[String] = Seq(
159134
"YearMonthIntervalType and DayTimeIntervalType inputs are not supported")
160135

136+
override def getSupportLevel(expr: Average): SupportLevel =
137+
if (AggSerde.avgDataTypeSupported(expr.dataType)) {
138+
Compatible()
139+
} else {
140+
Unsupported(Some(s"Unsupported data type: ${expr.dataType}"))
141+
}
142+
161143
override def convert(
162144
aggExpr: AggregateExpression,
163145
avg: Average,
164146
inputs: Seq[Attribute],
165147
binding: Boolean,
166148
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
167149

168-
if (!AggSerde.avgDataTypeSupported(avg.dataType)) {
169-
withFallbackReason(aggExpr, s"Unsupported data type: ${avg.dataType}")
170-
return None
171-
}
172-
173150
val child = avg.child
174151
val childExpr = exprToProto(child, inputs, binding)
175152
val dataType = serializeDataType(avg.dataType)
@@ -209,18 +186,20 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
209186

210187
object CometSum extends CometAggregateExpressionSerde[Sum] {
211188

189+
override def getSupportLevel(expr: Sum): SupportLevel =
190+
if (AggSerde.sumDataTypeSupported(expr.dataType)) {
191+
Compatible()
192+
} else {
193+
Unsupported(Some(s"Unsupported data type: ${expr.dataType}"))
194+
}
195+
212196
override def convert(
213197
aggExpr: AggregateExpression,
214198
sum: Sum,
215199
inputs: Seq[Attribute],
216200
binding: Boolean,
217201
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
218202

219-
if (!AggSerde.sumDataTypeSupported(sum.dataType)) {
220-
withFallbackReason(aggExpr, s"Unsupported data type: ${sum.dataType}")
221-
return None
222-
}
223-
224203
val evalMode = CometEvalModeUtil.sumEvalMode(sum)
225204

226205
val childExpr = exprToProto(sum.child, inputs, binding)
@@ -323,16 +302,19 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
323302
object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
324303
override def supportsMixedPartialFinal: Boolean = true
325304

305+
override def getSupportLevel(expr: BitAndAgg): SupportLevel =
306+
if (AggSerde.bitwiseAggTypeSupported(expr.dataType)) {
307+
Compatible()
308+
} else {
309+
Unsupported(Some(s"Unsupported data type: ${expr.dataType}"))
310+
}
311+
326312
override def convert(
327313
aggExpr: AggregateExpression,
328314
bitAnd: BitAndAgg,
329315
inputs: Seq[Attribute],
330316
binding: Boolean,
331317
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
332-
if (!AggSerde.bitwiseAggTypeSupported(bitAnd.dataType)) {
333-
withFallbackReason(aggExpr, s"Unsupported data type: ${bitAnd.dataType}")
334-
return None
335-
}
336318
val child = bitAnd.child
337319
val childExpr = exprToProto(child, inputs, binding)
338320
val dataType = serializeDataType(bitAnd.dataType)
@@ -359,16 +341,19 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
359341
object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
360342
override def supportsMixedPartialFinal: Boolean = true
361343

344+
override def getSupportLevel(expr: BitOrAgg): SupportLevel =
345+
if (AggSerde.bitwiseAggTypeSupported(expr.dataType)) {
346+
Compatible()
347+
} else {
348+
Unsupported(Some(s"Unsupported data type: ${expr.dataType}"))
349+
}
350+
362351
override def convert(
363352
aggExpr: AggregateExpression,
364353
bitOr: BitOrAgg,
365354
inputs: Seq[Attribute],
366355
binding: Boolean,
367356
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
368-
if (!AggSerde.bitwiseAggTypeSupported(bitOr.dataType)) {
369-
withFallbackReason(aggExpr, s"Unsupported data type: ${bitOr.dataType}")
370-
return None
371-
}
372357
val child = bitOr.child
373358
val childExpr = exprToProto(child, inputs, binding)
374359
val dataType = serializeDataType(bitOr.dataType)
@@ -395,16 +380,19 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
395380
object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {
396381
override def supportsMixedPartialFinal: Boolean = true
397382

383+
override def getSupportLevel(expr: BitXorAgg): SupportLevel =
384+
if (AggSerde.bitwiseAggTypeSupported(expr.dataType)) {
385+
Compatible()
386+
} else {
387+
Unsupported(Some(s"Unsupported data type: ${expr.dataType}"))
388+
}
389+
398390
override def convert(
399391
aggExpr: AggregateExpression,
400392
bitXor: BitXorAgg,
401393
inputs: Seq[Attribute],
402394
binding: Boolean,
403395
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
404-
if (!AggSerde.bitwiseAggTypeSupported(bitXor.dataType)) {
405-
withFallbackReason(aggExpr, s"Unsupported data type: ${bitXor.dataType}")
406-
return None
407-
}
408396
val child = bitXor.child
409397
val childExpr = exprToProto(child, inputs, binding)
410398
val dataType = serializeDataType(bitXor.dataType)
@@ -638,6 +626,14 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
638626

639627
override def supportsMixedPartialFinal: Boolean = true
640628

629+
override def getSupportLevel(expr: BloomFilterAggregate): SupportLevel =
630+
expr.child.dataType match {
631+
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: StringType =>
632+
Compatible()
633+
case other =>
634+
Unsupported(Some(s"Unsupported data type for bloom_filter_agg child: $other"))
635+
}
636+
641637
override def convert(
642638
aggExpr: AggregateExpression,
643639
bloomFilter: BloomFilterAggregate,
@@ -664,16 +660,6 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
664660
val dataType = serializeDataType(bloomFilter.dataType)
665661

666662
if (childExpr.isDefined &&
667-
(bloomFilter.child.dataType
668-
.isInstanceOf[ByteType] ||
669-
bloomFilter.child.dataType
670-
.isInstanceOf[ShortType] ||
671-
bloomFilter.child.dataType
672-
.isInstanceOf[IntegerType] ||
673-
bloomFilter.child.dataType
674-
.isInstanceOf[LongType] ||
675-
bloomFilter.child.dataType
676-
.isInstanceOf[StringType]) &&
677663
numItemsExpr.isDefined &&
678664
numBitsExpr.isDefined &&
679665
dataType.isDefined) {
@@ -793,4 +779,18 @@ object AggSerde {
793779
}
794780
}
795781

782+
/** Shared support level for `Min` / `Max` based on the result data type. */
783+
def minMaxSupportLevel(dt: DataType): SupportLevel = {
784+
if (!minMaxDataTypeSupported(dt)) {
785+
Unsupported(Some(s"Unsupported data type: $dt"))
786+
} else if ((dt == FloatType || dt == DoubleType) &&
787+
COMET_EXEC_STRICT_FLOATING_POINT.get()) {
788+
// https://github.com/apache/datafusion-comet/issues/2448
789+
Unsupported(
790+
Some(s"floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true"))
791+
} else {
792+
Compatible()
793+
}
794+
}
795+
796796
}

0 commit comments

Comments
 (0)