Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ object CometStringReplace extends CometScalarFunction[StringReplace]("replace")

object CometSubstring extends CometExpressionSerde[Substring] {

override def getSupportLevel(expr: Substring): SupportLevel = (expr.pos, expr.len) match {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be fixed after DF54, #4161

case (_: Literal, _: Literal) => Compatible()
case _ => Unsupported(Some("Substring pos and len must be literals"))
}

override def convert(
expr: Substring,
inputs: Seq[Attribute],
Expand All @@ -170,7 +175,7 @@ object CometSubstring extends CometExpressionSerde[Substring] {
None
}
case _ =>
withFallbackReason(expr, "Substring pos and len must be literals")
// Unreachable: getSupportLevel gates non-literal pos/len.
None
}
}
Expand Down Expand Up @@ -213,14 +218,18 @@ object CometLeft extends CometExpressionSerde[Left] {
None
}
case _ =>
withFallbackReason(expr, "LEFT len must be a literal")
// Unreachable: getSupportLevel gates a non-literal length.
None
}
}

override def getSupportLevel(expr: Left): SupportLevel = {
expr.str.dataType match {
case _: BinaryType | _: StringType => Compatible()
case _: BinaryType | _: StringType =>
expr.len match {
case _: Literal => Compatible()
case _ => Unsupported(Some("LEFT len must be a literal"))
}
case _ => Unsupported(Some(s"LEFT does not support ${expr.str.dataType}"))
}
}
Expand Down Expand Up @@ -256,7 +265,7 @@ object CometRight extends CometExpressionSerde[Right] {
}
}
case _ =>
withFallbackReason(expr, "RIGHT len must be a literal")
// Unreachable: getSupportLevel gates a non-literal length.
None
}
}
Expand All @@ -265,7 +274,11 @@ object CometRight extends CometExpressionSerde[Right] {

override def getSupportLevel(expr: Right): SupportLevel = {
expr.str.dataType match {
case _: StringType => Compatible()
case _: StringType =>
expr.len match {
case _: Literal => Compatible()
case _ => Unsupported(Some("RIGHT len must be a literal"))
}
case _ => Unsupported(Some(s"RIGHT does not support ${expr.str.dataType}"))
}
}
Expand Down Expand Up @@ -303,18 +316,22 @@ object CometConcat extends CometScalarFunction[Concat]("concat") with CometTypeS

object CometConcatWs extends CometExpressionSerde[ConcatWs] {

override def getSupportLevel(expr: ConcatWs): SupportLevel = expr.children.headOption match {
// A NULL separator converts directly to a NULL result, so it stays supported.
case Some(Literal(null, _)) => Compatible()
// Fall back to Spark for all-literal args so ConstantFolding can handle it.
case _ if expr.children.forall(_.foldable) =>
Unsupported(Some("all arguments are foldable"))
case _ => Compatible()
}

override def convert(expr: ConcatWs, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
expr.children.headOption match {
// Match Spark behavior: when the separator is NULL, the result of concat_ws is NULL.
case Some(Literal(null, _)) =>
val nullLiteral = Literal.create(null, expr.dataType)
exprToProtoInternal(nullLiteral, inputs, binding)

case _ if expr.children.forall(_.foldable) =>
// Fall back to Spark for all-literal args so ConstantFolding can handle it
withFallbackReason(expr, "all arguments are foldable")
None

case _ =>
// For all other cases, use the generic scalar function implementation.
CometScalarFunction[ConcatWs]("concat_ws").convert(expr, inputs, binding)
Expand All @@ -324,22 +341,23 @@ object CometConcatWs extends CometExpressionSerde[ConcatWs] {

object CometLike extends CometExpressionSerde[Like] {

override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
override def getSupportLevel(expr: Like): SupportLevel = {
if (expr.escapeChar == '\\') {
createBinaryExpr(
expr,
expr.left,
expr.right,
inputs,
binding,
(builder, binaryExpr) => builder.setLike(binaryExpr))
Compatible()
} else {
withFallbackReason(
expr,
s"custom escape character ${expr.escapeChar} not supported in LIKE")
None
Unsupported(Some(s"custom escape character ${expr.escapeChar} not supported in LIKE"))
}
}

override def convert(expr: Like, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe one day we should rename this method,it confuses me sometimes what is this conversion, is it expression rewrite, simplification, etc

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to consider doing all the renames we want to do before we release 1.0.0

createBinaryExpr(
expr,
expr.left,
expr.right,
inputs,
binding,
(builder, binaryExpr) => builder.setLike(binaryExpr))
}
}

/**
Expand Down
Loading