Skip to content

Commit 015b062

Browse files
author
yusinnmao
committed
Support threshold as column (not just literal) in levenshtein
- Refactored spark_levenshtein to handle threshold as either ColumnarValue::Scalar or ColumnarValue::Array using into_array() - NULL threshold in a column produces NULL result for that row - Added Rust unit tests for array threshold, nulls, and negative values - Added Scala integration tests with threshold as column reference - Added SQL file tests with threshold column and NULL scenarios
1 parent ab8f9e5 commit 015b062

3 files changed

Lines changed: 199 additions & 36 deletions

File tree

native/spark-expr/src/string_funcs/levenshtein.rs

Lines changed: 153 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ fn levenshtein_distance(s: &str, t: &str) -> i32 {
8282
/// - `levenshtein(str1, str2)` → edit distance
8383
/// - `levenshtein(str1, str2, threshold)` → edit distance if <= threshold, else -1
8484
///
85-
/// NULL inputs produce NULL outputs. NULL threshold produces NULL output.
85+
/// The threshold argument can be either a scalar or a column (array).
86+
/// NULL inputs produce NULL outputs. NULL threshold produces NULL output for that row.
8687
pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
8788
if args.len() < 2 || args.len() > 3 {
8889
return Err(DataFusionError::Internal(format!(
@@ -91,27 +92,9 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
9192
)));
9293
}
9394

94-
// Extract optional threshold (3rd argument must be a scalar Int32)
95-
let threshold: Option<i32> = if args.len() == 3 {
96-
match &args[2] {
97-
ColumnarValue::Scalar(ScalarValue::Int32(t)) => match t {
98-
Some(val) => Some(*val),
99-
None => return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))),
100-
},
101-
_ => {
102-
return Err(DataFusionError::Internal(
103-
"levenshtein threshold must be an Int32 scalar".to_string(),
104-
));
105-
}
106-
}
107-
} else {
108-
None
109-
};
110-
111-
// Expand scalars to arrays for uniform processing
95+
// Determine array length from any array argument
11296
let len = args
11397
.iter()
114-
.take(2)
11598
.find_map(|arg| match arg {
11699
ColumnarValue::Array(a) => Some(a.len()),
117100
_ => None,
@@ -124,22 +107,56 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
124107
let left_arr = as_string_array(&left);
125108
let right_arr = as_string_array(&right);
126109

127-
let result: Int32Array = left_arr
128-
.iter()
129-
.zip(right_arr.iter())
130-
.map(|(l, r)| match (l, r) {
131-
(Some(l), Some(r)) => {
132-
let dist = levenshtein_distance(l, r);
133-
match threshold {
134-
Some(t) if dist > t => Some(-1),
135-
_ => Some(dist),
110+
// Handle the optional threshold argument (scalar or array)
111+
if args.len() == 3 {
112+
let threshold_array = args[2].clone().into_array(len)?;
113+
let threshold_arr = threshold_array
114+
.as_any()
115+
.downcast_ref::<Int32Array>()
116+
.ok_or_else(|| {
117+
DataFusionError::Internal(
118+
"levenshtein threshold must be Int32".to_string(),
119+
)
120+
})?;
121+
122+
let result: Int32Array = left_arr
123+
.iter()
124+
.zip(right_arr.iter())
125+
.enumerate()
126+
.map(|(i, (l, r))| {
127+
// If threshold is NULL for this row, result is NULL
128+
if threshold_arr.is_null(i) {
129+
return None;
136130
}
137-
}
138-
_ => None, // NULL propagation
139-
})
140-
.collect();
131+
match (l, r) {
132+
(Some(l), Some(r)) => {
133+
let dist = levenshtein_distance(l, r);
134+
let t = threshold_arr.value(i);
135+
if dist > t {
136+
Some(-1)
137+
} else {
138+
Some(dist)
139+
}
140+
}
141+
_ => None, // NULL propagation
142+
}
143+
})
144+
.collect();
141145

142-
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
146+
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
147+
} else {
148+
// No threshold: just compute distance
149+
let result: Int32Array = left_arr
150+
.iter()
151+
.zip(right_arr.iter())
152+
.map(|(l, r)| match (l, r) {
153+
(Some(l), Some(r)) => Some(levenshtein_distance(l, r)),
154+
_ => None, // NULL propagation
155+
})
156+
.collect();
157+
158+
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
159+
}
143160
}
144161

145162
#[cfg(test)]
@@ -223,8 +240,108 @@ mod tests {
223240

224241
let result = spark_levenshtein(&[left, right, threshold]).unwrap();
225242
match result {
226-
ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} // NULL threshold -> NULL
227-
_ => panic!("Expected NULL scalar result for NULL threshold"),
243+
ColumnarValue::Array(arr) => {
244+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
245+
assert_eq!(int_arr.len(), 1);
246+
assert!(int_arr.is_null(0)); // NULL threshold -> NULL result
247+
}
248+
_ => panic!("Expected array result with NULL for NULL threshold"),
249+
}
250+
}
251+
252+
#[test]
253+
fn test_spark_levenshtein_threshold_as_array() {
254+
// threshold is a column (array) with per-row values
255+
let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![
256+
Some("kitten"),
257+
Some("frog"),
258+
Some("abc"),
259+
Some("hello"),
260+
])));
261+
let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![
262+
Some("sitting"),
263+
Some("fog"),
264+
Some("abc"),
265+
Some("world"),
266+
])));
267+
// Per-row thresholds: 2, 5, 0, 3
268+
let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![
269+
Some(2),
270+
Some(5),
271+
Some(0),
272+
Some(3),
273+
])));
274+
275+
let result = spark_levenshtein(&[left, right, threshold]).unwrap();
276+
match result {
277+
ColumnarValue::Array(arr) => {
278+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
279+
assert_eq!(int_arr.value(0), -1); // kitten->sitting=3 > 2, return -1
280+
assert_eq!(int_arr.value(1), 1); // frog->fog=1 <= 5, return 1
281+
assert_eq!(int_arr.value(2), 0); // abc->abc=0 <= 0, return 0
282+
assert_eq!(int_arr.value(3), -1); // hello->world=4 > 3, return -1
283+
}
284+
_ => panic!("Expected array result"),
285+
}
286+
}
287+
288+
#[test]
289+
fn test_spark_levenshtein_threshold_array_with_nulls() {
290+
// threshold array where some values are NULL
291+
let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![
292+
Some("abc"),
293+
Some("hello"),
294+
Some("frog"),
295+
])));
296+
let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![
297+
Some("adc"),
298+
Some("world"),
299+
Some("fog"),
300+
])));
301+
let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![
302+
Some(2),
303+
None, // NULL threshold for this row
304+
Some(0),
305+
])));
306+
307+
let result = spark_levenshtein(&[left, right, threshold]).unwrap();
308+
match result {
309+
ColumnarValue::Array(arr) => {
310+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
311+
assert_eq!(int_arr.value(0), 1); // abc->adc=1 <= 2, return 1
312+
assert!(int_arr.is_null(1)); // NULL threshold -> NULL
313+
assert_eq!(int_arr.value(2), -1); // frog->fog=1 > 0, return -1
314+
}
315+
_ => panic!("Expected array result"),
316+
}
317+
}
318+
319+
#[test]
320+
fn test_spark_levenshtein_threshold_negative() {
321+
// Negative threshold means distance always exceeds threshold → return -1
322+
let left = ColumnarValue::Array(Arc::new(StringArray::from(vec![
323+
Some("abc"),
324+
Some("abc"),
325+
])));
326+
let right = ColumnarValue::Array(Arc::new(StringArray::from(vec![
327+
Some("abc"),
328+
Some("adc"),
329+
])));
330+
let threshold = ColumnarValue::Array(Arc::new(Int32Array::from(vec![
331+
Some(-1),
332+
Some(-5),
333+
])));
334+
335+
let result = spark_levenshtein(&[left, right, threshold]).unwrap();
336+
match result {
337+
ColumnarValue::Array(arr) => {
338+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
339+
// dist=0 > -1 is true, so return -1
340+
assert_eq!(int_arr.value(0), -1);
341+
// dist=1 > -5 is true, so return -1
342+
assert_eq!(int_arr.value(1), -1);
343+
}
344+
_ => panic!("Expected array result"),
228345
}
229346
}
230347
}

