@@ -63,7 +63,7 @@ pub(crate) enum ExecutionState {
6363 ReadingInput ,
6464 /// When producing output, the remaining rows to output are stored
6565 /// here and are sliced off as needed in batch_size chunks
66- ProducingOutput ( RecordBatch ) ,
66+ ProducingOutput ( EmitTo ) ,
6767 /// Produce intermediate aggregate state for each input row without
6868 /// aggregation.
6969 ///
@@ -753,10 +753,8 @@ impl Stream for GroupedHashAggregateStream {
753753 && let Some ( to_emit) = self . group_ordering . emit_to ( )
754754 {
755755 timer. done ( ) ;
756- if let Some ( batch) = self . emit ( to_emit, false ) ? {
757- self . exec_state =
758- ExecutionState :: ProducingOutput ( batch) ;
759- } ;
756+ self . exec_state =
757+ ExecutionState :: ProducingOutput ( to_emit) ;
760758 // make sure the exec_state just set is not overwritten below
761759 break ' reading_input;
762760 }
@@ -837,33 +835,58 @@ impl Stream for GroupedHashAggregateStream {
837835 }
838836 }
839837
840- ExecutionState :: ProducingOutput ( batch ) => {
838+ ExecutionState :: ProducingOutput ( to_emit ) => {
841839 // slice off a part of the batch, if needed
842- let output_batch;
843840 let size = self . batch_size ;
844- ( self . exec_state , output_batch) = if batch. num_rows ( ) <= size {
845- (
846- if self . input_done {
847- ExecutionState :: Done
848- }
849- // In Partial aggregation, we also need to check
850- // if we should trigger partial skipping
851- else if self . mode == AggregateMode :: Partial
852- && self . should_skip_aggregation ( )
853- {
854- ExecutionState :: SkippingAggregation
841+ let ( batch, remaining) = match to_emit {
842+ EmitTo :: All => {
843+ let to_produce = std:: cmp:: min ( size, self . group_values . len ( ) ) ;
844+ (
845+ if to_produce > 0 {
846+ self . emit ( EmitTo :: First ( to_produce) , false ) ?
847+ } else {
848+ None
849+ } ,
850+ EmitTo :: All ,
851+ )
852+ }
853+ & EmitTo :: First ( n) => {
854+ let to_emit = std:: cmp:: min ( n, size) ;
855+ if to_emit > 0 {
856+ (
857+ self . emit ( EmitTo :: First ( to_emit) , false ) ?,
858+ EmitTo :: First ( n. saturating_sub ( to_emit) ) ,
859+ )
855860 } else {
856- ExecutionState :: ReadingInput
857- } ,
858- batch. clone ( ) ,
859- )
861+ ( None , EmitTo :: First ( 0 ) )
862+ }
863+ }
864+ } ;
865+
866+ let num_rows = batch. as_ref ( ) . map ( |b| b. num_rows ( ) ) . unwrap_or ( 0 ) ;
867+
868+ self . exec_state = if num_rows < size {
869+ if self . input_done {
870+ ExecutionState :: Done
871+ }
872+ // In Partial aggregation, we also need to check
873+ // if we should trigger partial skipping
874+ else if self . mode == AggregateMode :: Partial
875+ && self . should_skip_aggregation ( )
876+ {
877+ ExecutionState :: SkippingAggregation
878+ } else {
879+ ExecutionState :: ReadingInput
880+ }
860881 } else {
861882 // output first batch_size rows
862- let size = self . batch_size ;
863- let num_remaining = batch. num_rows ( ) - size;
864- let remaining = batch. slice ( size, num_remaining) ;
865- let output = batch. slice ( 0 , size) ;
866- ( ExecutionState :: ProducingOutput ( remaining) , output)
883+ ExecutionState :: ProducingOutput ( remaining)
884+ } ;
885+
886+ let output_batch = match batch {
887+ // it could be that no batch was emitted
888+ None => continue ,
889+ Some ( b) => b,
867890 } ;
868891
869892 if let Some ( reduction_factor) = self . reduction_factor . as_ref ( ) {
@@ -1047,10 +1070,8 @@ impl GroupedHashAggregateStream {
10471070 } ,
10481071 } ;
10491072
1050- if n > 0
1051- && let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ?
1052- {
1053- Ok ( Some ( ExecutionState :: ProducingOutput ( batch) ) )
1073+ if n > 0 {
1074+ Ok ( Some ( ExecutionState :: ProducingOutput ( EmitTo :: First ( n) ) ) )
10541075 } else {
10551076 Err ( oom)
10561077 }
@@ -1230,12 +1251,7 @@ impl GroupedHashAggregateStream {
12301251 let timer = elapsed_compute. timer ( ) ;
12311252 self . exec_state = if self . spill_state . spills . is_empty ( ) {
12321253 // Input has been entirely processed without spilling to disk.
1233-
1234- // Flush any remaining group values.
1235- let batch = self . emit ( EmitTo :: All , false ) ?;
1236-
1237- // If there are none, we're done; otherwise switch to emitting them
1238- batch. map_or ( ExecutionState :: Done , ExecutionState :: ProducingOutput )
1254+ ExecutionState :: ProducingOutput ( EmitTo :: All )
12391255 } else {
12401256 // Spill any remaining data to disk. There is some performance overhead in
12411257 // writing out this last chunk of data and reading it back. The benefit of
@@ -1312,9 +1328,8 @@ impl GroupedHashAggregateStream {
13121328 fn switch_to_skip_aggregation ( & mut self ) -> Result < Option < ExecutionState > > {
13131329 if let Some ( probe) = self . skip_aggregation_probe . as_mut ( )
13141330 && probe. should_skip ( )
1315- && let Some ( batch) = self . emit ( EmitTo :: All , false ) ?
13161331 {
1317- return Ok ( Some ( ExecutionState :: ProducingOutput ( batch ) ) ) ;
1332+ return Ok ( Some ( ExecutionState :: ProducingOutput ( EmitTo :: All ) ) ) ;
13181333 } ;
13191334
13201335 Ok ( None )
0 commit comments