1818//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
2020use std:: cmp:: Ordering ;
21- use std:: collections:: { HashSet , VecDeque } ;
21+ use std:: collections:: { HashMap , VecDeque } ;
2222use std:: mem:: { size_of, size_of_val, take} ;
2323use std:: sync:: Arc ;
2424
@@ -34,7 +34,9 @@ use datafusion_common::cast::as_list_array;
3434use datafusion_common:: utils:: {
3535 SingleRowListArrayBuilder , compare_rows, get_row_at_idx, take_function_args,
3636} ;
37- use datafusion_common:: { Result , ScalarValue , assert_eq_or_internal_err, exec_err} ;
37+ use datafusion_common:: {
38+ Result , ScalarValue , assert_eq_or_internal_err, exec_err, internal_err,
39+ } ;
3840use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
3941use datafusion_expr:: utils:: format_state_name;
4042use datafusion_expr:: {
@@ -814,7 +816,10 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {
814816
815817#[ derive( Debug ) ]
816818pub struct DistinctArrayAggAccumulator {
817- values : HashSet < ScalarValue > ,
819+ // Value → live refcount. Multiset state lets `retract_batch` correctly
820+ // drop a duplicate occurrence while keeping the key alive if other
821+ // copies remain in the current window frame.
822+ values : HashMap < ScalarValue , u64 > ,
818823 datatype : DataType ,
819824 sort_options : Option < SortOptions > ,
820825 ignore_nulls : bool ,
@@ -827,7 +832,7 @@ impl DistinctArrayAggAccumulator {
827832 ignore_nulls : bool ,
828833 ) -> Result < Self > {
829834 Ok ( Self {
830- values : HashSet :: new ( ) ,
835+ values : HashMap :: new ( ) ,
831836 datatype : datatype. clone ( ) ,
832837 sort_options,
833838 ignore_nulls,
@@ -856,8 +861,8 @@ impl Accumulator for DistinctArrayAggAccumulator {
856861 if nulls. is_none_or ( |nulls| nulls. null_count ( ) < val. len ( ) ) {
857862 for i in 0 ..val. len ( ) {
858863 if nulls. is_none_or ( |nulls| nulls. is_valid ( i) ) {
859- self . values
860- . insert ( ScalarValue :: try_from_array ( val , i ) ? . compacted ( ) ) ;
864+ let key = ScalarValue :: try_from_array ( val , i ) ? . compacted ( ) ;
865+ * self . values . entry ( key ) . or_insert ( 0 ) += 1 ;
861866 }
862867 }
863868 }
@@ -872,6 +877,12 @@ impl Accumulator for DistinctArrayAggAccumulator {
872877
873878 assert_eq_or_internal_err ! ( states. len( ) , 1 , "expects single state" ) ;
874879
880+ // The DISTINCT state schema is `List<value>` — partial accumulators
881+ // ship the set of values they saw, not multiplicities. Re-ingesting
882+ // each element here makes the merged counts represent "partitions
883+ // that emitted this value," which is fine because `evaluate` only
884+ // reads keys. Refcount semantics for retract are only valid within
885+ // a single accumulator instance (window execution).
875886 states[ 0 ]
876887 . as_list :: < i32 > ( )
877888 . iter ( )
@@ -880,7 +891,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
880891 }
881892
882893 fn evaluate ( & mut self ) -> Result < ScalarValue > {
883- let mut values: Vec < ScalarValue > = self . values . iter ( ) . cloned ( ) . collect ( ) ;
894+ let mut values: Vec < ScalarValue > = self . values . keys ( ) . cloned ( ) . collect ( ) ;
884895 if values. is_empty ( ) {
885896 return Ok ( ScalarValue :: new_null_list ( self . datatype . clone ( ) , true , 1 ) ) ;
886897 }
@@ -916,8 +927,50 @@ impl Accumulator for DistinctArrayAggAccumulator {
916927 Ok ( ScalarValue :: List ( arr) )
917928 }
918929
930+ fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
931+ if values. is_empty ( ) {
932+ return Ok ( ( ) ) ;
933+ }
934+
935+ assert_eq_or_internal_err ! ( values. len( ) , 1 , "expects single batch" ) ;
936+
937+ let val = & values[ 0 ] ;
938+ let nulls = if self . ignore_nulls {
939+ val. logical_nulls ( )
940+ } else {
941+ None
942+ } ;
943+ let nulls = nulls. as_ref ( ) ;
944+
945+ for i in 0 ..val. len ( ) {
946+ if nulls. is_some_and ( |nulls| !nulls. is_valid ( i) ) {
947+ continue ;
948+ }
949+ let key = ScalarValue :: try_from_array ( val, i) ?. compacted ( ) ;
950+ match self . values . get_mut ( & key) {
951+ Some ( count) => {
952+ * count -= 1 ;
953+ if * count == 0 {
954+ self . values . remove ( & key) ;
955+ }
956+ }
957+ None => {
958+ return internal_err ! (
959+ "DistinctArrayAggAccumulator::retract_batch: value not present in state"
960+ ) ;
961+ }
962+ }
963+ }
964+
965+ Ok ( ( ) )
966+ }
967+
968+ fn supports_retract_batch ( & self ) -> bool {
969+ true
970+ }
971+
919972 fn size ( & self ) -> usize {
920- size_of_val ( self ) + ScalarValue :: size_of_hashset ( & self . values )
973+ size_of_val ( self ) + ScalarValue :: size_of_hashmap ( & self . values )
921974 - size_of_val ( & self . values )
922975 + self . datatype . size ( )
923976 - size_of_val ( & self . datatype )
@@ -1494,8 +1547,8 @@ mod tests {
14941547 acc2. update_batch ( & [ string_list_data ( [ vec ! [ "e" , "f" , "g" ] ] ) ] ) ?;
14951548 acc1 = merge ( acc1, acc2) ?;
14961549
1497- // without compaction, the size is 16660
1498- assert_eq ! ( acc1. size( ) , 1660 ) ;
1550+ // without compaction, the size is 16684
1551+ assert_eq ! ( acc1. size( ) , 1684 ) ;
14991552
15001553 Ok ( ( ) )
15011554 }
@@ -2415,4 +2468,126 @@ mod tests {
24152468
24162469 Ok ( ( ) )
24172470 }
2471+
2472+ // ---- DistinctArrayAggAccumulator retract_batch tests ----
2473+
2474+ // Build a DISTINCT accumulator with ascending sort so evaluate output is
2475+ // deterministic regardless of HashMap iteration order.
2476+ fn distinct_acc ( ignore_nulls : bool ) -> Result < DistinctArrayAggAccumulator > {
2477+ DistinctArrayAggAccumulator :: try_new (
2478+ & DataType :: Utf8 ,
2479+ Some ( SortOptions :: default ( ) ) ,
2480+ ignore_nulls,
2481+ )
2482+ }
2483+
2484+ #[ test]
2485+ fn distinct_retract_duplicate_remains ( ) -> Result < ( ) > {
2486+ // Canonical regression for the HashSet-can't-retract bug: a value
2487+ // that appears multiple times in-frame must survive retraction of
2488+ // a single occurrence.
2489+ let mut acc = distinct_acc ( false ) ?;
2490+
2491+ // Feed [A, A, B] across two batches to exercise multi-batch state.
2492+ acc. update_batch ( & [ data ( [ "A" , "A" ] ) ] ) ?;
2493+ acc. update_batch ( & [ data ( [ "B" ] ) ] ) ?;
2494+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "A" , "B" ] ) ;
2495+
2496+ // Retract a single A — the other A is still in the frame.
2497+ acc. retract_batch ( & [ data ( [ "A" ] ) ] ) ?;
2498+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "A" , "B" ] ) ;
2499+
2500+ // Retract the remaining A — only B left.
2501+ acc. retract_batch ( & [ data ( [ "A" ] ) ] ) ?;
2502+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "B" ] ) ;
2503+
2504+ Ok ( ( ) )
2505+ }
2506+
2507+ #[ test]
2508+ fn distinct_retract_full_removal ( ) -> Result < ( ) > {
2509+ let mut acc = distinct_acc ( false ) ?;
2510+
2511+ acc. update_batch ( & [ data ( [ "A" , "B" ] ) ] ) ?;
2512+ acc. retract_batch ( & [ data ( [ "A" , "B" ] ) ] ) ?;
2513+
2514+ let result = acc. evaluate ( ) ?;
2515+ assert ! (
2516+ matches!( & result, ScalarValue :: List ( arr) if arr. is_null( 0 ) ) ,
2517+ "expected null list after full retract, got {result:?}"
2518+ ) ;
2519+
2520+ Ok ( ( ) )
2521+ }
2522+
2523+ #[ test]
2524+ fn distinct_retract_ignore_nulls_skips ( ) -> Result < ( ) > {
2525+ // ignore_nulls=true: NULL never enters state on update, so retract
2526+ // must also skip NULL — otherwise we'd error on the missing key.
2527+ let mut acc = distinct_acc ( true ) ?;
2528+
2529+ acc. update_batch ( & [ data ( [ Some ( "A" ) , None , Some ( "B" ) ] ) ] ) ?;
2530+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "A" , "B" ] ) ;
2531+
2532+ // Retract [A, NULL] — the NULL is skipped, only A is removed.
2533+ acc. retract_batch ( & [ data ( [ Some ( "A" ) , None ] ) ] ) ?;
2534+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "B" ] ) ;
2535+
2536+ Ok ( ( ) )
2537+ }
2538+
2539+ #[ test]
2540+ fn distinct_retract_null_tracked ( ) -> Result < ( ) > {
2541+ // ignore_nulls=false: NULL enters state with a refcount and must
2542+ // retract symmetrically; the NULL key must be removed at zero
2543+ // (else evaluate still emits a NULL element).
2544+ let mut acc = distinct_acc ( false ) ?;
2545+
2546+ acc. update_batch ( & [ data ( [ Some ( "A" ) , None , None ] ) ] ) ?;
2547+ // With nulls_first=true (SortOptions default), NULL sorts before A.
2548+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "NULL" , "A" ] ) ;
2549+
2550+ // Retract one NULL — count drops to 1, key still present.
2551+ acc. retract_batch ( & [ data :: < Option < & str > , 1 > ( [ None ] ) ] ) ?;
2552+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "NULL" , "A" ] ) ;
2553+
2554+ // Retract the remaining NULL — key is removed.
2555+ acc. retract_batch ( & [ data :: < Option < & str > , 1 > ( [ None ] ) ] ) ?;
2556+ assert_eq ! ( print_nulls( str_arr( acc. evaluate( ) ?) ?) , vec![ "A" ] ) ;
2557+
2558+ Ok ( ( ) )
2559+ }
2560+
2561+ #[ test]
2562+ fn distinct_supports_retract_batch ( ) -> Result < ( ) > {
2563+ let acc = distinct_acc ( false ) ?;
2564+ assert ! ( acc. supports_retract_batch( ) ) ;
2565+
2566+ let acc_ignore = distinct_acc ( true ) ?;
2567+ assert ! ( acc_ignore. supports_retract_batch( ) ) ;
2568+
2569+ Ok ( ( ) )
2570+ }
2571+
2572+ #[ test]
2573+ fn distinct_merge_then_evaluate_regression ( ) -> Result < ( ) > {
2574+ // Non-window path: state -> merge_batch -> evaluate must still
2575+ // produce the union of distinct values across partitions.
2576+ let mut acc1 = distinct_acc ( false ) ?;
2577+ let mut acc2 = distinct_acc ( false ) ?;
2578+
2579+ acc1. update_batch ( & [ data ( [ "A" , "A" , "B" ] ) ] ) ?;
2580+ acc2. update_batch ( & [ data ( [ "A" , "C" ] ) ] ) ?;
2581+
2582+ let state = acc2. state ( ) ?;
2583+ let state_arrs: Vec < ArrayRef > = state
2584+ . into_iter ( )
2585+ . map ( |sv| sv. to_array_of_size ( 1 ) )
2586+ . collect :: < Result < Vec < _ > > > ( ) ?;
2587+ acc1. merge_batch ( & state_arrs) ?;
2588+
2589+ assert_eq ! ( print_nulls( str_arr( acc1. evaluate( ) ?) ?) , vec![ "A" , "B" , "C" ] ) ;
2590+
2591+ Ok ( ( ) )
2592+ }
24182593}
0 commit comments