@@ -26,6 +26,8 @@ import org.apache.gluten.extension.JoinKeysTag
2626import org .apache .gluten .extension .columnar .FallbackTags
2727import org .apache .gluten .shuffle .NeedCustomColumnarBatchSerializer
2828import org .apache .gluten .sql .shims .SparkShimLoader
29+ import org .apache .gluten .substrait .SubstraitContext
30+ import org .apache .gluten .substrait .expression .{ExpressionBuilder , ExpressionNode , WindowFunctionNode }
2931import org .apache .gluten .vectorized .{ColumnarBatchSerializer , ColumnarBatchSerializeResult }
3032
3133import org .apache .spark .{ShuffleDependency , SparkEnv , SparkException }
@@ -53,7 +55,8 @@ import org.apache.spark.sql.execution.metric.SQLMetric
5355import org .apache .spark .sql .execution .python .ArrowEvalPythonExec
5456import org .apache .spark .sql .execution .unsafe .UnsafeColumnarBuildSideRelation
5557import org .apache .spark .sql .execution .utils .ExecUtil
56- import org .apache .spark .sql .expression .{UDFExpression , UserDefinedAggregateFunction }
58+ import org .apache .spark .sql .expression .{UDFExpression , UDFResolver , UserDefinedAggregateFunction }
59+ import org .apache .spark .sql .hive .HiveUDAFInspector
5760import org .apache .spark .sql .hive .VeloxHiveUDFTransformer
5861import org .apache .spark .sql .internal .SQLConf
5962import org .apache .spark .sql .types ._
@@ -64,6 +67,7 @@ import org.apache.commons.lang3.ClassUtils
6467
6568import javax .ws .rs .core .UriBuilder
6669
70+ import java .util .{ArrayList => JArrayList , List => JList }
6771import java .util .Locale
6872
6973import scala .collection .JavaConverters ._
@@ -1266,4 +1270,143 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
12661270 RaiseErrorRestrictions .ONLY_SUPPORT_ERROR_MESSAGE )
12671271 }
12681272 }
1273+
1274+ override def genWindowFunctionsNode (
1275+ windowExpression : Seq [NamedExpression ],
1276+ windowExpressionNodes : JList [WindowFunctionNode ],
1277+ originalInputAttributes : Seq [Attribute ],
1278+ context : SubstraitContext ): Unit = {
1279+ windowExpression.foreach {
1280+ windowExpr =>
1281+ val aliasExpr = windowExpr.asInstanceOf [Alias ]
1282+ val columnName = s " ${aliasExpr.name}_ ${aliasExpr.exprId.id}"
1283+ val wExpression = aliasExpr.child.asInstanceOf [WindowExpression ]
1284+ wExpression.windowFunction match {
1285+ case wf @ (RowNumber () | Rank (_) | DenseRank (_) | CumeDist () | PercentRank (_)) =>
1286+ val aggWindowFunc = wf.asInstanceOf [AggregateWindowFunction ]
1287+ val frame = aggWindowFunc.frame.asInstanceOf [SpecifiedWindowFrame ]
1288+ val windowFunctionNode = ExpressionBuilder .makeWindowFunction(
1289+ WindowFunctionsBuilder .create(context, aggWindowFunc).toInt,
1290+ new JArrayList [ExpressionNode ](),
1291+ columnName,
1292+ ConverterUtils .getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
1293+ frame.upper,
1294+ frame.lower,
1295+ frame.frameType.sql,
1296+ originalInputAttributes.asJava
1297+ )
1298+ windowExpressionNodes.add(windowFunctionNode)
1299+ case aggExpression : AggregateExpression =>
1300+ val frame = wExpression.windowSpec.frameSpecification.asInstanceOf [SpecifiedWindowFrame ]
1301+ val originalAggFunc = aggExpression.aggregateFunction
1302+ val aggregateFunc =
1303+ try {
1304+ AggregateFunctionsBuilder .getSubstraitFunctionName(originalAggFunc)
1305+ originalAggFunc
1306+ } catch {
1307+ case e : GlutenNotSupportException =>
1308+ HiveUDAFInspector .getUDAFClassName(originalAggFunc) match {
1309+ case Some (udafClass) if UDFResolver .UDAFNames .contains(udafClass) =>
1310+ UDFResolver .getUdafExpression(udafClass)(originalAggFunc.children)
1311+ case _ => throw e
1312+ }
1313+ }
1314+
1315+ val childrenNodeList = aggregateFunc.children
1316+ .map(
1317+ ExpressionConverter
1318+ .replaceWithExpressionTransformer(_, originalInputAttributes)
1319+ .doTransform(context))
1320+ .asJava
1321+
1322+ val functionId = VeloxAggregateFunctionsBuilder
1323+ .create(context, aggregateFunc, aggExpression.mode)
1324+ .toInt
1325+ val windowFunctionNode = ExpressionBuilder .makeWindowFunction(
1326+ functionId,
1327+ childrenNodeList,
1328+ columnName,
1329+ ConverterUtils .getTypeNode(aggExpression.dataType, aggExpression.nullable),
1330+ frame.upper,
1331+ frame.lower,
1332+ frame.frameType.sql,
1333+ originalInputAttributes.asJava
1334+ )
1335+ windowExpressionNodes.add(windowFunctionNode)
1336+ case wf @ (_ : Lead | _ : Lag ) =>
1337+ val offsetWf = wf.asInstanceOf [FrameLessOffsetWindowFunction ]
1338+ val frame = offsetWf.frame.asInstanceOf [SpecifiedWindowFrame ]
1339+ val childrenNodeList = new JArrayList [ExpressionNode ]()
1340+ childrenNodeList.add(
1341+ ExpressionConverter
1342+ .replaceWithExpressionTransformer(
1343+ offsetWf.input,
1344+ attributeSeq = originalInputAttributes)
1345+ .doTransform(context))
1346+ val offset = offsetWf.offset.eval(EmptyRow ).asInstanceOf [Int ]
1347+ val offsetNode = ExpressionBuilder .makeLiteral(Math .abs(offset.toLong), LongType , false )
1348+ childrenNodeList.add(offsetNode)
1349+ if (offsetWf.default.dataType != NullType ) {
1350+ childrenNodeList.add(
1351+ ExpressionConverter
1352+ .replaceWithExpressionTransformer(
1353+ offsetWf.default,
1354+ attributeSeq = originalInputAttributes)
1355+ .doTransform(context))
1356+ }
1357+ val windowFunctionNode = ExpressionBuilder .makeWindowFunction(
1358+ WindowFunctionsBuilder .create(context, offsetWf).toInt,
1359+ childrenNodeList,
1360+ columnName,
1361+ ConverterUtils .getTypeNode(offsetWf.dataType, offsetWf.nullable),
1362+ frame.upper,
1363+ frame.lower,
1364+ frame.frameType.sql,
1365+ offsetWf.ignoreNulls,
1366+ originalInputAttributes.asJava
1367+ )
1368+ windowExpressionNodes.add(windowFunctionNode)
1369+ case wf @ NthValue (input, offset : Literal , ignoreNulls : Boolean ) =>
1370+ val frame = wExpression.windowSpec.frameSpecification.asInstanceOf [SpecifiedWindowFrame ]
1371+ val childrenNodeList = new JArrayList [ExpressionNode ]()
1372+ childrenNodeList.add(
1373+ ExpressionConverter
1374+ .replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes)
1375+ .doTransform(context))
1376+ childrenNodeList.add(LiteralTransformer (offset).doTransform(context))
1377+ val windowFunctionNode = ExpressionBuilder .makeWindowFunction(
1378+ WindowFunctionsBuilder .create(context, wf).toInt,
1379+ childrenNodeList,
1380+ columnName,
1381+ ConverterUtils .getTypeNode(wf.dataType, wf.nullable),
1382+ frame.upper,
1383+ frame.lower,
1384+ frame.frameType.sql,
1385+ ignoreNulls,
1386+ originalInputAttributes.asJava
1387+ )
1388+ windowExpressionNodes.add(windowFunctionNode)
1389+ case wf @ NTile (buckets : Expression ) =>
1390+ val frame = wExpression.windowSpec.frameSpecification.asInstanceOf [SpecifiedWindowFrame ]
1391+ val childrenNodeList = new JArrayList [ExpressionNode ]()
1392+ val literal = buckets.asInstanceOf [Literal ]
1393+ childrenNodeList.add(LiteralTransformer (literal).doTransform(context))
1394+ val windowFunctionNode = ExpressionBuilder .makeWindowFunction(
1395+ WindowFunctionsBuilder .create(context, wf).toInt,
1396+ childrenNodeList,
1397+ columnName,
1398+ ConverterUtils .getTypeNode(wf.dataType, wf.nullable),
1399+ frame.upper,
1400+ frame.lower,
1401+ frame.frameType.sql,
1402+ originalInputAttributes.asJava
1403+ )
1404+ windowExpressionNodes.add(windowFunctionNode)
1405+ case _ =>
1406+ throw new GlutenNotSupportException (
1407+ " unsupported window function type: " +
1408+ wExpression.windowFunction)
1409+ }
1410+ }
1411+ }
12691412}
0 commit comments