1818use std:: sync:: Arc ;
1919
2020use arrow:: array:: {
21- Array , BooleanArray , Capacities , MutableArrayData , Scalar , make_array,
21+ Array , BooleanArray , Capacities , MutableArrayData , Scalar , cast :: AsArray , make_array,
2222 make_comparator,
2323} ;
2424use arrow:: compute:: SortOptions ;
@@ -27,7 +27,7 @@ use arrow_buffer::NullBuffer;
2727
2828use datafusion_common:: cast:: { as_map_array, as_struct_array} ;
2929use datafusion_common:: {
30- Result , ScalarValue , exec_err, internal_err, plan_datafusion_err,
30+ Result , ScalarValue , exec_datafusion_err , exec_err, internal_err, plan_datafusion_err,
3131} ;
3232use datafusion_expr:: expr:: ScalarFunction ;
3333use datafusion_expr:: simplify:: ExprSimplifyResult ;
@@ -198,6 +198,24 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
198198 let string_value = name. try_as_str ( ) . flatten ( ) . map ( |s| s. to_string ( ) ) ;
199199
200200 match ( array. data_type ( ) , name, string_value) {
201+ // Dictionary-encoded struct: extract the field from the dictionary's
202+ // values (the deduplicated struct array) and rebuild a dictionary with
203+ // the same keys. This preserves dictionary encoding without expanding.
204+ ( DataType :: Dictionary ( _, value_type) , _, Some ( field_name) )
205+ if matches ! ( value_type. as_ref( ) , DataType :: Struct ( _) ) =>
206+ {
207+ let dict = array. as_any_dictionary ( ) ;
208+ let values_struct = dict. values ( ) . as_struct ( ) ;
209+ let field_col =
210+ values_struct. column_by_name ( & field_name) . ok_or_else ( || {
211+ exec_datafusion_err ! (
212+ "Field {field_name} not found in dictionary struct"
213+ )
214+ } ) ?;
215+ Ok ( ColumnarValue :: Array (
216+ dict. with_values ( Arc :: clone ( field_col) ) ,
217+ ) )
218+ }
201219 ( DataType :: Map ( _, _) , ScalarValue :: List ( arr) , _) => {
202220 let key_array: Arc < dyn Array > = arr;
203221 process_map_array ( & array, key_array)
@@ -333,6 +351,42 @@ impl ScalarUDFImpl for GetFieldFunc {
333351 }
334352 }
335353 }
354+ // Dictionary-encoded struct: resolve the child field from
355+ // the underlying struct, then wrap the result back in the
356+ // same Dictionary type so the promised type matches execution.
357+ DataType :: Dictionary ( key_type, value_type)
358+ if matches ! ( value_type. as_ref( ) , DataType :: Struct ( _) ) =>
359+ {
360+ let DataType :: Struct ( fields) = value_type. as_ref ( ) else {
361+ unreachable ! ( )
362+ } ;
363+ let field_name = sv
364+ . as_ref ( )
365+ . and_then ( |sv| {
366+ sv. try_as_str ( ) . flatten ( ) . filter ( |s| !s. is_empty ( ) )
367+ } )
368+ . ok_or_else ( || {
369+ exec_datafusion_err ! ( "Field name must be a non-empty string" )
370+ } ) ?;
371+
372+ let child_field = fields
373+ . iter ( )
374+ . find ( |f| f. name ( ) == field_name)
375+ . ok_or_else ( || {
376+ plan_datafusion_err ! ( "Field {field_name} not found in struct" )
377+ } ) ?;
378+
379+ let dict_type = DataType :: Dictionary (
380+ key_type. clone ( ) ,
381+ Box :: new ( child_field. data_type ( ) . clone ( ) ) ,
382+ ) ;
383+ let mut new_field =
384+ child_field. as_ref ( ) . clone ( ) . with_data_type ( dict_type) ;
385+ if current_field. is_nullable ( ) {
386+ new_field = new_field. with_nullable ( true ) ;
387+ }
388+ current_field = Arc :: new ( new_field) ;
389+ }
336390 DataType :: Struct ( fields) => {
337391 let field_name = sv
338392 . as_ref ( )
@@ -560,6 +614,133 @@ mod tests {
560614 Ok ( ( ) )
561615 }
562616
617+ #[ test]
618+ fn test_get_field_dict_encoded_struct ( ) -> Result < ( ) > {
619+ use arrow:: array:: { DictionaryArray , StringArray , UInt32Array } ;
620+ use arrow:: datatypes:: UInt32Type ;
621+
622+ let names = Arc :: new ( StringArray :: from ( vec ! [ "main" , "foo" , "bar" ] ) ) as ArrayRef ;
623+ let ids = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 ] ) ) as ArrayRef ;
624+
625+ let struct_fields: Fields = vec ! [
626+ Field :: new( "name" , DataType :: Utf8 , false ) ,
627+ Field :: new( "id" , DataType :: Int32 , false ) ,
628+ ]
629+ . into ( ) ;
630+
631+ let values_struct =
632+ Arc :: new ( StructArray :: new ( struct_fields, vec ! [ names, ids] , None ) ) as ArrayRef ;
633+
634+ let keys = UInt32Array :: from ( vec ! [ 0u32 , 1 , 2 , 0 , 1 ] ) ;
635+ let dict = DictionaryArray :: < UInt32Type > :: try_new ( keys, values_struct) ?;
636+
637+ let base = ColumnarValue :: Array ( Arc :: new ( dict) ) ;
638+ let key = ScalarValue :: Utf8 ( Some ( "name" . to_string ( ) ) ) ;
639+
640+ let result = extract_single_field ( base, key) ?;
641+ let result_array = result. into_array ( 5 ) ?;
642+
643+ assert ! (
644+ matches!( result_array. data_type( ) , DataType :: Dictionary ( _, _) ) ,
645+ "expected dictionary output, got {:?}" ,
646+ result_array. data_type( )
647+ ) ;
648+
649+ let result_dict = result_array
650+ . as_any ( )
651+ . downcast_ref :: < DictionaryArray < UInt32Type > > ( )
652+ . unwrap ( ) ;
653+ assert_eq ! ( result_dict. values( ) . len( ) , 3 ) ;
654+ assert_eq ! ( result_dict. len( ) , 5 ) ;
655+
656+ let resolved = arrow:: compute:: cast ( & result_array, & DataType :: Utf8 ) ?;
657+ let string_arr = resolved. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
658+ assert_eq ! ( string_arr. value( 0 ) , "main" ) ;
659+ assert_eq ! ( string_arr. value( 1 ) , "foo" ) ;
660+ assert_eq ! ( string_arr. value( 2 ) , "bar" ) ;
661+ assert_eq ! ( string_arr. value( 3 ) , "main" ) ;
662+ assert_eq ! ( string_arr. value( 4 ) , "foo" ) ;
663+
664+ Ok ( ( ) )
665+ }
666+
667+ #[ test]
668+ fn test_get_field_nested_dict_struct ( ) -> Result < ( ) > {
669+ use arrow:: array:: { DictionaryArray , StringArray , UInt32Array } ;
670+ use arrow:: datatypes:: UInt32Type ;
671+
672+ let func_names = Arc :: new ( StringArray :: from ( vec ! [ "main" , "foo" ] ) ) as ArrayRef ;
673+ let func_files = Arc :: new ( StringArray :: from ( vec ! [ "main.c" , "foo.c" ] ) ) as ArrayRef ;
674+ let func_fields: Fields = vec ! [
675+ Field :: new( "name" , DataType :: Utf8 , false ) ,
676+ Field :: new( "file" , DataType :: Utf8 , false ) ,
677+ ]
678+ . into ( ) ;
679+ let func_struct = Arc :: new ( StructArray :: new (
680+ func_fields. clone ( ) ,
681+ vec ! [ func_names, func_files] ,
682+ None ,
683+ ) ) as ArrayRef ;
684+ let func_dict = Arc :: new ( DictionaryArray :: < UInt32Type > :: try_new (
685+ UInt32Array :: from ( vec ! [ 0u32 , 1 , 0 ] ) ,
686+ func_struct,
687+ ) ?) as ArrayRef ;
688+
689+ let line_nums = Arc :: new ( Int32Array :: from ( vec ! [ 10 , 20 , 30 ] ) ) as ArrayRef ;
690+ let line_fields: Fields = vec ! [
691+ Field :: new( "num" , DataType :: Int32 , false ) ,
692+ Field :: new(
693+ "function" ,
694+ DataType :: Dictionary (
695+ Box :: new( DataType :: UInt32 ) ,
696+ Box :: new( DataType :: Struct ( func_fields) ) ,
697+ ) ,
698+ false ,
699+ ) ,
700+ ]
701+ . into ( ) ;
702+ let line_struct = StructArray :: new ( line_fields, vec ! [ line_nums, func_dict] , None ) ;
703+
704+ let base = ColumnarValue :: Array ( Arc :: new ( line_struct) ) ;
705+
706+ let func_result =
707+ extract_single_field ( base, ScalarValue :: Utf8 ( Some ( "function" . to_string ( ) ) ) ) ?;
708+
709+ let func_array = func_result. into_array ( 3 ) ?;
710+ assert ! (
711+ matches!( func_array. data_type( ) , DataType :: Dictionary ( _, _) ) ,
712+ "expected dictionary for function, got {:?}" ,
713+ func_array. data_type( )
714+ ) ;
715+
716+ let name_result = extract_single_field (
717+ ColumnarValue :: Array ( func_array) ,
718+ ScalarValue :: Utf8 ( Some ( "name" . to_string ( ) ) ) ,
719+ ) ?;
720+ let name_array = name_result. into_array ( 3 ) ?;
721+
722+ assert ! (
723+ matches!( name_array. data_type( ) , DataType :: Dictionary ( _, _) ) ,
724+ "expected dictionary for name, got {:?}" ,
725+ name_array. data_type( )
726+ ) ;
727+
728+ let name_dict = name_array
729+ . as_any ( )
730+ . downcast_ref :: < DictionaryArray < UInt32Type > > ( )
731+ . unwrap ( ) ;
732+ assert_eq ! ( name_dict. values( ) . len( ) , 2 ) ;
733+ assert_eq ! ( name_dict. len( ) , 3 ) ;
734+
735+ let resolved = arrow:: compute:: cast ( & name_array, & DataType :: Utf8 ) ?;
736+ let strings = resolved. as_any ( ) . downcast_ref :: < StringArray > ( ) . unwrap ( ) ;
737+ assert_eq ! ( strings. value( 0 ) , "main" ) ;
738+ assert_eq ! ( strings. value( 1 ) , "foo" ) ;
739+ assert_eq ! ( strings. value( 2 ) , "main" ) ;
740+
741+ Ok ( ( ) )
742+ }
743+
563744 #[ test]
564745 fn test_placement_literal_key ( ) {
565746 let func = GetFieldFunc :: new ( ) ;
0 commit comments