Skip to content

Commit c4cb2d4

Browse files
committed
support udaf in window
1 parent 97632d8 commit c4cb2d4

7 files changed

Lines changed: 282 additions & 199 deletions

File tree

backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import org.apache.gluten.extension.JoinKeysTag
2626
import org.apache.gluten.extension.columnar.FallbackTags
2727
import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
2828
import org.apache.gluten.sql.shims.SparkShimLoader
29+
import org.apache.gluten.substrait.SubstraitContext
30+
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
2931
import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult}
3032

3133
import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException}
@@ -53,7 +55,8 @@ import org.apache.spark.sql.execution.metric.SQLMetric
5355
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
5456
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
5557
import org.apache.spark.sql.execution.utils.ExecUtil
56-
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
58+
import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction}
59+
import org.apache.spark.sql.hive.HiveUDAFInspector
5760
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
5861
import org.apache.spark.sql.internal.SQLConf
5962
import org.apache.spark.sql.types._
@@ -64,6 +67,7 @@ import org.apache.commons.lang3.ClassUtils
6467

6568
import javax.ws.rs.core.UriBuilder
6669

70+
import java.util.{ArrayList => JArrayList, List => JList}
6771
import java.util.Locale
6872

6973
import scala.collection.JavaConverters._
@@ -1266,4 +1270,143 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
12661270
RaiseErrorRestrictions.ONLY_SUPPORT_ERROR_MESSAGE)
12671271
}
12681272
}
1273+
1274+
override def genWindowFunctionsNode(
1275+
windowExpression: Seq[NamedExpression],
1276+
windowExpressionNodes: JList[WindowFunctionNode],
1277+
originalInputAttributes: Seq[Attribute],
1278+
context: SubstraitContext): Unit = {
1279+
windowExpression.foreach {
1280+
windowExpr =>
1281+
val aliasExpr = windowExpr.asInstanceOf[Alias]
1282+
val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
1283+
val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
1284+
wExpression.windowFunction match {
1285+
case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | PercentRank(_)) =>
1286+
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
1287+
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
1288+
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
1289+
WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
1290+
new JArrayList[ExpressionNode](),
1291+
columnName,
1292+
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
1293+
frame.upper,
1294+
frame.lower,
1295+
frame.frameType.sql,
1296+
originalInputAttributes.asJava
1297+
)
1298+
windowExpressionNodes.add(windowFunctionNode)
1299+
case aggExpression: AggregateExpression =>
1300+
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
1301+
val originalAggFunc = aggExpression.aggregateFunction
1302+
val aggregateFunc =
1303+
try {
1304+
AggregateFunctionsBuilder.getSubstraitFunctionName(originalAggFunc)
1305+
originalAggFunc
1306+
} catch {
1307+
case e: GlutenNotSupportException =>
1308+
HiveUDAFInspector.getUDAFClassName(originalAggFunc) match {
1309+
case Some(udafClass) if UDFResolver.UDAFNames.contains(udafClass) =>
1310+
UDFResolver.getUdafExpression(udafClass)(originalAggFunc.children)
1311+
case _ => throw e
1312+
}
1313+
}
1314+
1315+
val childrenNodeList = aggregateFunc.children
1316+
.map(
1317+
ExpressionConverter
1318+
.replaceWithExpressionTransformer(_, originalInputAttributes)
1319+
.doTransform(context))
1320+
.asJava
1321+
1322+
val functionId = VeloxAggregateFunctionsBuilder
1323+
.create(context, aggregateFunc, aggExpression.mode)
1324+
.toInt
1325+
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
1326+
functionId,
1327+
childrenNodeList,
1328+
columnName,
1329+
ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable),
1330+
frame.upper,
1331+
frame.lower,
1332+
frame.frameType.sql,
1333+
originalInputAttributes.asJava
1334+
)
1335+
windowExpressionNodes.add(windowFunctionNode)
1336+
case wf @ (_: Lead | _: Lag) =>
1337+
val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
1338+
val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
1339+
val childrenNodeList = new JArrayList[ExpressionNode]()
1340+
childrenNodeList.add(
1341+
ExpressionConverter
1342+
.replaceWithExpressionTransformer(
1343+
offsetWf.input,
1344+
attributeSeq = originalInputAttributes)
1345+
.doTransform(context))
1346+
val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
1347+
val offsetNode = ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false)
1348+
childrenNodeList.add(offsetNode)
1349+
if (offsetWf.default.dataType != NullType) {
1350+
childrenNodeList.add(
1351+
ExpressionConverter
1352+
.replaceWithExpressionTransformer(
1353+
offsetWf.default,
1354+
attributeSeq = originalInputAttributes)
1355+
.doTransform(context))
1356+
}
1357+
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
1358+
WindowFunctionsBuilder.create(context, offsetWf).toInt,
1359+
childrenNodeList,
1360+
columnName,
1361+
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
1362+
frame.upper,
1363+
frame.lower,
1364+
frame.frameType.sql,
1365+
offsetWf.ignoreNulls,
1366+
originalInputAttributes.asJava
1367+
)
1368+
windowExpressionNodes.add(windowFunctionNode)
1369+
case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
1370+
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
1371+
val childrenNodeList = new JArrayList[ExpressionNode]()
1372+
childrenNodeList.add(
1373+
ExpressionConverter
1374+
.replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes)
1375+
.doTransform(context))
1376+
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
1377+
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
1378+
WindowFunctionsBuilder.create(context, wf).toInt,
1379+
childrenNodeList,
1380+
columnName,
1381+
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
1382+
frame.upper,
1383+
frame.lower,
1384+
frame.frameType.sql,
1385+
ignoreNulls,
1386+
originalInputAttributes.asJava
1387+
)
1388+
windowExpressionNodes.add(windowFunctionNode)
1389+
case wf @ NTile(buckets: Expression) =>
1390+
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
1391+
val childrenNodeList = new JArrayList[ExpressionNode]()
1392+
val literal = buckets.asInstanceOf[Literal]
1393+
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
1394+
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
1395+
WindowFunctionsBuilder.create(context, wf).toInt,
1396+
childrenNodeList,
1397+
columnName,
1398+
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
1399+
frame.upper,
1400+
frame.lower,
1401+
frame.frameType.sql,
1402+
originalInputAttributes.asJava
1403+
)
1404+
windowExpressionNodes.add(windowFunctionNode)
1405+
case _ =>
1406+
throw new GlutenNotSupportException(
1407+
"unsupported window function type: " +
1408+
wExpression.windowFunction)
1409+
}
1410+
}
1411+
}
12691412
}

backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -344,15 +344,15 @@ object UDFResolver extends Logging {
344344

345345
val allowTypeConversion = checkAllowTypeConversion
346346
val signatures =
347-
UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage))
348-
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
349-
case Some(sig) =>
347+
UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)).toSeq
348+
tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
349+
case Some((sig, withTypeConversion)) =>
350350
UDFExpression(
351351
name,
352352
alias,
353353
sig.expressionType.dataType,
354354
sig.expressionType.nullable,
355-
if (!allowTypeConversion && !sig.allowTypeConversion) children
355+
if (!withTypeConversion) children
356356
else applyCast(children, sig)
357357
)
358358
case None =>
@@ -366,17 +366,15 @@ object UDFResolver extends Logging {
366366

367367
val allowTypeConversion = checkAllowTypeConversion
368368
val signatures =
369-
UDAFMap.getOrElse(
370-
name,
371-
throw new GlutenNotSupportException(errorMessage)
372-
)
373-
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
374-
case Some(sig) =>
369+
UDAFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)).toSeq
370+
371+
tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
372+
case Some((sig, withTypeConversion)) =>
375373
UserDefinedAggregateFunction(
376374
name,
377375
sig.expressionType.dataType,
378376
sig.expressionType.nullable,
379-
if (!allowTypeConversion && !sig.allowTypeConversion) children
377+
if (!withTypeConversion) children
380378
else applyCast(children, sig),
381379
sig.intermediateAttrs
382380
)
@@ -385,16 +383,23 @@ object UDFResolver extends Logging {
385383
}
386384
}
387385

388-
private def tryBind(
389-
sig: UDFSignatureBase,
386+
private def tryBind[U <: UDFSignatureBase](
387+
signatures: Seq[U],
390388
requiredDataTypes: Seq[DataType],
391-
allowTypeConversion: Boolean): Boolean = {
392-
if (
393-
!tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion || sig.allowTypeConversion)
394-
) {
395-
tryBindWithTypeConversion(sig, requiredDataTypes)
396-
} else {
397-
true
389+
allowTypeConversion: Boolean): Option[(U, Boolean)] = {
390+
signatures.find(sig => tryBindStrict(sig, requiredDataTypes)) match {
391+
case Some(sig) => Some((sig, false))
392+
case None =>
393+
val allowTypeConversionSignatures = if (allowTypeConversion) {
394+
signatures
395+
} else {
396+
signatures.filter(_.allowTypeConversion)
397+
}
398+
allowTypeConversionSignatures.find(
399+
sig => tryBindWithTypeConversion(sig, requiredDataTypes)) match {
400+
case Some(sig) => Some((sig, true))
401+
case None => None
402+
}
398403
}
399404
}
400405

0 commit comments

Comments
 (0)