1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use crate :: strings:: {
19+ BulkNullStringArrayBuilder , GenericStringArrayBuilder , StringViewArrayBuilder ,
20+ } ;
1821use crate :: utils:: utf8_to_str_type;
1922use arrow:: array:: {
20- Array , ArrayRef , AsArray , ByteView , GenericStringBuilder , Int64Array ,
21- StringArrayType , StringLikeArrayBuilder , StringViewArray , StringViewBuilder ,
23+ Array , ArrayRef , AsArray , ByteView , Int64Array , StringArrayType , StringViewArray ,
2224 make_view, new_null_array,
2325} ;
24- use arrow:: buffer:: ScalarBuffer ;
26+ use arrow:: buffer:: { NullBuffer , ScalarBuffer } ;
2527use arrow:: datatypes:: DataType ;
2628use datafusion_common:: ScalarValue ;
2729use datafusion_common:: cast:: as_int64_array;
@@ -167,7 +169,7 @@ impl ScalarUDFImpl for SplitPartFunc {
167169 let result = match args[ 0 ] . data_type ( ) {
168170 DataType :: Utf8View => split_part_for_delimiter_type ! (
169171 & args[ 0 ] . as_string_view( ) ,
170- StringViewBuilder :: with_capacity( inferred_length)
172+ StringViewArrayBuilder :: with_capacity( inferred_length)
171173 ) ,
172174 DataType :: Utf8 => {
173175 let str_arr = & args[ 0 ] . as_string :: < i32 > ( ) ;
@@ -176,7 +178,7 @@ impl ScalarUDFImpl for SplitPartFunc {
176178 // pre-allocating the full input data size.
177179 split_part_for_delimiter_type ! (
178180 str_arr,
179- GenericStringBuilder :: <i32 >:: with_capacity(
181+ GenericStringArrayBuilder :: <i32 >:: with_capacity(
180182 inferred_length,
181183 inferred_length,
182184 )
@@ -187,7 +189,7 @@ impl ScalarUDFImpl for SplitPartFunc {
187189 // Conservative under-estimate; see Utf8 comment above.
188190 split_part_for_delimiter_type ! (
189191 str_arr,
190- GenericStringBuilder :: <i64 >:: with_capacity(
192+ GenericStringArrayBuilder :: <i64 >:: with_capacity(
191193 inferred_length,
192194 inferred_length,
193195 )
@@ -293,7 +295,7 @@ fn split_part_scalar(
293295 arr,
294296 delimiter,
295297 position,
296- GenericStringBuilder :: < i32 > :: with_capacity ( arr. len ( ) , arr. len ( ) ) ,
298+ GenericStringArrayBuilder :: < i32 > :: with_capacity ( arr. len ( ) , arr. len ( ) ) ,
297299 )
298300 }
299301 DataType :: LargeUtf8 => {
@@ -303,7 +305,7 @@ fn split_part_scalar(
303305 arr,
304306 delimiter,
305307 position,
306- GenericStringBuilder :: < i64 > :: with_capacity ( arr. len ( ) , arr. len ( ) ) ,
308+ GenericStringArrayBuilder :: < i64 > :: with_capacity ( arr. len ( ) , arr. len ( ) ) ,
307309 )
308310 }
309311 other => exec_err ! ( "Unsupported string type {other:?} for split_part" ) ,
@@ -323,7 +325,7 @@ fn split_part_scalar_impl<'a, S, B>(
323325) -> Result < ArrayRef >
324326where
325327 S : StringArrayType < ' a > + Copy ,
326- B : StringLikeArrayBuilder ,
328+ B : BulkNullStringArrayBuilder ,
327329{
328330 if delimiter. is_empty ( ) {
329331 // PostgreSQL: empty delimiter treats input as a single field,
@@ -367,16 +369,31 @@ where
367369fn map_strings < ' a , S , B , F > ( string_array : S , mut builder : B , f : F ) -> Result < ArrayRef >
368370where
369371 S : StringArrayType < ' a > + Copy ,
370- B : StringLikeArrayBuilder ,
372+ B : BulkNullStringArrayBuilder ,
371373 F : Fn ( & ' a str ) -> Option < & ' a str > ,
372374{
373- for string in string_array. iter ( ) {
374- match string {
375- Some ( s) => builder. append_value ( f ( s) . unwrap_or ( "" ) ) ,
376- None => builder. append_null ( ) ,
375+ let item_len = string_array. len ( ) ;
376+ let nulls = string_array. nulls ( ) . cloned ( ) ;
377+
378+ if let Some ( ref n) = nulls {
379+ for i in 0 ..item_len {
380+ if n. is_null ( i) {
381+ builder. append_placeholder ( ) ;
382+ } else {
383+ // SAFETY: `n.is_null(i)` was false in the branch above.
384+ let s = unsafe { string_array. value_unchecked ( i) } ;
385+ builder. append_value ( f ( s) . unwrap_or ( "" ) ) ;
386+ }
387+ }
388+ } else {
389+ for i in 0 ..item_len {
390+ // SAFETY: no null buffer means every index is valid.
391+ let s = unsafe { string_array. value_unchecked ( i) } ;
392+ builder. append_value ( f ( s) . unwrap_or ( "" ) ) ;
377393 }
378394 }
379- Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
395+
396+ builder. finish ( nulls)
380397}
381398
382399/// Finds the `n`th (0-based) split part using a pre-built `memmem::Finder`.
@@ -543,58 +560,82 @@ fn split_part_impl<'a, StringArrType, DelimiterArrType, B>(
543560where
544561 StringArrType : StringArrayType < ' a > ,
545562 DelimiterArrType : StringArrayType < ' a > ,
546- B : StringLikeArrayBuilder ,
563+ B : BulkNullStringArrayBuilder ,
547564{
548- for ( ( string, delimiter) , n) in string_array
549- . iter ( )
550- . zip ( delimiter_array. iter ( ) )
551- . zip ( n_array. iter ( ) )
552- {
553- match ( string, delimiter, n) {
554- ( Some ( string) , Some ( delimiter) , Some ( n) ) => {
555- let result = match n. cmp ( & 0 ) {
556- std:: cmp:: Ordering :: Greater => {
557- let idx: usize = ( n - 1 ) . try_into ( ) . map_err ( |_| {
558- exec_datafusion_err ! (
559- "split_part index {n} exceeds maximum supported value"
560- )
561- } ) ?;
562- if delimiter. is_empty ( ) {
563- // Match PostgreSQL's behavior: empty delimiter
564- // treats input as a single field, so only position
565- // 1 returns data.
566- ( n == 1 ) . then_some ( string)
567- } else {
568- split_nth ( string, delimiter, idx)
569- }
570- }
571- std:: cmp:: Ordering :: Less => {
572- let idx: usize =
573- ( n. unsigned_abs ( ) - 1 ) . try_into ( ) . map_err ( |_| {
574- exec_datafusion_err ! (
575- "split_part index {n} exceeds minimum supported value"
576- )
577- } ) ?;
578- if delimiter. is_empty ( ) {
579- // Match PostgreSQL's behavior: empty delimiter
580- // treats input as a single field, so only position
581- // -1 returns data.
582- ( n == -1 ) . then_some ( string)
583- } else {
584- rsplit_nth ( string, delimiter, idx)
585- }
586- }
587- std:: cmp:: Ordering :: Equal => {
588- return exec_err ! ( "field position must not be zero" ) ;
589- }
590- } ;
591- builder. append_value ( result. unwrap_or ( "" ) ) ;
565+ let nulls = NullBuffer :: union_many ( [
566+ string_array. nulls ( ) ,
567+ delimiter_array. nulls ( ) ,
568+ n_array. nulls ( ) ,
569+ ] ) ;
570+
571+ if let Some ( ref n) = nulls {
572+ for i in 0 ..string_array. len ( ) {
573+ if n. is_null ( i) {
574+ builder. append_placeholder ( ) ;
575+ continue ;
592576 }
593- _ => builder. append_null ( ) ,
577+
578+ // SAFETY: the union null buffer is valid at `i`, so each input is valid.
579+ let string = unsafe { string_array. value_unchecked ( i) } ;
580+ let delimiter = unsafe { delimiter_array. value_unchecked ( i) } ;
581+ let position = unsafe { n_array. value_unchecked ( i) } ;
582+ append_split_part ( string, delimiter, position, & mut builder) ?;
583+ }
584+ } else {
585+ for i in 0 ..string_array. len ( ) {
586+ // SAFETY: no input has a null buffer, so every index is valid.
587+ let string = unsafe { string_array. value_unchecked ( i) } ;
588+ let delimiter = unsafe { delimiter_array. value_unchecked ( i) } ;
589+ let position = unsafe { n_array. value_unchecked ( i) } ;
590+ append_split_part ( string, delimiter, position, & mut builder) ?;
594591 }
595592 }
596593
597- Ok ( Arc :: new ( builder. finish ( ) ) as ArrayRef )
594+ builder. finish ( nulls)
595+ }
596+
597+ #[ inline]
598+ fn append_split_part < B : BulkNullStringArrayBuilder > (
599+ string : & str ,
600+ delimiter : & str ,
601+ n : i64 ,
602+ builder : & mut B ,
603+ ) -> Result < ( ) > {
604+ let result = match n. cmp ( & 0 ) {
605+ std:: cmp:: Ordering :: Greater => {
606+ let idx: usize = ( n - 1 ) . try_into ( ) . map_err ( |_| {
607+ exec_datafusion_err ! (
608+ "split_part index {n} exceeds maximum supported value"
609+ )
610+ } ) ?;
611+ if delimiter. is_empty ( ) {
612+ // Match PostgreSQL's behavior: empty delimiter treats input
613+ // as a single field, so only position 1 returns data.
614+ ( n == 1 ) . then_some ( string)
615+ } else {
616+ split_nth ( string, delimiter, idx)
617+ }
618+ }
619+ std:: cmp:: Ordering :: Less => {
620+ let idx: usize = ( n. unsigned_abs ( ) - 1 ) . try_into ( ) . map_err ( |_| {
621+ exec_datafusion_err ! (
622+ "split_part index {n} exceeds minimum supported value"
623+ )
624+ } ) ?;
625+ if delimiter. is_empty ( ) {
626+ // Match PostgreSQL's behavior: empty delimiter treats input
627+ // as a single field, so only position -1 returns data.
628+ ( n == -1 ) . then_some ( string)
629+ } else {
630+ rsplit_nth ( string, delimiter, idx)
631+ }
632+ }
633+ std:: cmp:: Ordering :: Equal => {
634+ return exec_err ! ( "field position must not be zero" ) ;
635+ }
636+ } ;
637+ builder. append_value ( result. unwrap_or ( "" ) ) ;
638+ Ok ( ( ) )
598639}
599640
600641#[ cfg( test) ]
0 commit comments