spark/src/test/resources/sql-tests/expressions/string/levenshtein_threshold.sql

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,23 @@ SELECT levenshtein('abc', 'abc', 0), levenshtein('abc', 'adc', 0), levenshtein('
3838
-- threshold with NULL
3939
query
4040
SELECT levenshtein('abc', 'adc', NULL), levenshtein(NULL, 'test', 2)
41+
42+
-- threshold as column
43+
statement
44+
CREATE TABLE test_levenshtein_col(s1 string, s2 string, threshold int) USING parquet
45+
46+
statement
47+
INSERT INTO test_levenshtein_col VALUES ('kitten', 'sitting', 2), ('frog', 'fog', 5), ('abc', 'abc', 0), ('hello', 'world', 3)
48+
49+
query
50+
SELECT levenshtein(s1, s2, threshold) FROM test_levenshtein_col
51+
52+
-- threshold as column with NULLs
53+
statement
54+
CREATE TABLE test_levenshtein_col_nulls(s1 string, s2 string, threshold int) USING parquet
55+
56+
statement
57+
INSERT INTO test_levenshtein_col_nulls VALUES ('abc', 'adc', 2), ('hello', 'world', NULL), (NULL, 'test', 3), ('frog', 'fog', -1)
58+
59+
query
60+
SELECT levenshtein(s1, s2, threshold) FROM test_levenshtein_col_nulls

spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,4 +750,30 @@ class CometStringExpressionSuite extends CometTestBase {
750750
}
751751
}
752752

753+
test("levenshtein with threshold as column") {
754+
assume(isSpark35Plus, "levenshtein with threshold requires Spark 3.5+")
755+
val table = "levenshtein_col_threshold_test"
756+
withTable(table) {
757+
sql(s"CREATE TABLE $table(s1 STRING, s2 STRING, threshold INT) USING parquet")
758+
sql(
759+
s"INSERT INTO $table VALUES " +
760+
"('kitten', 'sitting', 2), ('frog', 'fog', 5), ('abc', 'abc', 0), ('hello', 'world', 3)")
761+
// threshold as column reference
762+
checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2, threshold) FROM $table")
763+
}
764+
}
765+
766+
test("levenshtein with threshold as column with nulls") {
767+
assume(isSpark35Plus, "levenshtein with threshold requires Spark 3.5+")
768+
val table = "levenshtein_col_threshold_null_test"
769+
withTable(table) {
770+
sql(s"CREATE TABLE $table(s1 STRING, s2 STRING, threshold INT) USING parquet")
771+
sql(
772+
s"INSERT INTO $table VALUES " +
773+
"('abc', 'adc', 2), ('hello', 'world', NULL), (NULL, 'test', 3), ('frog', 'fog', -1)")
774+
// NULL threshold and NULL strings should produce NULL; negative threshold returns -1
775+
checkSparkAnswerAndOperator(s"SELECT levenshtein(s1, s2, threshold) FROM $table")
776+
}
777+
}
778+
753779
}

0 commit comments

Comments
 (0)