Skip to content

Commit 1e35882

Browse files
yusinnmaoclaude
andcommitted
Support three-argument levenshtein with threshold and add collation check
- Implement threshold semantics: levenshtein(str1, str2, threshold) returns -1 when distance exceeds threshold, matching Spark behavior - Add CometLevenshtein serde with getSupportLevel that falls back for non-default collations (Spark 4 StringTypeWithCollation) - Add Rust tests for threshold and NULL threshold cases - Add SLT tests for threshold variants - Add Scala integration test for threshold - Add levenshtein_threshold to benchmark Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 578477d commit 1e35882

6 files changed

Lines changed: 125 additions & 8 deletions

File tree

native/spark-expr/src/string_funcs/levenshtein.rs

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
//! matching Apache Spark's `levenshtein(str1, str2)` semantics.
2222
2323
use arrow::array::{as_string_array, Array, ArrayRef, Int32Array};
24-
use datafusion::common::{DataFusionError, Result};
24+
use datafusion::common::{DataFusionError, Result, ScalarValue};
2525
use datafusion::physical_plan::ColumnarValue;
2626
use std::sync::Arc;
2727

@@ -78,19 +78,40 @@ fn levenshtein_distance(s: &str, t: &str) -> i32 {
7878

7979
/// Spark-compatible levenshtein scalar function.
8080
///
81-
/// Accepts two string arguments and returns an Int32 array of edit distances.
82-
/// NULL inputs produce NULL outputs.
81+
/// Accepts two or three arguments:
82+
/// - `levenshtein(str1, str2)` → edit distance
83+
/// - `levenshtein(str1, str2, threshold)` → edit distance if <= threshold, else -1
84+
///
85+
/// NULL inputs produce NULL outputs. NULL threshold produces NULL output.
8386
pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
84-
if args.len() != 2 {
87+
if args.len() < 2 || args.len() > 3 {
8588
return Err(DataFusionError::Internal(format!(
86-
"levenshtein requires exactly 2 arguments, got {}",
89+
"levenshtein requires 2 or 3 arguments, got {}",
8790
args.len()
8891
)));
8992
}
9093

94+
// Extract optional threshold (3rd argument must be a scalar Int32)
95+
let threshold: Option<i32> = if args.len() == 3 {
96+
match &args[2] {
97+
ColumnarValue::Scalar(ScalarValue::Int32(t)) => match t {
98+
Some(val) => Some(*val),
99+
None => return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))),
100+
},
101+
_ => {
102+
return Err(DataFusionError::Internal(
103+
"levenshtein threshold must be an Int32 scalar".to_string(),
104+
));
105+
}
106+
}
107+
} else {
108+
None
109+
};
110+
91111
// Expand scalars to arrays for uniform processing
92112
let len = args
93113
.iter()
114+
.take(2)
94115
.find_map(|arg| match arg {
95116
ColumnarValue::Array(a) => Some(a.len()),
96117
_ => None,
@@ -107,7 +128,13 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
107128
.iter()
108129
.zip(right_arr.iter())
109130
.map(|(l, r)| match (l, r) {
110-
(Some(l), Some(r)) => Some(levenshtein_distance(l, r)),
131+
(Some(l), Some(r)) => {
132+
let dist = levenshtein_distance(l, r);
133+
match threshold {
134+
Some(t) if dist > t => Some(-1),
135+
_ => Some(dist),
136+
}
137+
}
111138
_ => None, // NULL propagation
112139
})
113140
.collect();
@@ -161,4 +188,43 @@ mod tests {
161188
_ => panic!("Expected array result"),
162189
}
163190
}
191+
192+
#[test]
193+
fn test_spark_levenshtein_with_threshold() {
194+
let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![
195+
Some("kitten"),
196+
Some("abc"),
197+
Some("frog"),
198+
])));
199+
let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![
200+
Some("sitting"),
201+
Some("adc"),
202+
Some("fog"),
203+
])));
204+
let threshold = ColumnarValue::Scalar(ScalarValue::Int32(Some(2)));
205+
206+
let result = spark_levenshtein(&[left, right, threshold]).unwrap();
207+
match result {
208+
ColumnarValue::Array(arr) => {
209+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
210+
assert_eq!(int_arr.value(0), -1); // kitten->sitting=3 > 2, return -1
211+
assert_eq!(int_arr.value(1), 1); // abc->adc=1 <= 2, return 1
212+
assert_eq!(int_arr.value(2), 1); // frog->fog=1 <= 2, return 1
213+
}
214+
_ => panic!("Expected array result"),
215+
}
216+
}
217+
218+
#[test]
219+
fn test_spark_levenshtein_null_threshold() {
220+
let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("abc")])));
221+
let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![Some("adc")])));
222+
let threshold = ColumnarValue::Scalar(ScalarValue::Int32(None));
223+
224+
let result = spark_levenshtein(&[left, right, threshold]).unwrap();
225+
match result {
226+
ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} // NULL threshold -> NULL
227+
_ => panic!("Expected NULL scalar result for NULL threshold"),
228+
}
229+
}
164230
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
166166
classOf[GetJsonObject] -> CometGetJsonObject,
167167
classOf[InitCap] -> CometInitCap,
168168
classOf[Length] -> CometLength,
169-
classOf[Levenshtein] -> CometScalarFunction("levenshtein"),
169+
classOf[Levenshtein] -> CometLevenshtein,
170170
classOf[Like] -> CometLike,
171171
classOf[Lower] -> CometLower,
172172
classOf[OctetLength] -> CometScalarFunction("octet_length"),

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Levenshtein, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
2626
import org.apache.spark.unsafe.types.UTF8String
2727

