@@ -37,17 +37,12 @@ object CometArrayRemove
3737 with CometExprShim
3838 with ArraysBase {
3939
40+ override def getSupportLevel (expr : ArrayRemove ): SupportLevel = childTypesSupportLevel(expr)
41+
4042 override def convert (
4143 expr : ArrayRemove ,
4244 inputs : Seq [Attribute ],
4345 binding : Boolean ): Option [ExprOuterClass .Expr ] = {
44- val inputTypes : Set [DataType ] = expr.children.map(_.dataType).toSet
45- for (dt <- inputTypes) {
46- if (! isTypeSupported(dt)) {
47- withFallbackReason(expr, s " data type not supported: $dt" )
48- return None
49- }
50- }
5146 val arrayExprProto = exprToProto(expr.left, inputs, binding)
5247 val keyExprProto = exprToProto(expr.right, inputs, binding)
5348
@@ -156,7 +151,11 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
156151 SupportLevel
157152 .strictFloatingPointReason(elementType, " Sorting on floating-point" )
158153 .map(reason => Incompatible (Some (reason)))
159- .getOrElse(Compatible ())
154+ .getOrElse(expr.ascendingOrder match {
155+ case Literal (_ : Boolean , BooleanType ) => Compatible ()
156+ case other =>
157+ Unsupported (Some (s " ascendingOrder must be a boolean literal: $other" ))
158+ })
160159 }
161160 }
162161
@@ -172,8 +171,8 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
172171 (
173172 exprToProtoInternal(Literal (direction), inputs, binding),
174173 exprToProtoInternal(Literal (nullOrdering), inputs, binding))
175- case other =>
176- withFallbackReason(expr, s " ascendingOrder must be a boolean literal: $other " )
174+ case _ =>
175+ // Unreachable: getSupportLevel gates a non- boolean- literal ascendingOrder.
177176 (None , None )
178177 }
179178
@@ -583,6 +582,14 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
583582 override def getUnsupportedReasons (): Seq [String ] = Seq (
584583 " Input must be an array. `Map` inputs are not supported." )
585584
585+ override def getSupportLevel (expr : ElementAt ): SupportLevel = {
586+ if (expr.left.dataType.isInstanceOf [ArrayType ]) {
587+ Compatible ()
588+ } else {
589+ Unsupported (Some (" Input is not an array" ))
590+ }
591+ }
592+
586593 override def convert (
587594 expr : ElementAt ,
588595 inputs : Seq [Attribute ],
@@ -591,11 +598,6 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
591598 val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding)
592599 val defaultExpr = expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding))
593600
594- if (! expr.left.dataType.isInstanceOf [ArrayType ]) {
595- withFallbackReason(expr, " Input is not an array" )
596- return None
597- }
598-
599601 if (childExpr.isDefined && ordinalExpr.isDefined &&
600602 defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) {
601603 val arrayExtractBuilder = ExprOuterClass .ListExtract
@@ -621,17 +623,12 @@ object CometElementAt extends CometExpressionSerde[ElementAt] {
621623
622624object CometFlatten extends CometExpressionSerde [Flatten ] with ArraysBase {
623625
626+ override def getSupportLevel (expr : Flatten ): SupportLevel = childTypesSupportLevel(expr)
627+
624628 override def convert (
625629 expr : Flatten ,
626630 inputs : Seq [Attribute ],
627631 binding : Boolean ): Option [ExprOuterClass .Expr ] = {
628- val inputTypes = expr.children.map(_.dataType).toSet
629- for (dt <- inputTypes) {
630- if (! isTypeSupported(dt)) {
631- withFallbackReason(expr, s " data type not supported: $dt" )
632- return None
633- }
634- }
635632 val flattenExprProto = exprToProto(expr.child, inputs, binding)
636633 val flattenScalarExpr = scalarFunctionExprToProto(" flatten" , flattenExprProto)
637634 optExprWithFallbackReason(flattenScalarExpr, expr, expr.children: _* )
@@ -719,25 +716,19 @@ object CometSize extends CometExpressionSerde[Size] {
719716
720717object CometArrayPosition extends CometExpressionSerde [ArrayPosition ] with ArraysBase {
721718
722- override def getSupportLevel (expr : ArrayPosition ): SupportLevel = Compatible ()
719+ override def getSupportLevel (expr : ArrayPosition ): SupportLevel = {
720+ if (expr.children.forall(_.foldable)) {
721+ // Fall back to Spark for all-literal args so ConstantFolding can handle it.
722+ Unsupported (Some (" all arguments are literals, falling back to Spark" ))
723+ } else {
724+ childTypesSupportLevel(expr)
725+ }
726+ }
723727
724728 override def convert (
725729 expr : ArrayPosition ,
726730 inputs : Seq [Attribute ],
727731 binding : Boolean ): Option [ExprOuterClass .Expr ] = {
728- if (expr.children.forall(_.foldable)) {
729- withFallbackReason(expr, " all arguments are literals, falling back to Spark" )
730- return None
731- }
732- // Check if input types are supported
733- val inputTypes : Set [DataType ] = expr.children.map(_.dataType).toSet
734- for (dt <- inputTypes) {
735- if (! isTypeSupported(dt)) {
736- withFallbackReason(expr, s " data type not supported: $dt" )
737- return None
738- }
739- }
740-
741732 val arrayExprProto = exprToProto(expr.left, inputs, binding)
742733 val elementExprProto = exprToProto(expr.right, inputs, binding)
743734
@@ -837,6 +828,17 @@ trait ArraysBase {
837828 case _ => false
838829 }
839830 }
831+
832+ /**
833+ * Support level based on whether every input data type is supported. Returns `Unsupported` for
834+ * the first unsupported input type, otherwise `Compatible`.
835+ */
836+ def childTypesSupportLevel (expr : Expression ): SupportLevel =
837+ expr.children
838+ .map(_.dataType)
839+ .collectFirst { case dt if ! isTypeSupported(dt) => dt }
840+ .map(dt => Unsupported (Some (s " data type not supported: $dt" )))
841+ .getOrElse(Compatible ())
840842}
841843
842844object CometArrayTransform extends CometCodegenDispatch [ArrayTransform ]
0 commit comments