1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: { Array , StringArray , as_largestring_array } ;
18+ use arrow:: array:: Array ;
1919use std:: any:: Any ;
2020use std:: sync:: Arc ;
2121
@@ -25,7 +25,9 @@ use crate::string::concat;
2525use crate :: string:: concat:: simplify_concat;
2626use crate :: string:: concat_ws;
2727use crate :: strings:: { ColumnarValueRef , StringArrayBuilder } ;
28- use datafusion_common:: cast:: { as_string_array, as_string_view_array} ;
28+ use datafusion_common:: cast:: {
29+ as_large_string_array, as_string_array, as_string_view_array,
30+ } ;
2931use datafusion_common:: { Result , ScalarValue , exec_err, internal_err, plan_err} ;
3032use datafusion_expr:: expr:: ScalarFunction ;
3133use datafusion_expr:: simplify:: { ExprSimplifyResult , SimplifyContext } ;
@@ -105,26 +107,21 @@ impl ScalarUDFImpl for ConcatWsFunc {
105107 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
106108 let ScalarFunctionArgs { args, .. } = args;
107109
108- // do not accept 0 arguments.
109110 if args. len ( ) < 2 {
110111 return exec_err ! (
111112 "concat_ws was called with {} arguments. It requires at least 2." ,
112113 args. len( )
113114 ) ;
114115 }
115116
116- let array_len = args
117- . iter ( )
118- . filter_map ( |x| match x {
119- ColumnarValue :: Array ( array) => Some ( array. len ( ) ) ,
120- _ => None ,
121- } )
122- . next ( ) ;
117+ let array_len = args. iter ( ) . find_map ( |x| match x {
118+ ColumnarValue :: Array ( array) => Some ( array. len ( ) ) ,
119+ _ => None ,
120+ } ) ;
123121
124122 // Scalar
125123 if array_len. is_none ( ) {
126124 let ColumnarValue :: Scalar ( scalar) = & args[ 0 ] else {
127- // loop above checks for all args being scalar
128125 unreachable ! ( )
129126 } ;
130127 let sep = match scalar. try_as_str ( ) {
@@ -139,7 +136,6 @@ impl ScalarUDFImpl for ConcatWsFunc {
139136 let mut values = Vec :: with_capacity ( args. len ( ) - 1 ) ;
140137 for arg in & args[ 1 ..] {
141138 let ColumnarValue :: Scalar ( scalar) = arg else {
142- // loop above checks for all args being scalar
143139 unreachable ! ( )
144140 } ;
145141
@@ -162,23 +158,53 @@ impl ScalarUDFImpl for ConcatWsFunc {
162158
163159 // parse sep
164160 let sep = match & args[ 0 ] {
165- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( s) ) ) => {
166- data_size += s. len ( ) * len * ( args. len ( ) - 2 ) ; // estimate
167- ColumnarValueRef :: Scalar ( s. as_bytes ( ) )
168- }
169- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) => {
170- return Ok ( ColumnarValue :: Array ( Arc :: new ( StringArray :: new_null ( len) ) ) ) ;
171- }
172- ColumnarValue :: Array ( array) => {
173- let string_array = as_string_array ( array) ?;
174- data_size += string_array. values ( ) . len ( ) * ( args. len ( ) - 2 ) ; // estimate
175- if array. is_nullable ( ) {
176- ColumnarValueRef :: NullableArray ( string_array)
177- } else {
178- ColumnarValueRef :: NonNullableArray ( string_array)
161+ ColumnarValue :: Scalar ( scalar) => match scalar. try_as_str ( ) {
162+ Some ( Some ( s) ) => {
163+ data_size += s. len ( ) * len * ( args. len ( ) - 2 ) ; // estimate
164+ ColumnarValueRef :: Scalar ( s. as_bytes ( ) )
179165 }
180- }
181- _ => unreachable ! ( "concat ws" ) ,
166+ Some ( None ) => {
167+ return Ok ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) ) ;
168+ }
169+ None => {
170+ return internal_err ! ( "Expected string separator, got {scalar:?}" ) ;
171+ }
172+ } ,
173+ ColumnarValue :: Array ( array) => match array. data_type ( ) {
174+ DataType :: Utf8 => {
175+ let string_array = as_string_array ( array) ?;
176+ data_size += string_array. values ( ) . len ( ) * ( args. len ( ) - 2 ) ;
177+ if array. is_nullable ( ) {
178+ ColumnarValueRef :: NullableArray ( string_array)
179+ } else {
180+ ColumnarValueRef :: NonNullableArray ( string_array)
181+ }
182+ }
183+ DataType :: LargeUtf8 => {
184+ let string_array = as_large_string_array ( array) ?;
185+ data_size += string_array. values ( ) . len ( ) * ( args. len ( ) - 2 ) ;
186+ if array. is_nullable ( ) {
187+ ColumnarValueRef :: NullableLargeStringArray ( string_array)
188+ } else {
189+ ColumnarValueRef :: NonNullableLargeStringArray ( string_array)
190+ }
191+ }
192+ DataType :: Utf8View => {
193+ let string_array = as_string_view_array ( array) ?;
194+ data_size +=
195+ string_array. total_buffer_bytes_used ( ) * ( args. len ( ) - 2 ) ;
196+ if array. is_nullable ( ) {
197+ ColumnarValueRef :: NullableStringViewArray ( string_array)
198+ } else {
199+ ColumnarValueRef :: NonNullableStringViewArray ( string_array)
200+ }
201+ }
202+ other => {
203+ return plan_err ! (
204+ "Input was {other} which is not a supported datatype for concat_ws separator"
205+ ) ;
206+ }
207+ } ,
182208 } ;
183209
184210 let mut columns = Vec :: with_capacity ( args. len ( ) - 1 ) ;
@@ -206,7 +232,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
206232 columns. push ( column) ;
207233 }
208234 DataType :: LargeUtf8 => {
209- let string_array = as_largestring_array ( array) ;
235+ let string_array = as_large_string_array ( array) ? ;
210236
211237 data_size += string_array. values ( ) . len ( ) ;
212238 let column = if array. is_nullable ( ) {
@@ -221,11 +247,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
221247 DataType :: Utf8View => {
222248 let string_array = as_string_view_array ( array) ?;
223249
224- data_size += string_array
225- . data_buffers ( )
226- . iter ( )
227- . map ( |buf| buf. len ( ) )
228- . sum :: < usize > ( ) ;
250+ data_size += string_array. total_buffer_bytes_used ( ) ;
229251 let column = if array. is_nullable ( ) {
230252 ColumnarValueRef :: NullableStringViewArray ( string_array)
231253 } else {
@@ -251,18 +273,14 @@ impl ScalarUDFImpl for ConcatWsFunc {
251273 continue ;
252274 }
253275
254- let mut iter = columns . iter ( ) ;
255- for column in iter . by_ref ( ) {
276+ let mut first = true ;
277+ for column in & columns {
256278 if column. is_valid ( i) {
279+ if !first {
280+ builder. write :: < false > ( & sep, i) ;
281+ }
257282 builder. write :: < false > ( column, i) ;
258- break ;
259- }
260- }
261-
262- for column in iter {
263- if column. is_valid ( i) {
264- builder. write :: < false > ( & sep, i) ;
265- builder. write :: < false > ( column, i) ;
283+ first = false ;
266284 }
267285 }
268286
@@ -546,4 +564,78 @@ mod tests {
546564
547565 Ok ( ( ) )
548566 }
567+
568+ #[ test]
569+ fn concat_ws_utf8view_scalar_separator ( ) -> Result < ( ) > {
570+ let c0 = ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( "," . to_string ( ) ) ) ) ;
571+ let c1 =
572+ ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [ "foo" , "bar" , "baz" ] ) ) ) ;
573+ let c2 = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
574+ Some ( "x" ) ,
575+ None ,
576+ Some ( "z" ) ,
577+ ] ) ) ) ;
578+
579+ let arg_fields = vec ! [
580+ Field :: new( "a" , Utf8 , true ) . into( ) ,
581+ Field :: new( "a" , Utf8 , true ) . into( ) ,
582+ Field :: new( "a" , Utf8 , true ) . into( ) ,
583+ ] ;
584+ let args = ScalarFunctionArgs {
585+ args : vec ! [ c0, c1, c2] ,
586+ arg_fields,
587+ number_rows : 3 ,
588+ return_field : Field :: new ( "f" , Utf8 , true ) . into ( ) ,
589+ config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
590+ } ;
591+
592+ let result = ConcatWsFunc :: new ( ) . invoke_with_args ( args) ?;
593+ let expected =
594+ Arc :: new ( StringArray :: from ( vec ! [ "foo,x" , "bar" , "baz,z" ] ) ) as ArrayRef ;
595+ match & result {
596+ ColumnarValue :: Array ( array) => {
597+ assert_eq ! ( & expected, array) ;
598+ }
599+ _ => panic ! ( "Expected array result" ) ,
600+ }
601+
602+ Ok ( ( ) )
603+ }
604+
605+ #[ test]
606+ fn concat_ws_largeutf8_scalar_separator ( ) -> Result < ( ) > {
607+ let c0 = ColumnarValue :: Scalar ( ScalarValue :: LargeUtf8 ( Some ( "," . to_string ( ) ) ) ) ;
608+ let c1 =
609+ ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [ "foo" , "bar" , "baz" ] ) ) ) ;
610+ let c2 = ColumnarValue :: Array ( Arc :: new ( StringArray :: from ( vec ! [
611+ Some ( "x" ) ,
612+ None ,
613+ Some ( "z" ) ,
614+ ] ) ) ) ;
615+
616+ let arg_fields = vec ! [
617+ Field :: new( "a" , Utf8 , true ) . into( ) ,
618+ Field :: new( "a" , Utf8 , true ) . into( ) ,
619+ Field :: new( "a" , Utf8 , true ) . into( ) ,
620+ ] ;
621+ let args = ScalarFunctionArgs {
622+ args : vec ! [ c0, c1, c2] ,
623+ arg_fields,
624+ number_rows : 3 ,
625+ return_field : Field :: new ( "f" , Utf8 , true ) . into ( ) ,
626+ config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
627+ } ;
628+
629+ let result = ConcatWsFunc :: new ( ) . invoke_with_args ( args) ?;
630+ let expected =
631+ Arc :: new ( StringArray :: from ( vec ! [ "foo,x" , "bar" , "baz,z" ] ) ) as ArrayRef ;
632+ match & result {
633+ ColumnarValue :: Array ( array) => {
634+ assert_eq ! ( & expected, array) ;
635+ }
636+ _ => panic ! ( "Expected array result" ) ,
637+ }
638+
639+ Ok ( ( ) )
640+ }
549641}
0 commit comments