Skip to content

Commit 6010c5b

Browse files
authored
refactor: move array expression support checks to getSupportLevel (#4677)
1 parent 2c5b842 commit 6010c5b

1 file changed

Lines changed: 38 additions & 36 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,12 @@ object CometArrayRemove
3737
with CometExprShim
3838
with ArraysBase {
3939

40+
override def getSupportLevel(expr: ArrayRemove): SupportLevel = childTypesSupportLevel(expr)
41+
4042
override def convert(
4143
expr: ArrayRemove,
4244
inputs: Seq[Attribute],
4345
binding: Boolean): Option[ExprOuterClass.Expr] = {
44-
val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet
45-
for (dt <- inputTypes) {
46-
if (!isTypeSupported(dt)) {
47-
withFallbackReason(expr, s"data type not supported: $dt")
48-
return None
49-
}
50-
}
5146
val arrayExprProto = exprToProto(expr.left, inputs, binding)
5247
val keyExprProto = exprToProto(expr.right, inputs, binding)
5348

@@ -156,7 +151,11 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
156151
SupportLevel
157152
.strictFloatingPointReason(elementType, "Sorting on floating-point")
158153
.map(reason => Incompatible(Some(reason)))
159-
.getOrElse(Compatible())
154+
.getOrElse(expr.ascendingOrder match {
155+
case Literal(_: Boolean, BooleanType) => Compatible()
156+
case other =>
157+
Unsupported(Some(s"ascendingOrder must be a boolean literal: $other"))
158+
})
160159
}
161160
}
162161

@@ -172,8 +171,8 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
172171
(
173172
exprToProtoInternal(Literal(direction), inputs, binding),
174173
exprToProtoInternal(Literal(nullOrdering), inputs, binding))
175-
case other =>
176-
withFallbackReason(expr, s"ascendingOrder must be a boolean literal: $other")
174+
case _ =>
175+
// Unreachable: getSupportLevel gates a non-boolean-literal ascendingOrder.
177176
(None, None)
178177
}
179178

@@ -583,6 +582,14 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
583582
override def getUnsupportedReasons(): Seq[String] = Seq(
584583
"Input must be an array. `Map` inputs are not supported.")
585584

585+
override def getSupportLevel(expr: ElementAt): SupportLevel = {
586+
if (expr.left.dataType.isInstanceOf[ArrayType]) {
587+
Compatible()
588+
} else {
589+
Unsupported(Some("Input is not an array"))
590+
}
591+
}
592+
586593
override def convert(
587594
expr: ElementAt,
588595
inputs: Seq[Attribute],
@@ -591,11 +598,6 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
591598
val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding)
592599
val defaultExpr = expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding))
593600

594-
if (!expr.left.dataType.isInstanceOf[ArrayType]) {
595-
withFallbackReason(expr, "Input is not an array")
596-
return None
597-
}
598-
599601
if (childExpr.isDefined && ordinalExpr.isDefined &&
600602
defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) {
601603
val arrayExtractBuilder = ExprOuterClass.ListExtract
@@ -621,17 +623,12 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
621623

622624
object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase {
623625

626+
override def getSupportLevel(expr: Flatten): SupportLevel = childTypesSupportLevel(expr)
627+
624628
override def convert(
625629
expr: Flatten,
626630
inputs: Seq[Attribute],
627631
binding: Boolean): Option[ExprOuterClass.Expr] = {
628-
val inputTypes = expr.children.map(_.dataType).toSet
629-
for (dt <- inputTypes) {
630-
if (!isTypeSupported(dt)) {
631-
withFallbackReason(expr, s"data type not supported: $dt")
632-
return None
633-
}
634-
}
635632
val flattenExprProto = exprToProto(expr.child, inputs, binding)
636633
val flattenScalarExpr = scalarFunctionExprToProto("flatten", flattenExprProto)
637634
optExprWithFallbackReason(flattenScalarExpr, expr, expr.children: _*)
@@ -719,25 +716,19 @@ object CometSize extends CometExpressionSerde[Size] {
719716

720717
object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase {
721718

722-
override def getSupportLevel(expr: ArrayPosition): SupportLevel = Compatible()
719+
override def getSupportLevel(expr: ArrayPosition): SupportLevel = {
720+
if (expr.children.forall(_.foldable)) {
721+
// Fall back to Spark for all-literal args so ConstantFolding can handle it.
722+
Unsupported(Some("all arguments are literals, falling back to Spark"))
723+
} else {
724+
childTypesSupportLevel(expr)
725+
}
726+
}
723727

724728
override def convert(
725729
expr: ArrayPosition,
726730
inputs: Seq[Attribute],
727731
binding: Boolean): Option[ExprOuterClass.Expr] = {
728-
if (expr.children.forall(_.foldable)) {
729-
withFallbackReason(expr, "all arguments are literals, falling back to Spark")
730-
return None
731-
}
732-
// Check if input types are supported
733-
val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet
734-
for (dt <- inputTypes) {
735-
if (!isTypeSupported(dt)) {
736-
withFallbackReason(expr, s"data type not supported: $dt")
737-
return None
738-
}
739-
}
740-
741732
val arrayExprProto = exprToProto(expr.left, inputs, binding)
742733
val elementExprProto = exprToProto(expr.right, inputs, binding)
743734

@@ -837,6 +828,17 @@ trait ArraysBase {
837828
case _ => false
838829
}
839830
}
831+
832+
/**
833+
* Support level based on whether every input data type is supported. Returns `Unsupported` for
834+
* the first unsupported input type, otherwise `Compatible`.
835+
*/
836+
def childTypesSupportLevel(expr: Expression): SupportLevel =
837+
expr.children
838+
.map(_.dataType)
839+
.collectFirst { case dt if !isTypeSupported(dt) => dt }
840+
.map(dt => Unsupported(Some(s"data type not supported: $dt")))
841+
.getOrElse(Compatible())
840842
}
841843

842844
object CometArrayTransform extends CometCodegenDispatch[ArrayTransform]

0 commit comments

Comments
 (0)