diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index b58fa911a9e2f..0a83eb3ed61ef 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -18,10 +18,10 @@ use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, + ArrayAccessor, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow_buffer::NullBuffer; use crate::utils::utf8_to_int_type; use datafusion_common::{ @@ -265,53 +265,55 @@ fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result, + V: ArrayAccessor + Copy, { - let string_iter = ArrayIter::new(string_array); - let str_list_iter = ArrayIter::new(str_list_array); - - let mut builder = PrimitiveArray::::builder(string_iter.len()); - - string_iter - .zip(str_list_iter) - .for_each( - |(string_opt, str_list_opt)| match (string_opt, str_list_opt) { - (Some(string), Some(str_list)) => { - let position = str_list - .split(',') - .position(|s| s == string) - .map_or(0, |idx| idx + 1); - builder.append_value(T::Native::from_usize(position).unwrap()); - } - _ => builder.append_null(), - }, - ); + let len = string_array.len(); + let nulls = NullBuffer::union(string_array.nulls(), str_list_array.nulls()); + let zero = T::Native::from_usize(0).unwrap(); + + let values: Vec = (0..len) + .map(|i| { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + return zero; + } + let string = string_array.value(i); + let str_list = str_list_array.value(i); + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + T::Native::from_usize(position).unwrap() + }) + .collect(); - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(PrimitiveArray::::new(values.into(), nulls)) as ArrayRef) } fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result where T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, - V: ArrayAccessor, + V: ArrayAccessor + Copy, { - let mut builder = PrimitiveArray::::builder(str_list_array.len()); - - let str_list_iter = ArrayIter::new(str_list_array); - - str_list_iter.for_each(|str_list_opt| match str_list_opt { - Some(str_list) => { + let len = str_list_array.len(); + let nulls = str_list_array.nulls().cloned(); + let zero = T::Native::from_usize(0).unwrap(); + + let values: Vec = (0..len) + .map(|i| { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + return zero; + } + let str_list = str_list_array.value(i); let position = str_list .split(',') .position(|s| s == string) .map_or(0, |idx| idx + 1); - builder.append_value(T::Native::from_usize(position).unwrap()); - } - None => builder.append_null(), - }); + T::Native::from_usize(position).unwrap() + }) + .collect(); - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(PrimitiveArray::::new(values.into(), nulls)) as ArrayRef) } fn find_in_set_right_literal<'a, T, V>( @@ -321,24 +323,27 @@ fn find_in_set_right_literal<'a, T, V>( where T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, - V: ArrayAccessor, + V: ArrayAccessor + Copy, { - let mut builder = PrimitiveArray::::builder(string_array.len()); - - let string_iter = ArrayIter::new(string_array); - - string_iter.for_each(|string_opt| match string_opt { - Some(string) => { + let len = string_array.len(); + let nulls = string_array.nulls().cloned(); + let zero = T::Native::from_usize(0).unwrap(); + + let values: Vec = (0..len) + .map(|i| { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + return zero; + } + let string = string_array.value(i); let position = str_list .iter() .position(|s| *s == string) .map_or(0, |idx| idx + 1); - builder.append_value(T::Native::from_usize(position).unwrap()); - } - None => builder.append_null(), - }); + T::Native::from_usize(position).unwrap() + }) + .collect(); - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(PrimitiveArray::::new(values.into(), nulls)) as ArrayRef) } #[cfg(test)]