1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use arrow:: array:: { Array , ArrayRef , Int64Builder } ;
18+ use arrow:: array:: { Array , ArrayData , ArrayRef , Int64Builder , ListArray } ;
1919use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2020use datafusion_common:: cast:: { as_int64_array, as_list_array} ;
2121use datafusion_common:: utils:: ListCoercion ;
22- use datafusion_common:: {
23- Result , ScalarValue , exec_err, internal_err, utils:: take_function_args,
24- } ;
22+ use datafusion_common:: { Result , exec_err, internal_err, utils:: take_function_args} ;
2523use datafusion_expr:: {
2624 ArrayFunctionArgument , ArrayFunctionSignature , ColumnarValue , ReturnFieldArgs ,
2725 ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature , Volatility ,
@@ -85,21 +83,26 @@ impl ScalarUDFImpl for SparkSlice {
8583 fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
8684 let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) ) ;
8785
88- Ok ( Arc :: new ( Field :: new (
89- "slice" ,
90- args. arg_fields [ 0 ] . data_type ( ) . clone ( ) ,
91- nullable,
92- ) ) )
86+ let data_type = match args. arg_fields [ 0 ] . data_type ( ) {
87+ DataType :: Null => {
88+ DataType :: List ( Arc :: new ( Field :: new_list_field ( DataType :: Null , true ) ) )
89+ }
90+ dt => dt. clone ( ) ,
91+ } ;
92+
93+ Ok ( Arc :: new ( Field :: new ( "slice" , data_type, nullable) ) )
9394 }
9495
9596 fn invoke_with_args (
9697 & self ,
9798 mut func_args : ScalarFunctionArgs ,
9899 ) -> Result < ColumnarValue > {
99- if func_args. args [ 0 ] . data_type ( ) == DataType :: Null
100- && let Some ( result) = check_null_types ( & func_args. args [ 0 ] )
101- {
102- return Ok ( result) ;
100+ if func_args. args [ 0 ] . data_type ( ) == DataType :: Null {
101+ let len = match & func_args. args [ 0 ] {
102+ ColumnarValue :: Array ( a) => a. len ( ) ,
103+ ColumnarValue :: Scalar ( _) => func_args. number_rows ,
104+ } ;
105+ return Ok ( ColumnarValue :: Array ( list_null_array ( len) ) ) ;
103106 }
104107
105108 let array_len = func_args
@@ -136,14 +139,9 @@ impl ScalarUDFImpl for SparkSlice {
136139 }
137140}
138141
139- fn check_null_types ( cv : & ColumnarValue ) -> Option < ColumnarValue > {
140- match cv {
141- ColumnarValue :: Scalar ( ScalarValue :: Null ) => {
142- Some ( ColumnarValue :: create_null_array ( 1 ) )
143- }
144- ColumnarValue :: Array ( _) => Some ( cv. clone ( ) ) ,
145- _ => None ,
146- }
142+ fn list_null_array ( len : usize ) -> ArrayRef {
143+ let list_type = DataType :: List ( Arc :: new ( Field :: new_list_field ( DataType :: Null , true ) ) ) ;
144+ Arc :: new ( ListArray :: from ( ArrayData :: new_null ( & list_type, len) ) )
147145}
148146
149147fn calculate_start_end ( args : & [ ArrayRef ] ) -> Result < ( ArrayRef , ArrayRef ) > {
@@ -193,9 +191,30 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
193191mod tests {
194192 use super :: * ;
195193 use arrow:: array:: NullArray ;
196- use arrow:: datatypes:: DataType :: List ;
197194 use arrow:: datatypes:: Field ;
198195 use datafusion_common:: ScalarValue ;
196+ use datafusion_common:: cast:: as_list_array;
197+ use datafusion_expr:: ReturnFieldArgs ;
198+
199+ #[ test]
200+ fn test_spark_slice_function_when_input_is_null ( ) {
201+ let slice = SparkSlice :: new ( ) ;
202+ let arg_fields: Vec < Arc < Field > > = vec ! [
203+ Arc :: new( Field :: new( "a" , DataType :: Null , true ) ) ,
204+ Arc :: new( Field :: new( "s" , DataType :: Int64 , true ) ) ,
205+ Arc :: new( Field :: new( "l" , DataType :: Int64 , true ) ) ,
206+ ] ;
207+ let out = slice
208+ . return_field_from_args ( ReturnFieldArgs {
209+ arg_fields : & arg_fields,
210+ scalar_arguments : & [ ] ,
211+ } )
212+ . unwrap ( ) ;
213+ assert_eq ! (
214+ out. data_type( ) ,
215+ & DataType :: List ( Arc :: new( Field :: new_list_field( DataType :: Null , true ) ) )
216+ ) ;
217+ }
199218
200219 #[ test]
201220 fn test_spark_slice_function_when_input_array_is_null ( ) {
@@ -207,21 +226,23 @@ mod tests {
207226
208227 let args = ScalarFunctionArgs {
209228 args : input_args,
210- arg_fields : vec ! [ Arc :: new( Field :: new(
211- "item" ,
212- List ( FieldRef :: new( Field :: new( "f" , DataType :: Int64 , true ) ) ) ,
213- false ,
214- ) ) ] ,
229+ arg_fields : vec ! [ Arc :: new( Field :: new( "item" , DataType :: Null , true ) ) ] ,
215230 number_rows : 1 ,
216231 return_field : Arc :: new ( Field :: new (
217- "item " ,
218- List ( FieldRef :: new ( Field :: new_list_field ( DataType :: Int64 , true ) ) ) ,
219- false ,
232+ "slice " ,
233+ DataType :: List ( Arc :: new ( Field :: new_list_field ( DataType :: Null , true ) ) ) ,
234+ true ,
220235 ) ) ,
221236 config_options : Arc :: new ( Default :: default ( ) ) ,
222237 } ;
223238 let slice = SparkSlice :: new ( ) ;
224239 let result = slice. invoke_with_args ( args) . unwrap ( ) ;
225- assert_eq ! ( * result. to_array( 1 ) . unwrap( ) , * Arc :: new( NullArray :: new( 1 ) ) ) ;
240+ let arr = result. to_array ( 1 ) . unwrap ( ) ;
241+ let list = as_list_array ( & arr) . unwrap ( ) ;
242+ assert_eq ! (
243+ arr. data_type( ) ,
244+ & DataType :: List ( Arc :: new( Field :: new_list_field( DataType :: Null , true ) ) )
245+ ) ;
246+ assert ! ( list. is_null( 0 ) ) ;
226247 }
227248}
0 commit comments