@@ -20,11 +20,13 @@ use std::sync::Arc;
2020use arrow:: datatypes:: SchemaRef ;
2121use arrow:: record_batch:: RecordBatch ;
2222use datafusion_common:: { Result , internal_err} ;
23+ use datafusion_expr:: EmitTo ;
2324
2425use crate :: aggregates:: AggregateExec ;
2526
2627use super :: common:: {
27- AggregateHashTable , AggregateHashTableState , FinalMarker , emit_to_for_batch_size,
28+ AggregateHashTable , AggregateHashTableBuffer , AggregateHashTableState , FinalMarker ,
29+ MaterializedFinalOutput ,
2830} ;
2931
3032/// Methods specific to the aggregate hash table used in the final aggregation stage.
@@ -55,30 +57,19 @@ impl AggregateHashTable<FinalMarker> {
5557 ) -> Result < Option < RecordBatch > > {
5658 let output_schema = Arc :: clone ( & self . output_schema ) ;
5759 let batch_size = self . batch_size ;
58- match & mut self . state {
60+ // Take ownership of the output state. Note `emit_next_materialized_batch`
61+ // updates state after it emits a materialized slice.
62+ match std:: mem:: replace ( & mut self . state , AggregateHashTableState :: Done ) {
5963 AggregateHashTableState :: Outputting ( state) => {
6064 if state. group_values . is_empty ( ) {
61- self . state = AggregateHashTableState :: Done ;
6265 return Ok ( None ) ;
6366 }
6467
65- let emit_to =
66- emit_to_for_batch_size ( batch_size, state. group_values . len ( ) ) ;
67- let timer = self . group_by_metrics . emitting_time . timer ( ) ;
68- let mut output = state. group_values . emit ( emit_to) ?;
69-
70- for acc in state. accumulators . iter_mut ( ) {
71- output. push ( acc. evaluate ( emit_to) ?) ;
72- }
73- let done = state. group_values . is_empty ( ) ;
74- drop ( timer) ;
75-
76- let batch = RecordBatch :: try_new ( output_schema, output) ?;
77- debug_assert ! ( batch. num_rows( ) > 0 ) ;
78- if done {
79- self . state = AggregateHashTableState :: Done ;
80- }
81- Ok ( Some ( batch) )
68+ let output = self . materialize_final_output ( state, output_schema) ?;
69+ Ok ( self . emit_next_materialized_batch ( output, batch_size) )
70+ }
71+ AggregateHashTableState :: OutputtingMaterializedFinal ( output) => {
72+ Ok ( self . emit_next_materialized_batch ( output, batch_size) )
8273 }
8374 AggregateHashTableState :: Done => Ok ( None ) ,
8475 AggregateHashTableState :: Building ( _) => {
@@ -87,6 +78,41 @@ impl AggregateHashTable<FinalMarker> {
8778 }
8879 }
8980
81+ fn materialize_final_output (
82+ & self ,
83+ mut state : AggregateHashTableBuffer ,
84+ output_schema : SchemaRef ,
85+ ) -> Result < MaterializedFinalOutput > {
86+ // Final aggregate evaluation consumes accumulator state. Evaluate all
87+ // groups once, then slice the materialized batch on subsequent polls.
88+ let emit_to = EmitTo :: All ;
89+ let timer = self . group_by_metrics . emitting_time . timer ( ) ;
90+ let mut output = state. group_values . emit ( emit_to) ?;
91+
92+ for acc in state. accumulators . iter_mut ( ) {
93+ output. push ( acc. evaluate ( emit_to) ?) ;
94+ }
95+ drop ( timer) ;
96+
97+ let batch = RecordBatch :: try_new ( output_schema, output) ?;
98+ debug_assert ! ( batch. num_rows( ) > 0 ) ;
99+ Ok ( MaterializedFinalOutput :: new ( batch) )
100+ }
101+
102+ fn emit_next_materialized_batch (
103+ & mut self ,
104+ mut output : MaterializedFinalOutput ,
105+ batch_size : usize ,
106+ ) -> Option < RecordBatch > {
107+ let batch = output. next_batch ( batch_size) ;
108+ if output. is_exhausted ( ) {
109+ self . state = AggregateHashTableState :: Done ;
110+ } else {
111+ self . state = AggregateHashTableState :: OutputtingMaterializedFinal ( output) ;
112+ }
113+ batch
114+ }
115+
90116 pub ( in crate :: aggregates) fn aggregate_batch (
91117 & mut self ,
92118 batch : & RecordBatch ,
0 commit comments