@@ -84,6 +84,21 @@ object CometLength extends CometScalarFunction[Length]("length") {
8484
}
8585
}
8686

87+
object CometLevenshtein extends CometScalarFunction[Levenshtein]("levenshtein") {
88+
89+
override def getUnsupportedReasons(): Seq[String] = Seq(
90+
"Non-default collation (non-UTF8_BINARY) is not supported")
91+
92+
override def getSupportLevel(expr: Levenshtein): SupportLevel = {
93+
expr.children.headOption match {
94+
case Some(child) if QueryPlanSerde.isStringCollationType(child.dataType) =>
95+
Unsupported(Some(
96+
"Levenshtein with non-default collation is not supported"))
97+
case _ => Compatible()
98+
}
99+
}
100+
}
101+
87102
object CometInitCap extends CometScalarFunction[InitCap]("initcap") {
88103

89104
override def getIncompatibleReasons(): Seq[String] = Seq(

spark/src/test/resources/sql-tests/expressions/string/levenshtein.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,19 @@ SELECT levenshtein(s1, s1) FROM test_levenshtein
4444
-- unicode characters
4545
query
4646
SELECT levenshtein('café', 'cafe'), levenshtein('你好', '你坏')
47+
48+
-- three argument version with threshold
49+
query
50+
SELECT levenshtein('kitten', 'sitting', 2), levenshtein('kitten', 'sitting', 3), levenshtein('kitten', 'sitting', 4)
51+
52+
-- threshold with column arguments
53+
query
54+
SELECT levenshtein(s1, s2, 2) FROM test_levenshtein
55+
56+
-- threshold edge cases
57+
query
58+
SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein('', '', 0)
59+
60+
-- threshold with NULL
61+
query
62+
SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2)

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,4 +509,21 @@ class CometStringExpressionSuite extends CometTestBase {
509509
}
510510
}
511511

512+
test("levenshtein with threshold") {
513+
val data = Seq(
514+
("kitten", "sitting"),
515+
("frog", "fog"),
516+
("abc", "abc"),
517+
("hello", "world"))
518+
519+
withParquetTable(data, "tbl") {
520+
checkSparkAnswerAndOperator(
521+
"SELECT levenshtein(_1, _2, 2) FROM tbl")
522+
checkSparkAnswerAndOperator(
523+
"SELECT levenshtein(_1, _2, 0) FROM tbl")
524+
checkSparkAnswerAndOperator(
525+
"SELECT levenshtein(_1, _2, 10) FROM tbl")
526+
}
527+
}
528+
512529
}

spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase {
5959
StringExprConfig("length", "select length(c1) from parquetV1Table"),
6060
StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"),
6161
StringExprConfig("levenshtein", "select levenshtein(c1, 'test') from parquetV1Table"),
62+
StringExprConfig(
63+
"levenshtein_threshold",
64+
"select levenshtein(c1, 'test', 3) from parquetV1Table"),
6265
StringExprConfig("like", "select c1 like '%123%' from parquetV1Table"),
6366
StringExprConfig("lower", "select lower(c1) from parquetV1Table"),
6467
StringExprConfig("lpad", "select lpad(c1, 150, 'x') from parquetV1Table"),

0 commit comments

Comments
 (0)