@@ -24,9 +24,8 @@ import scala.jdk.CollectionConverters._
2424import org .apache .spark .sql .catalyst .expressions .{Attribute , Literal }
2525import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , CollectSet , Corr , Count , Covariance , CovPopulation , CovSample , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
2626import org .apache .spark .sql .internal .SQLConf
27- import org .apache .spark .sql .types .{ByteType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType }
27+ import org .apache .spark .sql .types .{ByteType , DecimalType , IntegerType , LongType , ShortType , StringType }
2828
29- import org .apache .comet .CometConf
3029import org .apache .comet .CometConf .COMET_EXEC_STRICT_FLOATING_POINT
3130import org .apache .comet .CometSparkSessionExtensions .{isSpark41Plus , withFallbackReason }
3231import org .apache .comet .serde .QueryPlanSerde .{evalModeToProto , exprToProto , serializeDataType }
@@ -36,27 +35,15 @@ object CometMin extends CometAggregateExpressionSerde[Min] {
3635
3736 override def supportsMixedPartialFinal : Boolean = true
3837
38+ override def getSupportLevel (expr : Min ): SupportLevel =
39+ AggSerde .minMaxSupportLevel(expr.dataType)
40+
3941 override def convert (
4042 aggExpr : AggregateExpression ,
4143 expr : Min ,
4244 inputs : Seq [Attribute ],
4345 binding : Boolean ,
4446 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
45- if (! AggSerde .minMaxDataTypeSupported(expr.dataType)) {
46- withFallbackReason(aggExpr, s " Unsupported data type: ${expr.dataType}" )
47- return None
48- }
49-
50- if (expr.dataType == DataTypes .FloatType || expr.dataType == DataTypes .DoubleType ) {
51- if (CometConf .COMET_EXEC_STRICT_FLOATING_POINT .get()) {
52- // https://github.com/apache/datafusion-comet/issues/2448
53- withFallbackReason(
54- aggExpr,
55- s " floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT .key}=true " )
56- return None
57- }
58- }
59-
6047 val child = expr.children.head
6148 val childExpr = exprToProto(child, inputs, binding)
6249 val dataType = serializeDataType(expr.dataType)
@@ -85,27 +72,15 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
8572
8673 override def supportsMixedPartialFinal : Boolean = true
8774
75+ override def getSupportLevel (expr : Max ): SupportLevel =
76+ AggSerde .minMaxSupportLevel(expr.dataType)
77+
8878 override def convert (
8979 aggExpr : AggregateExpression ,
9080 expr : Max ,
9181 inputs : Seq [Attribute ],
9282 binding : Boolean ,
9383 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
94- if (! AggSerde .minMaxDataTypeSupported(expr.dataType)) {
95- withFallbackReason(aggExpr, s " Unsupported data type: ${expr.dataType}" )
96- return None
97- }
98-
99- if (expr.dataType == DataTypes .FloatType || expr.dataType == DataTypes .DoubleType ) {
100- if (CometConf .COMET_EXEC_STRICT_FLOATING_POINT .get()) {
101- // https://github.com/apache/datafusion-comet/issues/2448
102- withFallbackReason(
103- aggExpr,
104- s " floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT .key}=true " )
105- return None
106- }
107- }
108-
10984 val child = expr.children.head
11085 val childExpr = exprToProto(child, inputs, binding)
11186 val dataType = serializeDataType(expr.dataType)
@@ -158,18 +133,20 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
158133 override def getUnsupportedReasons (): Seq [String ] = Seq (
159134 " YearMonthIntervalType and DayTimeIntervalType inputs are not supported" )
160135
136+ override def getSupportLevel (expr : Average ): SupportLevel =
137+ if (AggSerde .avgDataTypeSupported(expr.dataType)) {
138+ Compatible ()
139+ } else {
140+ Unsupported (Some (s " Unsupported data type: ${expr.dataType}" ))
141+ }
142+
161143 override def convert (
162144 aggExpr : AggregateExpression ,
163145 avg : Average ,
164146 inputs : Seq [Attribute ],
165147 binding : Boolean ,
166148 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
167149
168- if (! AggSerde .avgDataTypeSupported(avg.dataType)) {
169- withFallbackReason(aggExpr, s " Unsupported data type: ${avg.dataType}" )
170- return None
171- }
172-
173150 val child = avg.child
174151 val childExpr = exprToProto(child, inputs, binding)
175152 val dataType = serializeDataType(avg.dataType)
@@ -209,18 +186,20 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
209186
210187object CometSum extends CometAggregateExpressionSerde [Sum ] {
211188
189+ override def getSupportLevel (expr : Sum ): SupportLevel =
190+ if (AggSerde .sumDataTypeSupported(expr.dataType)) {
191+ Compatible ()
192+ } else {
193+ Unsupported (Some (s " Unsupported data type: ${expr.dataType}" ))
194+ }
195+
212196 override def convert (
213197 aggExpr : AggregateExpression ,
214198 sum : Sum ,
215199 inputs : Seq [Attribute ],
216200 binding : Boolean ,
217201 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
218202
219- if (! AggSerde .sumDataTypeSupported(sum.dataType)) {
220- withFallbackReason(aggExpr, s " Unsupported data type: ${sum.dataType}" )
221- return None
222- }
223-
224203 val evalMode = CometEvalModeUtil .sumEvalMode(sum)
225204
226205 val childExpr = exprToProto(sum.child, inputs, binding)
@@ -323,16 +302,19 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
323302object CometBitAndAgg extends CometAggregateExpressionSerde [BitAndAgg ] {
324303 override def supportsMixedPartialFinal : Boolean = true
325304
305+ override def getSupportLevel (expr : BitAndAgg ): SupportLevel =
306+ if (AggSerde .bitwiseAggTypeSupported(expr.dataType)) {
307+ Compatible ()
308+ } else {
309+ Unsupported (Some (s " Unsupported data type: ${expr.dataType}" ))
310+ }
311+
326312 override def convert (
327313 aggExpr : AggregateExpression ,
328314 bitAnd : BitAndAgg ,
329315 inputs : Seq [Attribute ],
330316 binding : Boolean ,
331317 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
332- if (! AggSerde .bitwiseAggTypeSupported(bitAnd.dataType)) {
333- withFallbackReason(aggExpr, s " Unsupported data type: ${bitAnd.dataType}" )
334- return None
335- }
336318 val child = bitAnd.child
337319 val childExpr = exprToProto(child, inputs, binding)
338320 val dataType = serializeDataType(bitAnd.dataType)
@@ -359,16 +341,19 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
359341object CometBitOrAgg extends CometAggregateExpressionSerde [BitOrAgg ] {
360342 override def supportsMixedPartialFinal : Boolean = true
361343
344+ override def getSupportLevel (expr : BitOrAgg ): SupportLevel =
345+ if (AggSerde .bitwiseAggTypeSupported(expr.dataType)) {
346+ Compatible ()
347+ } else {
348+ Unsupported (Some (s " Unsupported data type: ${expr.dataType}" ))
349+ }
350+
362351 override def convert (
363352 aggExpr : AggregateExpression ,
364353 bitOr : BitOrAgg ,
365354 inputs : Seq [Attribute ],
366355 binding : Boolean ,
367356 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
368- if (! AggSerde .bitwiseAggTypeSupported(bitOr.dataType)) {
369- withFallbackReason(aggExpr, s " Unsupported data type: ${bitOr.dataType}" )
370- return None
371- }
372357 val child = bitOr.child
373358 val childExpr = exprToProto(child, inputs, binding)
374359 val dataType = serializeDataType(bitOr.dataType)
@@ -395,16 +380,19 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
395380object CometBitXOrAgg extends CometAggregateExpressionSerde [BitXorAgg ] {
396381 override def supportsMixedPartialFinal : Boolean = true
397382
383+ override def getSupportLevel (expr : BitXorAgg ): SupportLevel =
384+ if (AggSerde .bitwiseAggTypeSupported(expr.dataType)) {
385+ Compatible ()
386+ } else {
387+ Unsupported (Some (s " Unsupported data type: ${expr.dataType}" ))
388+ }
389+
398390 override def convert (
399391 aggExpr : AggregateExpression ,
400392 bitXor : BitXorAgg ,
401393 inputs : Seq [Attribute ],
402394 binding : Boolean ,
403395 conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
404- if (! AggSerde .bitwiseAggTypeSupported(bitXor.dataType)) {
405- withFallbackReason(aggExpr, s " Unsupported data type: ${bitXor.dataType}" )
406- return None
407- }
408396 val child = bitXor.child
409397 val childExpr = exprToProto(child, inputs, binding)
410398 val dataType = serializeDataType(bitXor.dataType)
@@ -638,6 +626,14 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
638626
639627 override def supportsMixedPartialFinal : Boolean = true
640628
629+ override def getSupportLevel (expr : BloomFilterAggregate ): SupportLevel =
630+ expr.child.dataType match {
631+ case _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType | _ : StringType =>
632+ Compatible ()
633+ case other =>
634+ Unsupported (Some (s " Unsupported data type for bloom_filter_agg child: $other" ))
635+ }
636+
641637 override def convert (
642638 aggExpr : AggregateExpression ,
643639 bloomFilter : BloomFilterAggregate ,
@@ -664,16 +660,6 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
664660 val dataType = serializeDataType(bloomFilter.dataType)
665661
666662 if (childExpr.isDefined &&
667- (bloomFilter.child.dataType
668- .isInstanceOf [ByteType ] ||
669- bloomFilter.child.dataType
670- .isInstanceOf [ShortType ] ||
671- bloomFilter.child.dataType
672- .isInstanceOf [IntegerType ] ||
673- bloomFilter.child.dataType
674- .isInstanceOf [LongType ] ||
675- bloomFilter.child.dataType
676- .isInstanceOf [StringType ]) &&
677663 numItemsExpr.isDefined &&
678664 numBitsExpr.isDefined &&
679665 dataType.isDefined) {
@@ -793,4 +779,18 @@ object AggSerde {
793779 }
794780 }
795781
782+ /** Shared support level for `Min` / `Max` based on the result data type. */
783+ def minMaxSupportLevel (dt : DataType ): SupportLevel = {
784+ if (! minMaxDataTypeSupported(dt)) {
785+ Unsupported (Some (s " Unsupported data type: $dt" ))
786+ } else if ((dt == FloatType || dt == DoubleType ) &&
787+ COMET_EXEC_STRICT_FLOATING_POINT .get()) {
788+ // https://github.com/apache/datafusion-comet/issues/2448
789+ Unsupported (
790+ Some (s " floating-point not supported when ${COMET_EXEC_STRICT_FLOATING_POINT .key}=true " ))
791+ } else {
792+ Compatible ()
793+ }
794+ }
795+
796796}
0 commit comments