@@ -82,7 +82,8 @@ fn levenshtein_distance(s: &str, t: &str) -> i32 {
8282/// - `levenshtein(str1, str2)` → edit distance
8383/// - `levenshtein(str1, str2, threshold)` → edit distance if <= threshold, else -1
8484///
85- /// NULL inputs produce NULL outputs. NULL threshold produces NULL output.
85+ /// The threshold argument can be either a scalar or a column (array).
86+ /// NULL inputs produce NULL outputs. NULL threshold produces NULL output for that row.
8687pub fn spark_levenshtein ( args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
8788 if args. len ( ) < 2 || args. len ( ) > 3 {
8889 return Err ( DataFusionError :: Internal ( format ! (
@@ -91,27 +92,9 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
9192 ) ) ) ;
9293 }
9394
94- // Extract optional threshold (3rd argument must be a scalar Int32)
95- let threshold: Option < i32 > = if args. len ( ) == 3 {
96- match & args[ 2 ] {
97- ColumnarValue :: Scalar ( ScalarValue :: Int32 ( t) ) => match t {
98- Some ( val) => Some ( * val) ,
99- None => return Ok ( ColumnarValue :: Scalar ( ScalarValue :: Int32 ( None ) ) ) ,
100- } ,
101- _ => {
102- return Err ( DataFusionError :: Internal (
103- "levenshtein threshold must be an Int32 scalar" . to_string ( ) ,
104- ) ) ;
105- }
106- }
107- } else {
108- None
109- } ;
110-
111- // Expand scalars to arrays for uniform processing
95+ // Determine array length from any array argument
11296 let len = args
11397 . iter ( )
114- . take ( 2 )
11598 . find_map ( |arg| match arg {
11699 ColumnarValue :: Array ( a) => Some ( a. len ( ) ) ,
117100 _ => None ,
@@ -124,22 +107,56 @@ pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
124107 let left_arr = as_string_array ( & left) ;
125108 let right_arr = as_string_array ( & right) ;
126109
127- let result: Int32Array = left_arr
128- . iter ( )
129- . zip ( right_arr. iter ( ) )
130- . map ( |( l, r) | match ( l, r) {
131- ( Some ( l) , Some ( r) ) => {
132- let dist = levenshtein_distance ( l, r) ;
133- match threshold {
134- Some ( t) if dist > t => Some ( -1 ) ,
135- _ => Some ( dist) ,
110+ // Handle the optional threshold argument (scalar or array)
111+ if args. len ( ) == 3 {
112+ let threshold_array = args[ 2 ] . clone ( ) . into_array ( len) ?;
113+ let threshold_arr = threshold_array
114+ . as_any ( )
115+ . downcast_ref :: < Int32Array > ( )
116+ . ok_or_else ( || {
117+ DataFusionError :: Internal (
118+ "levenshtein threshold must be Int32" . to_string ( ) ,
119+ )
120+ } ) ?;
121+
122+ let result: Int32Array = left_arr
123+ . iter ( )
124+ . zip ( right_arr. iter ( ) )
125+ . enumerate ( )
126+ . map ( |( i, ( l, r) ) | {
127+ // If threshold is NULL for this row, result is NULL
128+ if threshold_arr. is_null ( i) {
129+ return None ;
136130 }
137- }
138- _ => None , // NULL propagation
139- } )
140- . collect ( ) ;
131+ match ( l, r) {
132+ ( Some ( l) , Some ( r) ) => {
133+ let dist = levenshtein_distance ( l, r) ;
134+ let t = threshold_arr. value ( i) ;
135+ if dist > t {
136+ Some ( -1 )
137+ } else {
138+ Some ( dist)
139+ }
140+ }
141+ _ => None , // NULL propagation
142+ }
143+ } )
144+ . collect ( ) ;
141145
142- Ok ( ColumnarValue :: Array ( Arc :: new ( result) as ArrayRef ) )
146+ Ok ( ColumnarValue :: Array ( Arc :: new ( result) as ArrayRef ) )
147+ } else {
148+ // No threshold: just compute distance
149+ let result: Int32Array = left_arr
150+ . iter ( )
151+ . zip ( right_arr. iter ( ) )
152+ . map ( |( l, r) | match ( l, r) {
153+ ( Some ( l) , Some ( r) ) => Some ( levenshtein_distance ( l, r) ) ,
154+ _ => None , // NULL propagation
155+ } )
156+ . collect ( ) ;
157+
158+ Ok ( ColumnarValue :: Array ( Arc :: new ( result) as ArrayRef ) )
159+ }
143160}
144161
145162#[ cfg( test) ]
@@ -223,8 +240,108 @@ mod tests {
223240
224241 let result = spark_levenshtein ( & [ left, right, threshold] ) . unwrap ( ) ;
225242 match result {
226- ColumnarValue :: Scalar ( ScalarValue :: Int32 ( None ) ) => { } // NULL threshold -> NULL
227- _ => panic ! ( "Expected NULL scalar result for NULL threshold" ) ,
243+ ColumnarValue :: Array ( arr) => {
244+ let int_arr = arr. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
245+ assert_eq ! ( int_arr. len( ) , 1 ) ;
246+ assert ! ( int_arr. is_null( 0 ) ) ; // NULL threshold -> NULL result
247+ }
248+ _ => panic ! ( "Expected array result with NULL for NULL threshold" ) ,
249+ }
250+ }
251+
252+ #[ test]
253+ fn test_spark_levenshtein_threshold_as_array ( ) {
254+ // threshold is a column (array) with per-row values
255+ let left = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
256+ Some ( "kitten" ) ,
257+ Some ( "frog" ) ,
258+ Some ( "abc" ) ,
259+ Some ( "hello" ) ,
260+ ] ) ) ) ;
261+ let right = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
262+ Some ( "sitting" ) ,
263+ Some ( "fog" ) ,
264+ Some ( "abc" ) ,
265+ Some ( "world" ) ,
266+ ] ) ) ) ;
267+ // Per-row thresholds: 2, 5, 0, 3
268+ let threshold = ColumnarValue :: Array ( Arc :: new ( Int32Array :: from ( vec ! [
269+ Some ( 2 ) ,
270+ Some ( 5 ) ,
271+ Some ( 0 ) ,
272+ Some ( 3 ) ,
273+ ] ) ) ) ;
274+
275+ let result = spark_levenshtein ( & [ left, right, threshold] ) . unwrap ( ) ;
276+ match result {
277+ ColumnarValue :: Array ( arr) => {
278+ let int_arr = arr. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
279+ assert_eq ! ( int_arr. value( 0 ) , -1 ) ; // kitten->sitting=3 > 2, return -1
280+ assert_eq ! ( int_arr. value( 1 ) , 1 ) ; // frog->fog=1 <= 5, return 1
281+ assert_eq ! ( int_arr. value( 2 ) , 0 ) ; // abc->abc=0 <= 0, return 0
282+ assert_eq ! ( int_arr. value( 3 ) , -1 ) ; // hello->world=4 > 3, return -1
283+ }
284+ _ => panic ! ( "Expected array result" ) ,
285+ }
286+ }
287+
288+ #[ test]
289+ fn test_spark_levenshtein_threshold_array_with_nulls ( ) {
290+ // threshold array where some values are NULL
291+ let left = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
292+ Some ( "abc" ) ,
293+ Some ( "hello" ) ,
294+ Some ( "frog" ) ,
295+ ] ) ) ) ;
296+ let right = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
297+ Some ( "adc" ) ,
298+ Some ( "world" ) ,
299+ Some ( "fog" ) ,
300+ ] ) ) ) ;
301+ let threshold = ColumnarValue :: Array ( Arc :: new ( Int32Array :: from ( vec ! [
302+ Some ( 2 ) ,
303+ None , // NULL threshold for this row
304+ Some ( 0 ) ,
305+ ] ) ) ) ;
306+
307+ let result = spark_levenshtein ( & [ left, right, threshold] ) . unwrap ( ) ;
308+ match result {
309+ ColumnarValue :: Array ( arr) => {
310+ let int_arr = arr. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
311+ assert_eq ! ( int_arr. value( 0 ) , 1 ) ; // abc->adc=1 <= 2, return 1
312+ assert ! ( int_arr. is_null( 1 ) ) ; // NULL threshold -> NULL
313+ assert_eq ! ( int_arr. value( 2 ) , -1 ) ; // frog->fog=1 > 0, return -1
314+ }
315+ _ => panic ! ( "Expected array result" ) ,
316+ }
317+ }
318+
319+ #[ test]
320+ fn test_spark_levenshtein_threshold_negative ( ) {
321+ // Negative threshold means distance always exceeds threshold → return -1
322+ let left = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
323+ Some ( "abc" ) ,
324+ Some ( "abc" ) ,
325+ ] ) ) ) ;
326+ let right = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
327+ Some ( "abc" ) ,
328+ Some ( "adc" ) ,
329+ ] ) ) ) ;
330+ let threshold = ColumnarValue :: Array ( Arc :: new ( Int32Array :: from ( vec ! [
331+ Some ( -1 ) ,
332+ Some ( -5 ) ,
333+ ] ) ) ) ;
334+
335+ let result = spark_levenshtein ( & [ left, right, threshold] ) . unwrap ( ) ;
336+ match result {
337+ ColumnarValue :: Array ( arr) => {
338+ let int_arr = arr. as_any ( ) . downcast_ref :: < Int32Array > ( ) . unwrap ( ) ;
339+ // dist=0 > -1 is true, so return -1
340+ assert_eq ! ( int_arr. value( 0 ) , -1 ) ;
341+ // dist=1 > -5 is true, so return -1
342+ assert_eq ! ( int_arr. value( 1 ) , -1 ) ;
343+ }
344+ _ => panic ! ( "Expected array result" ) ,
228345 }
229346 }
230347}
0 commit comments