diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index faef13e075..3a6b79321d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -37,17 +37,12 @@ object CometArrayRemove with CometExprShim with ArraysBase { + override def getSupportLevel(expr: ArrayRemove): SupportLevel = childTypesSupportLevel(expr) + override def convert( expr: ArrayRemove, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet - for (dt <- inputTypes) { - if (!isTypeSupported(dt)) { - withFallbackReason(expr, s"data type not supported: $dt") - return None - } - } val arrayExprProto = exprToProto(expr.left, inputs, binding) val keyExprProto = exprToProto(expr.right, inputs, binding) @@ -156,7 +151,11 @@ object CometSortArray extends CometExpressionSerde[SortArray] { SupportLevel .strictFloatingPointReason(elementType, "Sorting on floating-point") .map(reason => Incompatible(Some(reason))) - .getOrElse(Compatible()) + .getOrElse(expr.ascendingOrder match { + case Literal(_: Boolean, BooleanType) => Compatible() + case other => + Unsupported(Some(s"ascendingOrder must be a boolean literal: $other")) + }) } } @@ -172,8 +171,8 @@ object CometSortArray extends CometExpressionSerde[SortArray] { ( exprToProtoInternal(Literal(direction), inputs, binding), exprToProtoInternal(Literal(nullOrdering), inputs, binding)) - case other => - withFallbackReason(expr, s"ascendingOrder must be a boolean literal: $other") + case _ => + // Unreachable: getSupportLevel gates a non-boolean-literal ascendingOrder. (None, None) } @@ -583,6 +582,14 @@ object CometElementAt extends CometExpressionSerde[ElementAt] { override def getUnsupportedReasons(): Seq[String] = Seq( "Input must be an array. `Map` inputs are not supported.") + override def getSupportLevel(expr: ElementAt): SupportLevel = { + if (expr.left.dataType.isInstanceOf[ArrayType]) { + Compatible() + } else { + Unsupported(Some("Input is not an array")) + } + } + override def convert( expr: ElementAt, inputs: Seq[Attribute], @@ -591,11 +598,6 @@ object CometElementAt extends CometExpressionSerde[ElementAt] { val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding) val defaultExpr = expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding)) - if (!expr.left.dataType.isInstanceOf[ArrayType]) { - withFallbackReason(expr, "Input is not an array") - return None - } - if (childExpr.isDefined && ordinalExpr.isDefined && defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) { val arrayExtractBuilder = ExprOuterClass.ListExtract @@ -621,17 +623,12 @@ object CometElementAt extends CometExpressionSerde[ElementAt] { object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase { + override def getSupportLevel(expr: Flatten): SupportLevel = childTypesSupportLevel(expr) + override def convert( expr: Flatten, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val inputTypes = expr.children.map(_.dataType).toSet - for (dt <- inputTypes) { - if (!isTypeSupported(dt)) { - withFallbackReason(expr, s"data type not supported: $dt") - return None - } - } val flattenExprProto = exprToProto(expr.child, inputs, binding) val flattenScalarExpr = scalarFunctionExprToProto("flatten", flattenExprProto) optExprWithFallbackReason(flattenScalarExpr, expr, expr.children: _*) @@ -719,25 +716,19 @@ object CometSize extends CometExpressionSerde[Size] { object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase { - override def getSupportLevel(expr: ArrayPosition): SupportLevel = Compatible() + override def getSupportLevel(expr: ArrayPosition): SupportLevel = { + if (expr.children.forall(_.foldable)) { + // Fall back to Spark for all-literal args so ConstantFolding can handle it. + Unsupported(Some("all arguments are literals, falling back to Spark")) + } else { + childTypesSupportLevel(expr) + } + } override def convert( expr: ArrayPosition, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - if (expr.children.forall(_.foldable)) { - withFallbackReason(expr, "all arguments are literals, falling back to Spark") - return None - } - // Check if input types are supported - val inputTypes: Set[DataType] = expr.children.map(_.dataType).toSet - for (dt <- inputTypes) { - if (!isTypeSupported(dt)) { - withFallbackReason(expr, s"data type not supported: $dt") - return None - } - } - val arrayExprProto = exprToProto(expr.left, inputs, binding) val elementExprProto = exprToProto(expr.right, inputs, binding) @@ -837,6 +828,17 @@ trait ArraysBase { case _ => false } } + + /** + * Support level based on whether every input data type is supported. Returns `Unsupported` for + * the first unsupported input type, otherwise `Compatible`. + */ + def childTypesSupportLevel(expr: Expression): SupportLevel = + expr.children + .map(_.dataType) + .collectFirst { case dt if !isTypeSupported(dt) => dt } + .map(dt => Unsupported(Some(s"data type not supported: $dt"))) + .getOrElse(Compatible()) } object CometArrayTransform extends CometCodegenDispatch[ArrayTransform]