Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.gluten.extension.JoinKeysTag
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.shuffle.NeedCustomColumnarBatchSerializer
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult}

import org.apache.spark.{ShuffleDependency, SparkEnv, SparkException}
Expand Down Expand Up @@ -53,7 +55,8 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
import org.apache.spark.sql.expression.{UDFExpression, UDFResolver, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.HiveUDAFInspector
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand All @@ -64,6 +67,7 @@ import org.apache.commons.lang3.ClassUtils

import javax.ws.rs.core.UriBuilder

import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Locale

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -1266,4 +1270,143 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
RaiseErrorRestrictions.ONLY_SUPPORT_ERROR_MESSAGE)
}
}

override def genWindowFunctionsNode(
windowExpression: Seq[NamedExpression],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
context: SubstraitContext): Unit = {
windowExpression.foreach {
windowExpr =>
val aliasExpr = windowExpr.asInstanceOf[Alias]
val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
wExpression.windowFunction match {
case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | PercentRank(_)) =>
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(context, aggWindowFunc).toInt,
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
frame.upper,
frame.lower,
frame.frameType.sql,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case aggExpression: AggregateExpression =>
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val originalAggFunc = aggExpression.aggregateFunction
val aggregateFunc =
try {
AggregateFunctionsBuilder.getSubstraitFunctionName(originalAggFunc)
originalAggFunc
} catch {
case e: GlutenNotSupportException =>
HiveUDAFInspector.getUDAFClassName(originalAggFunc) match {
case Some(udafClass) if UDFResolver.UDAFNames.contains(udafClass) =>
UDFResolver.getUdafExpression(udafClass)(originalAggFunc.children)
case _ => throw e
}
}

val childrenNodeList = aggregateFunc.children
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(context))
.asJava

val functionId = VeloxAggregateFunctionsBuilder
.create(context, aggregateFunc, aggExpression.mode)
.toInt
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
functionId,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(aggExpression.dataType, aggExpression.nullable),
frame.upper,
frame.lower,
frame.frameType.sql,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ (_: Lead | _: Lag) =>
val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.input,
attributeSeq = originalInputAttributes)
.doTransform(context))
val offset = offsetWf.offset.eval(EmptyRow).asInstanceOf[Int]
val offsetNode = ExpressionBuilder.makeLiteral(Math.abs(offset.toLong), LongType, false)
childrenNodeList.add(offsetNode)
if (offsetWf.default.dataType != NullType) {
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.default,
attributeSeq = originalInputAttributes)
.doTransform(context))
}
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(context, offsetWf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
frame.upper,
frame.lower,
frame.frameType.sql,
offsetWf.ignoreNulls,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ NthValue(input, offset: Literal, ignoreNulls: Boolean) =>
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes)
.doTransform(context))
childrenNodeList.add(LiteralTransformer(offset).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
frame.upper,
frame.lower,
frame.frameType.sql,
ignoreNulls,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ NTile(buckets: Expression) =>
val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
val literal = buckets.asInstanceOf[Literal]
childrenNodeList.add(LiteralTransformer(literal).doTransform(context))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(context, wf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
frame.upper,
frame.lower,
frame.frameType.sql,
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
case _ =>
throw new GlutenNotSupportException(
"unsupported window function type: " +
wExpression.windowFunction)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ object UDFResolver extends Logging {

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage))
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)).toSeq
tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
case Some((sig, withTypeConversion)) =>
UDFExpression(
name,
alias,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
if (!withTypeConversion) children
else applyCast(children, sig)
)
case None =>
Expand All @@ -366,17 +366,15 @@ object UDFResolver extends Logging {

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDAFMap.getOrElse(
name,
throw new GlutenNotSupportException(errorMessage)
)
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UDAFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)).toSeq

tryBind(signatures, children.map(_.dataType), allowTypeConversion) match {
case Some((sig, withTypeConversion)) =>
UserDefinedAggregateFunction(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
if (!withTypeConversion) children
else applyCast(children, sig),
sig.intermediateAttrs
)
Expand All @@ -385,16 +383,23 @@ object UDFResolver extends Logging {
}
}

private def tryBind(
sig: UDFSignatureBase,
private def tryBind[U <: UDFSignatureBase](
signatures: Seq[U],
requiredDataTypes: Seq[DataType],
allowTypeConversion: Boolean): Boolean = {
if (
!tryBindStrict(sig, requiredDataTypes) && (allowTypeConversion || sig.allowTypeConversion)
) {
tryBindWithTypeConversion(sig, requiredDataTypes)
} else {
true
allowTypeConversion: Boolean): Option[(U, Boolean)] = {
signatures.find(sig => tryBindStrict(sig, requiredDataTypes)) match {
case Some(sig) => Some((sig, false))
case None =>
val allowTypeConversionSignatures = if (allowTypeConversion) {
signatures
} else {
signatures.filter(_.allowTypeConversion)
}
allowTypeConversionSignatures.find(
sig => tryBindWithTypeConversion(sig, requiredDataTypes)) match {
case Some(sig) => Some((sig, true))
case None => None
}
}
}

Expand Down
Loading
Loading