diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 05448ba653..b0f2446354 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -152,6 +152,11 @@ object CometStringReplace extends CometScalarFunction[StringReplace]("replace") object CometSubstring extends CometExpressionSerde[Substring] { + override def getSupportLevel(expr: Substring): SupportLevel = (expr.pos, expr.len) match { + case (_: Literal, _: Literal) => Compatible() + case _ => Unsupported(Some("Substring pos and len must be literals")) + } + override def convert( expr: Substring, inputs: Seq[Attribute], @@ -170,7 +175,7 @@ object CometSubstring extends CometExpressionSerde[Substring] { None } case _ => - withFallbackReason(expr, "Substring pos and len must be literals") + // Unreachable: getSupportLevel gates non-literal pos/len. None } } @@ -213,14 +218,18 @@ object CometLeft extends CometExpressionSerde[Left] { None } case _ => - withFallbackReason(expr, "LEFT len must be a literal") + // Unreachable: getSupportLevel gates a non-literal length. None } } override def getSupportLevel(expr: Left): SupportLevel = { expr.str.dataType match { - case _: BinaryType | _: StringType => Compatible() + case _: BinaryType | _: StringType => + expr.len match { + case _: Literal => Compatible() + case _ => Unsupported(Some("LEFT len must be a literal")) + } case _ => Unsupported(Some(s"LEFT does not support ${expr.str.dataType}")) } } @@ -256,7 +265,7 @@ object CometRight extends CometExpressionSerde[Right] { } } case _ => - withFallbackReason(expr, "RIGHT len must be a literal") + // Unreachable: getSupportLevel gates a non-literal length. None } } @@ -265,7 +274,11 @@ object CometRight extends CometExpressionSerde[Right] { override def getSupportLevel(expr: Right): SupportLevel = { expr.str.dataType match { - case _: StringType => Compatible() + case _: StringType => + expr.len match { + case _: Literal => Compatible() + case _ => Unsupported(Some("RIGHT len must be a literal")) + } case _ => Unsupported(Some(s"RIGHT does not support ${expr.str.dataType}")) } } @@ -303,6 +316,15 @@ object CometConcat extends CometScalarFunction[Concat]("concat") with CometTypeS object CometConcatWs extends CometExpressionSerde[ConcatWs] { + override def getSupportLevel(expr: ConcatWs): SupportLevel = expr.children.headOption match { + // A NULL separator converts directly to a NULL result, so it stays supported. + case Some(Literal(null, _)) => Compatible() + // Fall back to Spark for all-literal args so ConstantFolding can handle it. + case _ if expr.children.forall(_.foldable) => + Unsupported(Some("all arguments are foldable")) + case _ => Compatible() + } + override def convert(expr: ConcatWs, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { expr.children.headOption match { // Match Spark behavior: when the separator is NULL, the result of concat_ws is NULL. @@ -310,11 +332,6 @@ object CometConcatWs extends CometExpressionSerde[ConcatWs] { val nullLiteral = Literal.create(null, expr.dataType) exprToProtoInternal(nullLiteral, inputs, binding) - case _ if expr.children.forall(_.foldable) => - // Fall back to Spark for all-literal args so ConstantFolding can handle it - withFallbackReason(expr, "all arguments are foldable") - None - case _ => // For all other cases, use the generic scalar function implementation. CometScalarFunction[ConcatWs]("concat_ws").convert(expr, inputs, binding) @@ -324,22 +341,23 @@ object CometConcatWs extends CometExpressionSerde[ConcatWs] { object CometLike extends CometExpressionSerde[Like] { - override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + override def getSupportLevel(expr: Like): SupportLevel = { if (expr.escapeChar == '\\') { - createBinaryExpr( - expr, - expr.left, - expr.right, - inputs, - binding, - (builder, binaryExpr) => builder.setLike(binaryExpr)) + Compatible() } else { - withFallbackReason( - expr, - s"custom escape character ${expr.escapeChar} not supported in LIKE") - None + Unsupported(Some(s"custom escape character ${expr.escapeChar} not supported in LIKE")) } } + + override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setLike(binaryExpr)) + } } /**