Skip to content

Commit 7a07db2

Browse files
authored
feat: Support right expression (#3207)
1 parent 7943199 commit 7a07db2

3 files changed

Lines changed: 117 additions & 1 deletion

File tree

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
173173
classOf[StringTrimLeft] -> CometScalarFunction("ltrim"),
174174
classOf[StringTrimRight] -> CometScalarFunction("rtrim"),
175175
classOf[Left] -> CometLeft,
176+
classOf[Right] -> CometRight,
176177
classOf[Substring] -> CometSubstring,
177178
classOf[Upper] -> CometUpper)
178179

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, InitCap, Left, Length, Like, Literal, Lower, RegExpReplace, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, Substring, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
26+
import org.apache.spark.unsafe.types.UTF8String
2627

2728
import org.apache.comet.CometConf
2829
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -143,6 +144,49 @@ object CometLeft extends CometExpressionSerde[Left] {
143144
}
144145
}
145146

147+
object CometRight extends CometExpressionSerde[Right] {
148+
149+
override def convert(expr: Right, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
150+
expr.len match {
151+
case Literal(lenValue, _) =>
152+
val lenInt = lenValue.asInstanceOf[Int]
153+
if (lenInt <= 0) {
154+
// Match Spark's behavior: If(IsNull(str), NULL, "")
155+
// This ensures NULL propagation: RIGHT(NULL, 0) -> NULL, RIGHT("hello", 0) -> ""
156+
val isNullExpr = IsNull(expr.str)
157+
val nullLiteral = Literal.create(null, StringType)
158+
val emptyStringLiteral = Literal(UTF8String.EMPTY_UTF8, StringType)
159+
val ifExpr = If(isNullExpr, nullLiteral, emptyStringLiteral)
160+
161+
// Serialize the If expression using existing infrastructure
162+
exprToProtoInternal(ifExpr, inputs, binding)
163+
} else {
164+
exprToProtoInternal(expr.str, inputs, binding) match {
165+
case Some(strExpr) =>
166+
val builder = ExprOuterClass.Substring.newBuilder()
167+
builder.setChild(strExpr)
168+
builder.setStart(-lenInt)
169+
builder.setLen(lenInt)
170+
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
171+
case None =>
172+
withInfo(expr, expr.str)
173+
None
174+
}
175+
}
176+
case _ =>
177+
withInfo(expr, "RIGHT len must be a literal")
178+
None
179+
}
180+
}
181+
182+
override def getSupportLevel(expr: Right): SupportLevel = {
183+
expr.str.dataType match {
184+
case _: StringType => Compatible()
185+
case _ => Unsupported(Some(s"RIGHT does not support ${expr.str.dataType}"))
186+
}
187+
}
188+
}
189+
146190
object CometConcat extends CometScalarFunction[Concat]("concat") {
147191
val unsupportedReason = "CONCAT supports only string input parameters"
148192

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,77 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
569569
}
570570
}
571571

572+
test("RIGHT function") {
573+
withParquetTable((0 until 10).map(i => (s"test$i", i)), "tbl") {
574+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 2) FROM tbl")
575+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 4) FROM tbl")
576+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 0) FROM tbl")
577+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, -1) FROM tbl")
578+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 100) FROM tbl")
579+
checkSparkAnswerAndOperator("SELECT RIGHT(CAST(NULL AS STRING), 2) FROM tbl LIMIT 1")
580+
}
581+
}
582+
583+
test("RIGHT function with unicode") {
584+
val data = Seq("café", "hello世界", "😀emoji", "తెలుగు")
585+
withParquetTable(data.zipWithIndex, "unicode_tbl") {
586+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 2) FROM unicode_tbl")
587+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 3) FROM unicode_tbl")
588+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 0) FROM unicode_tbl")
589+
}
590+
}
591+
592+
test("RIGHT function equivalence with SUBSTRING negative pos") {
593+
withParquetTable((0 until 20).map(i => Tuple1(s"test$i")), "equiv_tbl") {
594+
val df = spark.sql("""
595+
SELECT _1,
596+
RIGHT(_1, 3) as right_result,
597+
SUBSTRING(_1, -3, 3) as substring_result
598+
FROM equiv_tbl
599+
""")
600+
checkAnswer(
601+
df.filter(
602+
"right_result != substring_result OR " +
603+
"(right_result IS NULL AND substring_result IS NOT NULL) OR " +
604+
"(right_result IS NOT NULL AND substring_result IS NULL)"),
605+
Seq.empty)
606+
}
607+
}
608+
609+
test("RIGHT function with dictionary") {
610+
val data = (0 until 1000)
611+
.map(_ % 5)
612+
.map(i => s"value$i")
613+
withParquetTable(data.zipWithIndex, "dict_tbl") {
614+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 3) FROM dict_tbl")
615+
}
616+
}
617+
618+
test("RIGHT function NULL handling") {
619+
// Test NULL propagation with len = 0 (critical edge case)
620+
withParquetTable((0 until 5).map(i => (s"test$i", i)), "null_tbl") {
621+
checkSparkAnswerAndOperator("SELECT RIGHT(CAST(NULL AS STRING), 0) FROM null_tbl LIMIT 1")
622+
checkSparkAnswerAndOperator("SELECT RIGHT(CAST(NULL AS STRING), -1) FROM null_tbl LIMIT 1")
623+
checkSparkAnswerAndOperator("SELECT RIGHT(CAST(NULL AS STRING), -5) FROM null_tbl LIMIT 1")
624+
}
625+
626+
// Test non-NULL strings with len <= 0 (should return empty string)
627+
withParquetTable((0 until 5).map(i => (s"test$i", i)), "edge_tbl") {
628+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, 0) FROM edge_tbl")
629+
checkSparkAnswerAndOperator("SELECT _1, RIGHT(_1, -1) FROM edge_tbl")
630+
}
631+
632+
// Test mixed NULL and non-NULL values with a table
633+
val table = "right_null_edge"
634+
withTable(table) {
635+
sql(s"create table $table(str string) using parquet")
636+
sql(s"insert into $table values('hello'), (NULL), (''), ('world')")
637+
checkSparkAnswerAndOperator(s"SELECT str, RIGHT(str, 0) FROM $table")
638+
checkSparkAnswerAndOperator(s"SELECT str, RIGHT(str, -1) FROM $table")
639+
checkSparkAnswerAndOperator(s"SELECT str, RIGHT(str, 2) FROM $table")
640+
}
641+
}
642+
572643
test("hour, minute, second") {
573644
Seq(true, false).foreach { dictionaryEnabled =>
574645
withTempDir { dir =>

0 commit comments

Comments
 (0)