Skip to content

Commit b26a8df

Browse files
authored
refactor: move string expression support checks to getSupportLevel (#4676)
1 parent 82bf3ae commit b26a8df

1 file changed

Lines changed: 40 additions & 22 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ object CometStringReplace extends CometScalarFunction[StringReplace]("replace")
152152

153153
object CometSubstring extends CometExpressionSerde[Substring] {
154154

155+
override def getSupportLevel(expr: Substring): SupportLevel = (expr.pos, expr.len) match {
156+
case (_: Literal, _: Literal) => Compatible()
157+
case _ => Unsupported(Some("Substring pos and len must be literals"))
158+
}
159+
155160
override def convert(
156161
expr: Substring,
157162
inputs: Seq[Attribute],
@@ -170,7 +175,7 @@ object CometSubstring extends CometExpressionSerde[Substring] {
170175
None
171176
}
172177
case _ =>
173-
withFallbackReason(expr, "Substring pos and len must be literals")
178+
// Unreachable: getSupportLevel gates non-literal pos/len.
174179
None
175180
}
176181
}
@@ -213,14 +218,18 @@ object CometLeft extends CometExpressionSerde[Left] {
213218
None
214219
}
215220
case _ =>
216-
withFallbackReason(expr, "LEFT len must be a literal")
221+
// Unreachable: getSupportLevel gates a non-literal length.
217222
None
218223
}
219224
}
220225

221226
override def getSupportLevel(expr: Left): SupportLevel = {
222227
expr.str.dataType match {
223-
case _: BinaryType | _: StringType => Compatible()
228+
case _: BinaryType | _: StringType =>
229+
expr.len match {
230+
case _: Literal => Compatible()
231+
case _ => Unsupported(Some("LEFT len must be a literal"))
232+
}
224233
case _ => Unsupported(Some(s"LEFT does not support ${expr.str.dataType}"))
225234
}
226235
}
@@ -256,7 +265,7 @@ object CometRight extends CometExpressionSerde[Right] {
256265
}
257266
}
258267
case _ =>
259-
withFallbackReason(expr, "RIGHT len must be a literal")
268+
// Unreachable: getSupportLevel gates a non-literal length.
260269
None
261270
}
262271
}
@@ -265,7 +274,11 @@ object CometRight extends CometExpressionSerde[Right] {
265274

266275
override def getSupportLevel(expr: Right): SupportLevel = {
267276
expr.str.dataType match {
268-
case _: StringType => Compatible()
277+
case _: StringType =>
278+
expr.len match {
279+
case _: Literal => Compatible()
280+
case _ => Unsupported(Some("RIGHT len must be a literal"))
281+
}
269282
case _ => Unsupported(Some(s"RIGHT does not support ${expr.str.dataType}"))
270283
}
271284
}
@@ -303,18 +316,22 @@ object CometConcat extends CometScalarFunction[Concat]("concat") with CometTypeS
303316

304317
object CometConcatWs extends CometExpressionSerde[ConcatWs] {
305318

319+
override def getSupportLevel(expr: ConcatWs): SupportLevel = expr.children.headOption match {
320+
// A NULL separator converts directly to a NULL result, so it stays supported.
321+
case Some(Literal(null, _)) => Compatible()
322+
// Fall back to Spark for all-literal args so ConstantFolding can handle it.
323+
case _ if expr.children.forall(_.foldable) =>
324+
Unsupported(Some("all arguments are foldable"))
325+
case _ => Compatible()
326+
}
327+
306328
override def convert(expr: ConcatWs, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
307329
expr.children.headOption match {
308330
// Match Spark behavior: when the separator is NULL, the result of concat_ws is NULL.
309331
case Some(Literal(null, _)) =>
310332
val nullLiteral = Literal.create(null, expr.dataType)
311333
exprToProtoInternal(nullLiteral, inputs, binding)
312334

313-
case _ if expr.children.forall(_.foldable) =>
314-
// Fall back to Spark for all-literal args so ConstantFolding can handle it
315-
withFallbackReason(expr, "all arguments are foldable")
316-
None
317-
318335
case _ =>
319336
// For all other cases, use the generic scalar function implementation.
320337
CometScalarFunction[ConcatWs]("concat_ws").convert(expr, inputs, binding)
@@ -324,22 +341,23 @@ object CometConcatWs extends CometExpressionSerde[ConcatWs] {
324341

325342
object CometLike extends CometExpressionSerde[Like] {
326343

327-
override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
344+
override def getSupportLevel(expr: Like): SupportLevel = {
328345
if (expr.escapeChar == '\\') {
329-
createBinaryExpr(
330-
expr,
331-
expr.left,
332-
expr.right,
333-
inputs,
334-
binding,
335-
(builder, binaryExpr) => builder.setLike(binaryExpr))
346+
Compatible()
336347
} else {
337-
withFallbackReason(
338-
expr,
339-
s"custom escape character ${expr.escapeChar} not supported in LIKE")
340-
None
348+
Unsupported(Some(s"custom escape character ${expr.escapeChar} not supported in LIKE"))
341349
}
342350
}
351+
352+
override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
353+
createBinaryExpr(
354+
expr,
355+
expr.left,
356+
expr.right,
357+
inputs,
358+
binding,
359+
(builder, binaryExpr) => builder.setLike(binaryExpr))
360+
}
343361
}
344362

345363
/**

0 commit comments

Comments
 (0)