Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,12 @@ object CometArrayRemove
with CometExprShim
with ArraysBase {

override def getSupportLevel(expr: ArrayRemove): SupportLevel = childTypesSupportLevel(expr)

override def convert(
expr: ArrayRemove,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet
for (dt <- inputTypes) {
if (!isTypeSupported(dt)) {
withFallbackReason(expr, s"data type not supported: $dt")
return None
}
}
val arrayExprProto = exprToProto(expr.left, inputs, binding)
val keyExprProto = exprToProto(expr.right, inputs, binding)

Expand Down Expand Up @@ -156,7 +151,11 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
SupportLevel
.strictFloatingPointReason(elementType, "Sorting on floating-point")
.map(reason => Incompatible(Some(reason)))
.getOrElse(Compatible())
.getOrElse(expr.ascendingOrder match {
case Literal(_: Boolean, BooleanType) => Compatible()
case other =>
Unsupported(Some(s"ascendingOrder must be a boolean literal: $other"))
})
}
}

Expand All @@ -172,8 +171,8 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
(
exprToProtoInternal(Literal(direction), inputs, binding),
exprToProtoInternal(Literal(nullOrdering), inputs, binding))
case other =>
withFallbackReason(expr, s"ascendingOrder must be a boolean literal: $other")
case _ =>
// Unreachable: getSupportLevel gates a non-boolean-literal ascendingOrder.
(None, None)
}

Expand Down Expand Up @@ -583,6 +582,14 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
override def getUnsupportedReasons(): Seq[String] = Seq(
"Input must be an array. `Map` inputs are not supported.")

override def getSupportLevel(expr: ElementAt): SupportLevel = {
if (expr.left.dataType.isInstanceOf[ArrayType]) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in Apache Spark, the ElementAt supports Map, and Comet gates for Array only, we should support Map

  override def inputTypes: Seq[AbstractDataType] = {
    (left.dataType, right.dataType) match {
      case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
        Seq(arr, IntegerType)
      case (MapType(keyType, valueType, hasNull), e2) =>
        TypeCoercion.findTightestCommonType(keyType, e2) match {
          case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
          case _ => Seq.empty
        }
      case (l, r) => Seq.empty

    }
  }

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compatible()
} else {
Unsupported(Some("Input is not an array"))
}
}

override def convert(
expr: ElementAt,
inputs: Seq[Attribute],
Expand All @@ -591,11 +598,6 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding)
val defaultExpr = expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding))

if (!expr.left.dataType.isInstanceOf[ArrayType]) {
withFallbackReason(expr, "Input is not an array")
return None
}

if (childExpr.isDefined && ordinalExpr.isDefined &&
defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) {
val arrayExtractBuilder = ExprOuterClass.ListExtract
Expand All @@ -621,17 +623,12 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {

object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase {

override def getSupportLevel(expr: Flatten): SupportLevel = childTypesSupportLevel(expr)

override def convert(
expr: Flatten,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val inputTypes = expr.children.map(_.dataType).toSet
for (dt <- inputTypes) {
if (!isTypeSupported(dt)) {
withFallbackReason(expr, s"data type not supported: $dt")
return None
}
}
val flattenExprProto = exprToProto(expr.child, inputs, binding)
val flattenScalarExpr = scalarFunctionExprToProto("flatten", flattenExprProto)
optExprWithFallbackReason(flattenScalarExpr, expr, expr.children: _*)
Expand Down Expand Up @@ -719,25 +716,19 @@ object CometSize extends CometExpressionSerde[Size] {

object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase {

override def getSupportLevel(expr: ArrayPosition): SupportLevel = Compatible()
override def getSupportLevel(expr: ArrayPosition): SupportLevel = {
if (expr.children.forall(_.foldable)) {
// Fall back to Spark for all-literal args so ConstantFolding can handle it.
Unsupported(Some("all arguments are literals, falling back to Spark"))
} else {
childTypesSupportLevel(expr)
}
}

override def convert(
expr: ArrayPosition,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
if (expr.children.forall(_.foldable)) {
withFallbackReason(expr, "all arguments are literals, falling back to Spark")
return None
}
// Check if input types are supported
val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet
for (dt <- inputTypes) {
if (!isTypeSupported(dt)) {
withFallbackReason(expr, s"data type not supported: $dt")
return None
}
}

val arrayExprProto = exprToProto(expr.left, inputs, binding)
val elementExprProto = exprToProto(expr.right, inputs, binding)

Expand Down Expand Up @@ -837,6 +828,17 @@ trait ArraysBase {
case _ => false
}
}

/**
* Support level based on whether every input data type is supported. Returns `Unsupported` for
* the first unsupported input type, otherwise `Compatible`.
*/
def childTypesSupportLevel(expr: Expression): SupportLevel =
expr.children
.map(_.dataType)
.collectFirst { case dt if !isTypeSupported(dt) => dt }
.map(dt => Unsupported(Some(s"data type not supported: $dt")))
.getOrElse(Compatible())
}

object CometArrayTransform extends CometCodegenDispatch[ArrayTransform]
Expand Down
Loading