From c4cb2d4a75fb070f77bf5cb1bb1b8ef26b4c513f Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Tue, 19 May 2026 22:22:44 +0100 Subject: [PATCH] support udaf in window --- .../velox/VeloxSparkPlanExecApi.scala | 145 ++++++++++++++++- .../spark/sql/expression/UDFResolver.scala | 45 +++--- .../gluten/expression/VeloxUdfSuite.scala | 146 +++++++++++++----- cpp/velox/substrait/SubstraitToVeloxPlan.cc | 3 +- cpp/velox/udf/examples/UdfCommon.h | 4 +- docs/developers/VeloxUDF.md | 1 + .../gluten/backendsapi/SparkPlanExecApi.scala | 137 +--------------- 7 files changed, 282 insertions(+), 199 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index b7a1e172b2c..f0c77e363f3 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -26,6 +26,8 @@ import org.apache.gluten.extension.JoinKeysTag import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer import org.apache.gluten.sql.shims.SparkShimLoader +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult} import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException} @@ -53,7 +55,8 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation import org.apache.spark.sql.execution.utils.ExecUtil -import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction} +import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction} +import org.apache.spark.sql.hive.HiveUDAFInspector import org.apache.spark.sql.hive.VeloxHiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -64,6 +67,7 @@ import org.apache.commons.lang3.ClassUtils import javax.ws.rs.core.UriBuilder +import java.util.{ArrayList => JArrayList, List => JList} import java.util.Locale import scala.collection.JavaConverters._ @@ -1266,4 +1270,143 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging { RaiseErrorRestrictions.ONLY_SUPPORT_ERROR_MESSAGE) } } + + override def genWindowFunctionsNode( + windowExpression: Seq[NamedExpression], + windowExpressionNodes: JList[WindowFunctionNode], + originalInputAttributes: Seq[Attribute], + context: SubstraitContext): Unit = { + windowExpression.foreach { + windowExpr => + val aliasExpr = windowExpr.asInstanceOf[Alias] + val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}" + val wExpression = aliasExpr.child.asInstanceOf[WindowExpression] + wExpression.windowFunction match { + case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | PercentRank(_)) => + val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction] + val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] + val windowFunctionNode = ExpressionBuilder.makeWindowFunction( + WindowFunctionsBuilder.create(context, aggWindowFunc).toInt, + new JArrayList[ExpressionNode](), + columnName, + ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava + ) + windowExpressionNodes.add(windowFunctionNode) + case aggExpression: AggregateExpression => + val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + val originalAggFunc = aggExpression.aggregateFunction + val aggregateFunc = + try { + AggregateFunctionsBuilder.getSubstraitFunctionName(originalAggFunc) + originalAggFunc + } catch { + case e: GlutenNotSupportException => + HiveUDAFInspector.getUDAFClassName(originalAggFunc) match { + case Some(udafClass) if UDFResolver.UDAFNames.contains(udafClass) => + UDFResolver.getUdafExpression(udafClass)(originalAggFunc.children) + case _ => throw e + } + } + + val childrenNodeList = aggregateFunc.children + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(context)) + .asJava + + val functionId = VeloxAggregateFunctionsBuilder + .create(context, aggregateFunc, aggExpression.mode) + .toInt + val windowFunctionNode = ExpressionBuilder.makeWindowFunction( + functionId, + childrenNodeList, + columnName, + ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava + ) + windowExpressionNodes.add(windowFunctionNode) + case wf @ (_: Lead | _: Lag) => + val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction] + val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame] + val childrenNodeList = new JArrayList[ExpressionNode]() + childrenNodeList.add( + ExpressionConverter + .replaceWithExpressionTransformer( + offsetWf.input, + attributeSeq = originalInputAttributes) + .doTransform(context)) + val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int] + val offsetNode = ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false) + childrenNodeList.add(offsetNode) + if (offsetWf.default.dataType != NullType) { + childrenNodeList.add( + ExpressionConverter + .replaceWithExpressionTransformer( + offsetWf.default, + attributeSeq = originalInputAttributes) + .doTransform(context)) + } + val windowFunctionNode = ExpressionBuilder.makeWindowFunction( + WindowFunctionsBuilder.create(context, offsetWf).toInt, + childrenNodeList, + columnName, + ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable), + frame.upper, + frame.lower, + frame.frameType.sql, + offsetWf.ignoreNulls, + originalInputAttributes.asJava + ) + windowExpressionNodes.add(windowFunctionNode) + case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) => + val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + val childrenNodeList = new JArrayList[ExpressionNode]() + childrenNodeList.add( + ExpressionConverter + .replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes) + .doTransform(context)) + childrenNodeList.add(LiteralTransformer(offset).doTransform(context)) + val windowFunctionNode = ExpressionBuilder.makeWindowFunction( + WindowFunctionsBuilder.create(context, wf).toInt, + childrenNodeList, + columnName, + ConverterUtils.getTypeNode(wf.dataType, wf.nullable), + frame.upper, + frame.lower, + frame.frameType.sql, + ignoreNulls, + originalInputAttributes.asJava + ) + windowExpressionNodes.add(windowFunctionNode) + case wf @ NTile(buckets: Expression) => + val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + val childrenNodeList = new JArrayList[ExpressionNode]() + val literal = buckets.asInstanceOf[Literal] + childrenNodeList.add(LiteralTransformer(literal).doTransform(context)) + val windowFunctionNode = ExpressionBuilder.makeWindowFunction( + WindowFunctionsBuilder.create(context, wf).toInt, + childrenNodeList, + columnName, + ConverterUtils.getTypeNode(wf.dataType, wf.nullable), + frame.upper, + frame.lower, + frame.frameType.sql, + originalInputAttributes.asJava + ) + windowExpressionNodes.add(windowFunctionNode) + case _ => + throw new GlutenNotSupportException( + "unsupported window function type: " + + wExpression.windowFunction) + } + } + } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index d2405f9e93e..43ee7b4f7d9 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -344,15 +344,15 @@ object UDFResolver extends Logging { val allowTypeConversion = checkAllowTypeConversion val signatures = - UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)) - signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match { - case Some(sig) => + UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)).toSeq + tryBind(signatures, children.map(_.dataType), allowTypeConversion) match { + case Some((sig, withTypeConversion)) => UDFExpression( name, alias, sig.expressionType.dataType, sig.expressionType.nullable, - if (!allowTypeConversion && !sig.allowTypeConversion) children + if (!withTypeConversion) children else applyCast(children, sig) ) case None => @@ -366,17 +366,15 @@ object UDFResolver extends Logging { val allowTypeConversion = checkAllowTypeConversion val signatures = - UDAFMap.getOrElse( - name, - throw new GlutenNotSupportException(errorMessage) - ) - signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match { - case Some(sig) => + UDAFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)).toSeq + + tryBind(signatures, children.map(_.dataType), allowTypeConversion) match { + case Some((sig, withTypeConversion)) => UserDefinedAggregateFunction( name, sig.expressionType.dataType, sig.expressionType.nullable, - if (!allowTypeConversion && !sig.allowTypeConversion) children + if (!withTypeConversion) children else applyCast(children, sig), sig.intermediateAttrs ) @@ -385,16 +383,23 @@ object UDFResolver extends Logging { } } - private def tryBind( - sig: UDFSignatureBase, + private def tryBind[U <: UDFSignatureBase]( + signatures: Seq[U], requiredDataTypes: Seq[DataType], - allowTypeConversion: Boolean): Boolean = { - if ( - !tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion || sig.allowTypeConversion) - ) { - tryBindWithTypeConversion(sig, requiredDataTypes) - } else { - true + allowTypeConversion: Boolean): Option[(U, Boolean)] = { + signatures.find(sig => tryBindStrict(sig, requiredDataTypes)) match { + case Some(sig) => Some((sig, false)) + case None => + val allowTypeConversionSignatures = if (allowTypeConversion) { + signatures + } else { + signatures.filter(_.allowTypeConversion) + } + allowTypeConversionSignatures.find( + sig => tryBindWithTypeConversion(sig, requiredDataTypes)) match { + case Some(sig) => Some((sig, true)) + case None => None + } } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala index 7eb61144a96..a5128f62d99 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala @@ -18,12 +18,14 @@ package org.apache.gluten.expression import org.apache.gluten.backendsapi.velox.VeloxBackendSettings import org.apache.gluten.execution.ProjectExecTransformer +import org.apache.gluten.execution.WindowExecTransformer import org.apache.gluten.tags.{SkipTest, UDFTest} import org.apache.spark.SparkConf import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expression.UDFResolver import java.nio.file.Paths @@ -92,16 +94,72 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { .set("spark.memory.offHeap.enabled", "true") .set("spark.memory.offHeap.size", "1024MB") .set("spark.ui.enabled", "false") + .set("spark.sql.adaptive.enabled", "false") } - // Aggregate result can be flaky. - ignore("test native hive udaf") { + test("test native hive udf") { + val tbl = "test_hive_udf_replacement" + withTempPath { + dir => + try { + spark.sql(s""" + |CREATE EXTERNAL TABLE $tbl + |LOCATION 'file://$dir' + |AS select * from values (1, '1'), (2, '2'), (3, '3') + |""".stripMargin) + + // Check native hive udf has been registered. + assert( + UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString")) + + spark.sql(""" + |CREATE TEMPORARY FUNCTION hive_string_string + |AS 'org.apache.spark.sql.hive.execution.UDFStringString' + |""".stripMargin) + + val offloadWithImplicitConversionDF = + spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""") + checkGlutenPlan[ProjectExecTransformer](offloadWithImplicitConversionDF) + val offloadWithImplicitConversionResult = offloadWithImplicitConversionDF.collect() + + val offloadDF = + spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""") + checkGlutenPlan[ProjectExecTransformer](offloadDF) + val offloadResult = offloadWithImplicitConversionDF.collect() + + // Unregister native hive udf to fallback. + UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString") + val fallbackDF = + spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""") + checkSparkPlan[ProjectExec](fallbackDF) + val fallbackResult = fallbackDF.collect() + assert(offloadWithImplicitConversionResult.sameElements(fallbackResult)) + assert(offloadResult.sameElements(fallbackResult)) + + // Add an unimplemented udf to the map to test fallback of registered native hive udf. + UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString") + spark.sql(""" + |CREATE TEMPORARY FUNCTION hive_int_to_string + |AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString' + |""".stripMargin) + val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""") + checkSparkPlan[ProjectExec](df) + checkAnswer(df, Seq(Row("1"), Row("2"), Row("3"))) + } finally { + spark.sql(s"DROP TABLE IF EXISTS $tbl") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string") + } + } + } + + test("test native hive udaf") { val tbl = "test_hive_udaf_replacement" + val udafClass = "test.org.apache.spark.sql.MyDoubleAvg" withTempPath { dir => try { // Check native hive udaf has been registered. - val udafClass = "test.org.apache.spark.sql.MyDoubleAvg" assert(UDFResolver.UDAFNames.contains(udafClass)) spark.sql(s""" @@ -136,64 +194,68 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { assert(nativeResult.sameElements(fallbackResult)) assert(nativeImplicitConversionResult.sameElements(fallbackResult)) } finally { + UDFResolver.UDAFNames.add(udafClass) spark.sql(s"DROP TABLE IF EXISTS $tbl") spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg") } } } - test("test native hive udf") { - val tbl = "test_hive_udf_replacement" + test("test native hive udaf in window") { + val tbl = "test_hive_udaf_window" + val udafClass = "test.org.apache.spark.sql.MyDoubleAvg" + val query = + s"""SELECT + | col1, + | my_double_avg(col1) OVER ( + | PARTITION BY col1 % 2 + | ORDER BY col1 + | ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_avg_window + |FROM $tbl + |ORDER BY col1 + |""".stripMargin + withTempPath { dir => try { + assert(UDFResolver.UDAFNames.contains(udafClass)) + + spark.sql(s""" + |CREATE TEMPORARY FUNCTION my_double_avg + |AS '$udafClass' + |""".stripMargin) + spark.sql(s""" + |DROP TABLE IF EXISTS $tbl; + |""".stripMargin) spark.sql(s""" |CREATE EXTERNAL TABLE $tbl |LOCATION 'file://$dir' - |AS select * from values (1, '1'), (2, '2'), (3, '3') + |AS SELECT CAST(v AS FLOAT) AS col1 + |FROM VALUES (1.0), (2.0), (3.0), (4.0) AS t(v) |""".stripMargin) - // Check native hive udf has been registered. - assert( - UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString")) - - spark.sql(""" - |CREATE TEMPORARY FUNCTION hive_string_string - |AS 'org.apache.spark.sql.hive.execution.UDFStringString' - |""".stripMargin) - - val offloadWithImplicitConversionDF = - spark.sql(s"""SELECT hive_string_string(col1, 'a') FROM $tbl""") - checkGlutenPlan[ProjectExecTransformer](offloadWithImplicitConversionDF) - val offloadWithImplicitConversionResult = offloadWithImplicitConversionDF.collect() + val offloadDF = spark.sql(query) + checkGlutenPlan[WindowExecTransformer](offloadDF) + checkAnswer( + offloadDF, + Seq( + Row(1.0f, 101.0), + Row(2.0f, 102.0), + Row(3.0f, 102.0), + Row(4.0f, 103.0) + )) + val offloadResult = offloadDF.collect() - val offloadDF = - spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""") - checkGlutenPlan[ProjectExecTransformer](offloadDF) - val offloadResult = offloadWithImplicitConversionDF.collect() - - // Unregister native hive udf to fallback. - UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString") - val fallbackDF = - spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""") - checkSparkPlan[ProjectExec](fallbackDF) + UDFResolver.UDAFNames.remove(udafClass) + val fallbackDF = spark.sql(query) + checkSparkPlan[WindowExec](fallbackDF) val fallbackResult = fallbackDF.collect() - assert(offloadWithImplicitConversionResult.sameElements(fallbackResult)) - assert(offloadResult.sameElements(fallbackResult)) - // Add an unimplemented udf to the map to test fallback of registered native hive udf. - UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString") - spark.sql(""" - |CREATE TEMPORARY FUNCTION hive_int_to_string - |AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString' - |""".stripMargin) - val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""") - checkSparkPlan[ProjectExec](df) - checkAnswer(df, Seq(Row("1"), Row("2"), Row("3"))) + assert(offloadResult.sameElements(fallbackResult)) } finally { + UDFResolver.UDAFNames.add(udafClass) spark.sql(s"DROP TABLE IF EXISTS $tbl") - spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string") - spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS my_double_avg") } } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 5477176ce85..b0fc0fc4a30 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -1165,7 +1165,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(), inputType)); } auto windowVeloxType = SubstraitParser::parseType(windowFunction.output_type()); - auto windowCall = std::make_shared(windowVeloxType, std::move(windowParams), funcName); + auto windowCall = std::make_shared( + windowVeloxType, std::move(windowParams), exec::sanitizeName(funcName)); auto upperBound = windowFunction.upper_bound(); auto lowerBound = windowFunction.lower_bound(); auto type = windowFunction.window_type(); diff --git a/cpp/velox/udf/examples/UdfCommon.h b/cpp/velox/udf/examples/UdfCommon.h index a68c474607c..7e73b916a2a 100644 --- a/cpp/velox/udf/examples/UdfCommon.h +++ b/cpp/velox/udf/examples/UdfCommon.h @@ -24,7 +24,7 @@ namespace gluten { class UdfRegisterer { public: - ~UdfRegisterer() = default; + virtual ~UdfRegisterer() = default; // Returns the number of UDFs in populateUdfEntries. virtual int getNumUdf() = 0; @@ -38,7 +38,7 @@ class UdfRegisterer { class UdafRegisterer { public: - ~UdafRegisterer() = default; + virtual ~UdafRegisterer() = default; // Returns the number of UDFs in populateUdafEntries. virtual int getNumUdaf() = 0; diff --git a/docs/developers/VeloxUDF.md b/docs/developers/VeloxUDF.md index b3154c41a78..a38f1a48db0 100644 --- a/docs/developers/VeloxUDF.md +++ b/docs/developers/VeloxUDF.md @@ -14,6 +14,7 @@ Users can implement custom functions using the UDF interface provided by Velox a At runtime, these UDFs are registered alongside their Java implementations via `CREATE TEMPORARY FUNCTION`. Once registered, Gluten can parse and offload these UDFs to Velox during execution, meanwhile ensuring proper fallback to Java UDFs when necessary. +Registered UDAFs can be used both as regular aggregate functions and as aggregate window functions. ## Create and Build UDF/UDAF library diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index ce0a79f0bc2..b8823748443 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -22,7 +22,7 @@ import org.apache.gluten.execution._ import org.apache.gluten.expression._ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.SubstraitContext -import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode} +import org.apache.gluten.substrait.expression.WindowFunctionNode import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD @@ -44,13 +44,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.hive.HiveUDFTransformer -import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType} +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import java.io.{ObjectInputStream, ObjectOutputStream} -import java.util.{ArrayList => JArrayList, List => JList} - -import scala.collection.JavaConverters._ +import java.util.{List => JList} trait SparkPlanExecApi { @@ -552,134 +550,7 @@ trait SparkPlanExecApi { windowExpression: Seq[NamedExpression], windowExpressionNodes: JList[WindowFunctionNode], originalInputAttributes: Seq[Attribute], - context: SubstraitContext): Unit = { - windowExpression.map { - windowExpr => - val aliasExpr = windowExpr.asInstanceOf[Alias] - val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}" - val wExpression = aliasExpr.child.asInstanceOf[WindowExpression] - wExpression.windowFunction match { - case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | PercentRank(_)) => - val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction] - val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] - val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(context, aggWindowFunc).toInt, - new JArrayList[ExpressionNode](), - columnName, - ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), - frame.upper, - frame.lower, - frame.frameType.sql, - originalInputAttributes.asJava - ) - windowExpressionNodes.add(windowFunctionNode) - case aggExpression: AggregateExpression => - val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - val aggregateFunc = aggExpression.aggregateFunction - val substraitAggFuncName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass) - if (substraitAggFuncName.isEmpty) { - throw new GlutenNotSupportException(s"Not currently supported: $aggregateFunc.") - } - - val childrenNodeList = aggregateFunc.children - .map( - ExpressionConverter - .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(context)) - .asJava - - val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - AggregateFunctionsBuilder.create(context, aggExpression.aggregateFunction).toInt, - childrenNodeList, - columnName, - ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable), - frame.upper, - frame.lower, - frame.frameType.sql, - originalInputAttributes.asJava - ) - windowExpressionNodes.add(windowFunctionNode) - case wf @ (_: Lead | _: Lag) => - val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction] - val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame] - val childrenNodeList = new JArrayList[ExpressionNode]() - childrenNodeList.add( - ExpressionConverter - .replaceWithExpressionTransformer( - offsetWf.input, - attributeSeq = originalInputAttributes) - .doTransform(context)) - // Spark only accepts foldable offset. Converts it to LongType literal. - val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int] - // Velox only allows negative offset. WindowFunctionsBuilder#create converts - // lag/lead with negative offset to the function with positive offset. So just - // makes offsetNode store positive value. - val offsetNode = ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false) - childrenNodeList.add(offsetNode) - // NullType means Null is the default value. Don't pass it to native. - if (offsetWf.default.dataType != NullType) { - childrenNodeList.add( - ExpressionConverter - .replaceWithExpressionTransformer( - offsetWf.default, - attributeSeq = originalInputAttributes) - .doTransform(context)) - } - val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(context, offsetWf).toInt, - childrenNodeList, - columnName, - ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable), - frame.upper, - frame.lower, - frame.frameType.sql, - offsetWf.ignoreNulls, - originalInputAttributes.asJava - ) - windowExpressionNodes.add(windowFunctionNode) - case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) => - val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - val childrenNodeList = new JArrayList[ExpressionNode]() - childrenNodeList.add( - ExpressionConverter - .replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes) - .doTransform(context)) - childrenNodeList.add(LiteralTransformer(offset).doTransform(context)) - val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(context, wf).toInt, - childrenNodeList, - columnName, - ConverterUtils.getTypeNode(wf.dataType, wf.nullable), - frame.upper, - frame.lower, - frame.frameType.sql, - ignoreNulls, - originalInputAttributes.asJava - ) - windowExpressionNodes.add(windowFunctionNode) - case wf @ NTile(buckets: Expression) => - val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - val childrenNodeList = new JArrayList[ExpressionNode]() - val literal = buckets.asInstanceOf[Literal] - childrenNodeList.add(LiteralTransformer(literal).doTransform(context)) - val windowFunctionNode = ExpressionBuilder.makeWindowFunction( - WindowFunctionsBuilder.create(context, wf).toInt, - childrenNodeList, - columnName, - ConverterUtils.getTypeNode(wf.dataType, wf.nullable), - frame.upper, - frame.lower, - frame.frameType.sql, - originalInputAttributes.asJava - ) - windowExpressionNodes.add(windowFunctionNode) - case _ => - throw new GlutenNotSupportException( - "unsupported window function type: " + - wExpression.windowFunction) - } - } - } + context: SubstraitContext): Unit def rewriteSpillPath(path: String): String = path