@@ -131,6 +131,29 @@ object QueryPlanSerde extends Logging with CometExprShim {
131131 classOf [SparkPartitionID ] -> CometSparkPartitionId ,
132132 classOf [MonotonicallyIncreasingID ] -> CometMonotonicallyIncreasingId )
133133
134+ /**
135+ * Mapping of Spark aggregate expression class to Comet expression handler.
136+ */
137+ private val aggrSerdeMap : Map [Class [_], CometAggregateExpressionSerde ] = Map (
138+ classOf [Sum ] -> CometSum ,
139+ classOf [Average ] -> CometAverage ,
140+ classOf [Count ] -> CometCount ,
141+ classOf [Min ] -> CometMin ,
142+ classOf [Max ] -> CometMax ,
143+ classOf [First ] -> CometFirst ,
144+ classOf [Last ] -> CometLast ,
145+ classOf [BitAndAgg ] -> CometBitAndAgg ,
146+ classOf [BitOrAgg ] -> CometBitOrAgg ,
147+ classOf [BitXorAgg ] -> CometBitXOrAgg ,
148+ classOf [CovSample ] -> CometCovSample ,
149+ classOf [CovPopulation ] -> CometCovPopulation ,
150+ classOf [VarianceSamp ] -> CometVarianceSamp ,
151+ classOf [VariancePop ] -> CometVariancePop ,
152+ classOf [StddevSamp ] -> CometStddevSamp ,
153+ classOf [StddevPop ] -> CometStddevPop ,
154+ classOf [Corr ] -> CometCorr ,
155+ classOf [BloomFilterAggregate ] -> CometBloomFilterAggregate )
156+
134157 def emitWarning (reason : String ): Unit = {
135158 logWarning(s " Comet native execution is disabled due to: $reason" )
136159 }
@@ -436,33 +459,17 @@ object QueryPlanSerde extends Logging with CometExprShim {
436459 return None
437460 }
438461
439- val cometExpr : CometAggregateExpressionSerde = aggExpr.aggregateFunction match {
440- case _ : Sum => CometSum
441- case _ : Average => CometAverage
442- case _ : Count => CometCount
443- case _ : Min => CometMin
444- case _ : Max => CometMax
445- case _ : First => CometFirst
446- case _ : Last => CometLast
447- case _ : BitAndAgg => CometBitAndAgg
448- case _ : BitOrAgg => CometBitOrAgg
449- case _ : BitXorAgg => CometBitXOrAgg
450- case _ : CovSample => CometCovSample
451- case _ : CovPopulation => CometCovPopulation
452- case _ : VarianceSamp => CometVarianceSamp
453- case _ : VariancePop => CometVariancePop
454- case _ : StddevSamp => CometStddevSamp
455- case _ : StddevPop => CometStddevPop
456- case _ : Corr => CometCorr
457- case _ : BloomFilterAggregate => CometBloomFilterAggregate
458- case fn =>
462+ val fn = aggExpr.aggregateFunction
463+ val cometExpr = aggrSerdeMap.get(fn.getClass)
464+ cometExpr match {
465+ case Some (handler) =>
466+ handler.convert(aggExpr, fn, inputs, binding, conf)
467+ case _ =>
459468 val msg = s " unsupported Spark aggregate function: ${fn.prettyName}"
460469 emitWarning(msg)
461470 withInfo(aggExpr, msg, fn.children: _* )
462- return None
463-
471+ None
464472 }
465- cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding, conf)
466473 }
467474
468475 def evalModeToProto (evalMode : CometEvalMode .Value ): ExprOuterClass .EvalMode = {
0 commit comments