@@ -19,7 +19,7 @@ use std::any::Any;
1919use std:: str:: from_utf8_unchecked;
2020use std:: sync:: Arc ;
2121
22- use arrow:: array:: { Array , BinaryArray , Int64Array , StringArray , StringBuilder } ;
22+ use arrow:: array:: { Array , ArrayRef , StringBuilder } ;
2323use arrow:: datatypes:: DataType ;
2424use arrow:: {
2525 array:: { as_dictionary_array, as_largestring_array, as_string_array} ,
@@ -92,11 +92,13 @@ impl ScalarUDFImpl for SparkHex {
9292 & self . signature
9393 }
9494
95- fn return_type (
96- & self ,
97- _arg_types : & [ DataType ] ,
98- ) -> datafusion_common:: Result < DataType > {
99- Ok ( DataType :: Utf8 )
95+ fn return_type ( & self , arg_types : & [ DataType ] ) -> datafusion_common:: Result < DataType > {
96+ Ok ( match & arg_types[ 0 ] {
97+ DataType :: Dictionary ( key_type, _) => {
98+ DataType :: Dictionary ( key_type. clone ( ) , Box :: new ( DataType :: Utf8 ) )
99+ }
100+ _ => DataType :: Utf8 ,
101+ } )
100102 }
101103
102104 fn invoke_with_args (
@@ -136,7 +138,7 @@ fn hex_encode_bytes<'a, I, T>(
136138 iter : I ,
137139 lowercase : bool ,
138140 len : usize ,
139- ) -> Result < ColumnarValue , DataFusionError >
141+ ) -> Result < ArrayRef , DataFusionError >
140142where
141143 I : Iterator < Item = Option < T > > ,
142144 T : AsRef < [ u8 ] > + ' a ,
@@ -166,14 +168,14 @@ where
166168 }
167169 }
168170
169- Ok ( ColumnarValue :: Array ( Arc :: new ( builder. finish ( ) ) ) )
171+ Ok ( Arc :: new ( builder. finish ( ) ) )
170172}
171173
172174/// Generic hex encoding for int64 type
173- fn hex_encode_int64 < I > ( iter : I , len : usize ) -> Result < ColumnarValue , DataFusionError >
174- where
175- I : Iterator < Item = Option < i64 > > ,
176- {
175+ fn hex_encode_int64 (
176+ iter : impl Iterator < Item = Option < i64 > > ,
177+ len : usize ,
178+ ) -> Result < ArrayRef , DataFusionError > {
177179 let mut builder = StringBuilder :: with_capacity ( len, len * 16 ) ;
178180
179181 for v in iter {
@@ -189,7 +191,7 @@ where
189191 }
190192 }
191193
192- Ok ( ColumnarValue :: Array ( Arc :: new ( builder. finish ( ) ) ) )
194+ Ok ( Arc :: new ( builder. finish ( ) ) )
193195}
194196
195197/// Spark-compatible `hex` function
@@ -215,55 +217,109 @@ pub fn compute_hex(
215217 ColumnarValue :: Array ( array) => match array. data_type ( ) {
216218 DataType :: Int64 => {
217219 let array = as_int64_array ( array) ?;
218- hex_encode_int64 ( array. iter ( ) , array. len ( ) )
220+ Ok ( ColumnarValue :: Array ( hex_encode_int64 (
221+ array. iter ( ) ,
222+ array. len ( ) ,
223+ ) ?) )
219224 }
220225 DataType :: Utf8 => {
221226 let array = as_string_array ( array) ;
222- hex_encode_bytes ( array. iter ( ) , lowercase, array. len ( ) )
227+ Ok ( ColumnarValue :: Array ( hex_encode_bytes (
228+ array. iter ( ) ,
229+ lowercase,
230+ array. len ( ) ,
231+ ) ?) )
223232 }
224233 DataType :: Utf8View => {
225234 let array = as_string_view_array ( array) ?;
226- hex_encode_bytes ( array. iter ( ) , lowercase, array. len ( ) )
235+ Ok ( ColumnarValue :: Array ( hex_encode_bytes (
236+ array. iter ( ) ,
237+ lowercase,
238+ array. len ( ) ,
239+ ) ?) )
227240 }
228241 DataType :: LargeUtf8 => {
229242 let array = as_largestring_array ( array) ;
230- hex_encode_bytes ( array. iter ( ) , lowercase, array. len ( ) )
243+ Ok ( ColumnarValue :: Array ( hex_encode_bytes (
244+ array. iter ( ) ,
245+ lowercase,
246+ array. len ( ) ,
247+ ) ?) )
231248 }
232249 DataType :: Binary => {
233250 let array = as_binary_array ( array) ?;
234- hex_encode_bytes ( array. iter ( ) , lowercase, array. len ( ) )
251+ Ok ( ColumnarValue :: Array ( hex_encode_bytes (
252+ array. iter ( ) ,
253+ lowercase,
254+ array. len ( ) ,
255+ ) ?) )
235256 }
236257 DataType :: LargeBinary => {
237258 let array = as_large_binary_array ( array) ?;
238- hex_encode_bytes ( array. iter ( ) , lowercase, array. len ( ) )
259+ Ok ( ColumnarValue :: Array ( hex_encode_bytes (
260+ array. iter ( ) ,
261+ lowercase,
262+ array. len ( ) ,
263+ ) ?) )
239264 }
240265 DataType :: FixedSizeBinary ( _) => {
241266 let array = as_fixed_size_binary_array ( array) ?;
242- hex_encode_bytes ( array. iter ( ) , lowercase, array. len ( ) )
267+ Ok ( ColumnarValue :: Array ( hex_encode_bytes (
268+ array. iter ( ) ,
269+ lowercase,
270+ array. len ( ) ,
271+ ) ?) )
243272 }
244- DataType :: Dictionary ( _, value_type) => {
273+ DataType :: Dictionary ( key_type, _) => {
274+ if * * key_type != DataType :: Int32 {
275+ return exec_err ! (
276+ "hex only supports Int32 dictionary keys, get: {}" ,
277+ key_type
278+ ) ;
279+ }
280+
245281 let dict = as_dictionary_array :: < Int32Type > ( & array) ;
282+ let dict_values = dict. values ( ) ;
246283
247- match * * value_type {
284+ let encoded_values = match dict_values . data_type ( ) {
248285 DataType :: Int64 => {
249- let arr = dict . downcast_dict :: < Int64Array > ( ) . unwrap ( ) ;
250- hex_encode_int64 ( arr. into_iter ( ) , dict . len ( ) )
286+ let arr = as_int64_array ( dict_values ) ? ;
287+ hex_encode_int64 ( arr. iter ( ) , arr . len ( ) ) ?
251288 }
252289 DataType :: Utf8 => {
253- let arr = dict. downcast_dict :: < StringArray > ( ) . unwrap ( ) ;
254- hex_encode_bytes ( arr. into_iter ( ) , lowercase, dict. len ( ) )
290+ let arr = as_string_array ( dict_values) ;
291+ hex_encode_bytes ( arr. iter ( ) , lowercase, arr. len ( ) ) ?
292+ }
293+ DataType :: LargeUtf8 => {
294+ let arr = as_largestring_array ( dict_values) ;
295+ hex_encode_bytes ( arr. iter ( ) , lowercase, arr. len ( ) ) ?
296+ }
297+ DataType :: Utf8View => {
298+ let arr = as_string_view_array ( dict_values) ?;
299+ hex_encode_bytes ( arr. iter ( ) , lowercase, arr. len ( ) ) ?
255300 }
256301 DataType :: Binary => {
257- let arr = dict. downcast_dict :: < BinaryArray > ( ) . unwrap ( ) ;
258- hex_encode_bytes ( arr. into_iter ( ) , lowercase, dict. len ( ) )
302+ let arr = as_binary_array ( dict_values) ?;
303+ hex_encode_bytes ( arr. iter ( ) , lowercase, arr. len ( ) ) ?
304+ }
305+ DataType :: LargeBinary => {
306+ let arr = as_large_binary_array ( dict_values) ?;
307+ hex_encode_bytes ( arr. iter ( ) , lowercase, arr. len ( ) ) ?
308+ }
309+ DataType :: FixedSizeBinary ( _) => {
310+ let arr = as_fixed_size_binary_array ( dict_values) ?;
311+ hex_encode_bytes ( arr. iter ( ) , lowercase, arr. len ( ) ) ?
259312 }
260313 _ => {
261- exec_err ! (
314+ return exec_err ! (
262315 "hex got an unexpected argument type: {}" ,
263- array . data_type( )
264- )
316+ dict_values . data_type( )
317+ ) ;
265318 }
266- }
319+ } ;
320+
321+ let new_dict = dict. with_values ( encoded_values) ;
322+ Ok ( ColumnarValue :: Array ( Arc :: new ( new_dict) ) )
267323 }
268324 _ => exec_err ! ( "hex got an unexpected argument type: {}" , array. data_type( ) ) ,
269325 } ,
@@ -279,11 +335,12 @@ mod test {
279335 use arrow:: array:: { DictionaryArray , Int32Array , Int64Array , StringArray } ;
280336 use arrow:: {
281337 array:: {
282- BinaryDictionaryBuilder , PrimitiveDictionaryBuilder , StringBuilder ,
283- StringDictionaryBuilder , as_string_array,
338+ BinaryDictionaryBuilder , PrimitiveDictionaryBuilder , StringDictionaryBuilder ,
339+ as_string_array,
284340 } ,
285341 datatypes:: { Int32Type , Int64Type } ,
286342 } ;
343+ use datafusion_common:: cast:: as_dictionary_array;
287344 use datafusion_expr:: ColumnarValue ;
288345
289346 #[ test]
@@ -295,12 +352,12 @@ mod test {
295352 input_builder. append_value ( "rust" ) ;
296353 let input = input_builder. finish ( ) ;
297354
298- let mut string_builder = StringBuilder :: new ( ) ;
299- string_builder . append_value ( "6869" ) ;
300- string_builder . append_value ( "627965" ) ;
301- string_builder . append_null ( ) ;
302- string_builder . append_value ( "72757374" ) ;
303- let expected = string_builder . finish ( ) ;
355+ let mut expected_builder = StringDictionaryBuilder :: < Int32Type > :: new ( ) ;
356+ expected_builder . append_value ( "6869" ) ;
357+ expected_builder . append_value ( "627965" ) ;
358+ expected_builder . append_null ( ) ;
359+ expected_builder . append_value ( "72757374" ) ;
360+ let expected = expected_builder . finish ( ) ;
304361
305362 let columnar_value = ColumnarValue :: Array ( Arc :: new ( input) ) ;
306363 let result = super :: spark_hex ( & [ columnar_value] ) . unwrap ( ) ;
@@ -310,7 +367,7 @@ mod test {
310367 _ => panic ! ( "Expected array" ) ,
311368 } ;
312369
313- let result = as_string_array ( & result) ;
370+ let result = as_dictionary_array ( & result) . unwrap ( ) ;
314371
315372 assert_eq ! ( result, & expected) ;
316373 }
@@ -324,12 +381,12 @@ mod test {
324381 input_builder. append_value ( 3 ) ;
325382 let input = input_builder. finish ( ) ;
326383
327- let mut string_builder = StringBuilder :: new ( ) ;
328- string_builder . append_value ( "1" ) ;
329- string_builder . append_value ( "2" ) ;
330- string_builder . append_null ( ) ;
331- string_builder . append_value ( "3" ) ;
332- let expected = string_builder . finish ( ) ;
384+ let mut expected_builder = StringDictionaryBuilder :: < Int32Type > :: new ( ) ;
385+ expected_builder . append_value ( "1" ) ;
386+ expected_builder . append_value ( "2" ) ;
387+ expected_builder . append_null ( ) ;
388+ expected_builder . append_value ( "3" ) ;
389+ let expected = expected_builder . finish ( ) ;
333390
334391 let columnar_value = ColumnarValue :: Array ( Arc :: new ( input) ) ;
335392 let result = super :: spark_hex ( & [ columnar_value] ) . unwrap ( ) ;
@@ -339,7 +396,7 @@ mod test {
339396 _ => panic ! ( "Expected array" ) ,
340397 } ;
341398
342- let result = as_string_array ( & result) ;
399+ let result = as_dictionary_array ( & result) . unwrap ( ) ;
343400
344401 assert_eq ! ( result, & expected) ;
345402 }
@@ -353,7 +410,7 @@ mod test {
353410 input_builder. append_value ( "3" ) ;
354411 let input = input_builder. finish ( ) ;
355412
356- let mut expected_builder = StringBuilder :: new ( ) ;
413+ let mut expected_builder = StringDictionaryBuilder :: < Int32Type > :: new ( ) ;
357414 expected_builder. append_value ( "31" ) ;
358415 expected_builder. append_value ( "6A" ) ;
359416 expected_builder. append_null ( ) ;
@@ -368,7 +425,7 @@ mod test {
368425 _ => panic ! ( "Expected array" ) ,
369426 } ;
370427
371- let result = as_string_array ( & result) ;
428+ let result = as_dictionary_array ( & result) . unwrap ( ) ;
372429
373430 assert_eq ! ( result, & expected) ;
374431 }
@@ -425,8 +482,11 @@ mod test {
425482 _ => panic ! ( "Expected array" ) ,
426483 } ;
427484
428- let result = as_string_array ( & result) ;
429- let expected = StringArray :: from ( vec ! [ Some ( "20" ) , None , None ] ) ;
485+ let result = as_dictionary_array ( & result) . unwrap ( ) ;
486+
487+ let keys = Int32Array :: from ( vec ! [ Some ( 0 ) , None , Some ( 1 ) ] ) ;
488+ let vals = StringArray :: from ( vec ! [ Some ( "20" ) , None ] ) ;
489+ let expected = DictionaryArray :: new ( keys, Arc :: new ( vals) ) ;
430490
431491 assert_eq ! ( & expected, result) ;
432492 }
0 commit comments