1717
1818//! `MAP_AGG` aggregate implementation: [`MapAgg`]
1919
20- use std:: collections:: VecDeque ;
20+ use std:: collections:: { HashSet , VecDeque } ;
2121use std:: mem:: { size_of, size_of_val, take} ;
2222use std:: sync:: Arc ;
2323
@@ -26,7 +26,7 @@ use arrow::buffer::{OffsetBuffer, ScalarBuffer};
2626use arrow:: compute:: SortOptions ;
2727use arrow:: datatypes:: { DataType , Field , FieldRef , Fields } ;
2828
29- use datafusion_common:: utils:: { compare_rows, get_row_at_idx} ;
29+ use datafusion_common:: utils:: { SingleRowListArrayBuilder , compare_rows, get_row_at_idx} ;
3030use datafusion_common:: { Result , ScalarValue , exec_err} ;
3131use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
3232use datafusion_expr:: utils:: format_state_name;
@@ -150,13 +150,11 @@ impl AggregateUDFImpl for MapAgg {
150150}
151151
152152fn map_type ( key_type : & DataType , value_type : & DataType ) -> DataType {
153- let key_field = Arc :: new ( Field :: new ( "key" , key_type. clone ( ) , false ) ) ;
154- let value_field = Arc :: new ( Field :: new ( "value" , value_type. clone ( ) , true ) ) ;
155- let entries_field = Arc :: new ( Field :: new (
156- "entries" ,
157- DataType :: Struct ( Fields :: from ( vec ! [ key_field, value_field] ) ) ,
158- false ,
159- ) ) ;
153+ let fields = Fields :: from ( vec ! [
154+ Field :: new( "key" , key_type. clone( ) , false ) ,
155+ Field :: new( "value" , value_type. clone( ) , true ) ,
156+ ] ) ;
157+ let entries_field = Arc :: new ( Field :: new ( "entries" , DataType :: Struct ( fields) , false ) ) ;
160158 DataType :: Map ( entries_field, false )
161159}
162160
@@ -168,14 +166,13 @@ fn build_single_map(
168166) -> Result < ArrayRef > {
169167 debug_assert_eq ! ( keys. len( ) , values. len( ) ) ;
170168
171- let key_field = Arc :: new ( Field :: new ( "key" , key_type. clone ( ) , false ) ) ;
172- let value_field = Arc :: new ( Field :: new ( "value" , value_type. clone ( ) , true ) ) ;
169+ let fields = Fields :: from ( vec ! [
170+ Field :: new( "key" , key_type. clone( ) , false ) ,
171+ Field :: new( "value" , value_type. clone( ) , true ) ,
172+ ] ) ;
173173 let entries_field = Arc :: new ( Field :: new (
174174 "entries" ,
175- DataType :: Struct ( Fields :: from ( vec ! [
176- Arc :: clone( & key_field) ,
177- Arc :: clone( & value_field) ,
178- ] ) ) ,
175+ DataType :: Struct ( fields. clone ( ) ) ,
179176 false ,
180177 ) ) ;
181178
@@ -191,11 +188,7 @@ fn build_single_map(
191188 ScalarValue :: iter_to_array ( values) ?
192189 } ;
193190
194- let entries = StructArray :: try_new (
195- Fields :: from ( vec ! [ key_field, value_field] ) ,
196- vec ! [ key_array, value_array] ,
197- None ,
198- ) ?;
191+ let entries = StructArray :: try_new ( fields, vec ! [ key_array, value_array] , None ) ?;
199192
200193 let offsets = OffsetBuffer :: new ( ScalarBuffer :: from ( vec ! [ 0i32 , len as i32 ] ) ) ;
201194 Ok ( Arc :: new ( MapArray :: try_new (
@@ -207,27 +200,26 @@ fn build_single_map(
207200 ) ?) )
208201}
209202
210- /// De-duplicates parallel key/value vectors keeping the first value seen for
211- /// each key.
212203fn dedup_first_wins (
213204 keys : Vec < ScalarValue > ,
214205 values : Vec < ScalarValue > ,
215206) -> ( Vec < ScalarValue > , Vec < ScalarValue > ) {
216- use std:: collections:: HashSet ;
217-
218- let mut seen: HashSet < ScalarValue > = HashSet :: with_capacity ( keys. len ( ) ) ;
219- let mut out_keys: Vec < ScalarValue > = Vec :: with_capacity ( keys. len ( ) ) ;
220- let mut out_vals: Vec < ScalarValue > = Vec :: with_capacity ( keys. len ( ) ) ;
221-
222- for ( k, v) in keys. into_iter ( ) . zip ( values) {
223- // Keep only the first occurrence of each key; later ones are dropped.
224- if seen. insert ( k. clone ( ) ) {
225- out_keys. push ( k) ;
226- out_vals. push ( v) ;
227- }
228- }
229-
230- ( out_keys, out_vals)
207+ // First pass: mark each position that is the first occurrence of its key.
208+ let mut seen = HashSet :: with_capacity ( keys. len ( ) ) ;
209+ let keep: Vec < bool > = keys. iter ( ) . map ( |k| seen. insert ( k) ) . collect ( ) ;
210+
211+ // Second pass: keep only the first-occurrence positions.
212+ let out_keys = keys
213+ . into_iter ( )
214+ . zip ( & keep)
215+ . filter_map ( |( k, & keep) | keep. then_some ( k) )
216+ . collect ( ) ;
217+ let out_values = values
218+ . into_iter ( )
219+ . zip ( & keep)
220+ . filter_map ( |( v, & keep) | keep. then_some ( v) )
221+ . collect ( ) ;
222+ ( out_keys, out_values)
231223}
232224
233225/// Plain accumulator used when there is no `ORDER BY`.
@@ -388,12 +380,7 @@ impl OrderSensitiveMapAggAccumulator {
388380 }
389381
390382 let struct_array = StructArray :: try_new ( struct_field, column_wise, None ) ?;
391- Ok (
392- datafusion_common:: utils:: SingleRowListArrayBuilder :: new ( Arc :: new (
393- struct_array,
394- ) )
395- . build_list_scalar ( ) ,
396- )
383+ Ok ( SingleRowListArrayBuilder :: new ( Arc :: new ( struct_array) ) . build_list_scalar ( ) )
397384 }
398385}
399386
0 commit comments