Skip to content

Commit 3200c9f

Browse files
yusinnmaoclaude
andcommitted
Use scalarFunctionExprToProtoWithReturnType for levenshtein
Override convert() to use scalarFunctionExprToProtoWithReturnType with IntegerType, so the native planner skips the DataFusion registry lookup and does not conflict with DataFusion's built-in 2-arg levenshtein function when 3 args (threshold) are passed. Also fix Spotless: remove unnecessary string interpolation prefix. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d552d67 commit 3200c9f

2 files changed

Lines changed: 16 additions & 3 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet.serde
2222
import java.util.Locale
2323

2424
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
25-
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
25+
import org.apache.spark.sql.types.{BinaryType, DataTypes, IntegerType, LongType, StringType}
2626
import org.apache.spark.unsafe.types.UTF8String
2727

2828
import org.apache.comet.CometConf
@@ -84,7 +84,7 @@ object CometLength extends CometScalarFunction[Length]("length") {
8484
}
8585
}
8686

87-
object CometLevenshtein extends CometScalarFunction[Levenshtein]("levenshtein") {
87+
object CometLevenshtein extends CometExpressionSerde[Levenshtein] {
8888

8989
override def getUnsupportedReasons(): Seq[String] = Seq(
9090
"Non-default collation (non-UTF8_BINARY) is not supported")
@@ -96,6 +96,19 @@ object CometLevenshtein extends CometScalarFunction[Levenshtein]("levenshtein")
9696
case _ => Compatible()
9797
}
9898
}
99+
100+
override def convert(
101+
expr: Levenshtein,
102+
inputs: Seq[Attribute],
103+
binding: Boolean): Option[Expr] = {
104+
val childExprs = expr.children.map(exprToProtoInternal(_, inputs, binding))
105+
val optExpr = scalarFunctionExprToProtoWithReturnType(
106+
"levenshtein",
107+
IntegerType,
108+
false,
109+
childExprs: _*)
110+
optExprWithInfo(optExpr, expr, expr.children: _*)
111+
}
99112
}
100113

101114
object CometInitCap extends CometScalarFunction[InitCap]("initcap") {

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ class CometStringExpressionSuite extends CometTestBase {
722722
sql(s"CREATE TABLE $table(s1 STRING, s2 STRING) USING parquet")
723723
sql(
724724
s"INSERT INTO $table VALUES " +
725-
s"('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)")
725+
"('abc', 'adc'), (NULL, 'test'), ('hello', NULL), (NULL, NULL)")
726726
checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2) FROM $table")
727727
}
728728
}

0 commit comments

Comments
 (0)