@@ -62,6 +62,7 @@ pub(crate) enum ExecutionState {
6262 /// When producing output, the remaining rows to output are stored
6363 /// here and are sliced off as needed in batch_size chunks
6464 ProducingOutput ( RecordBatch ) ,
65+ ProducingBlocks ( Option < usize > ) ,
6566 /// Produce intermediate aggregate state for each input row without
6667 /// aggregation.
6768 ///
@@ -387,6 +388,10 @@ pub(crate) struct GroupedHashAggregateStream {
387388 /// Optional probe for skipping data aggregation, if supported by
388389 /// current stream.
389390 skip_aggregation_probe : Option < SkipAggregationProbe > ,
391+
392+ enable_blocked_group_states : bool ,
393+
394+ block_size : usize ,
390395}
391396
392397impl GroupedHashAggregateStream {
@@ -676,6 +681,43 @@ impl Stream for GroupedHashAggregateStream {
676681 ) ) ) ;
677682 }
678683
684+ ExecutionState :: ProducingBlocks ( blocks) => {
685+ if let Some ( blk) = blocks {
686+ if blk > 0 {
687+ self . exec_state = ExecutionState :: ProducingBlocks ( Some ( * blk - 1 ) ) ;
688+ } else {
689+ self . exec_state = if self . input_done {
690+ ExecutionState :: Done
691+ } else if self . should_skip_aggregation ( ) {
692+ ExecutionState :: SkippingAggregation
693+ } else {
694+ ExecutionState :: ReadingInput
695+ } ;
696+ continue ;
697+ }
698+ }
699+
700+ let emit_result = self . emit ( EmitTo :: CurrentBlock ( true ) , false ) ;
701+ if emit_result. is_err ( ) {
702+ return Poll :: Ready ( Some ( emit_result) ) ;
703+ }
704+
705+ let emit_batch = emit_result. unwrap ( ) ;
706+ if emit_batch. num_rows ( ) == 0 {
707+ self . exec_state = if self . input_done {
708+ ExecutionState :: Done
709+ } else if self . should_skip_aggregation ( ) {
710+ ExecutionState :: SkippingAggregation
711+ } else {
712+ ExecutionState :: ReadingInput
713+ } ;
714+ }
715+
716+ return Poll :: Ready ( Some ( Ok (
717+ emit_batch. record_output ( & self . baseline_metrics )
718+ ) ) ) ;
719+ }
720+
679721 ExecutionState :: Done => {
680722 // release the memory reservation since sending back output batch itself needs
681723 // some memory reservation, so make some room for it.
@@ -900,10 +942,15 @@ impl GroupedHashAggregateStream {
900942 && matches ! ( self . group_ordering, GroupOrdering :: None )
901943 && matches ! ( self . mode, AggregateMode :: Partial )
902944 && self . update_memory_reservation ( ) . is_err ( )
903- {
904- let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
905- let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
906- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
945+ {
946+ if self . enable_blocked_group_states {
947+ let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
948+ let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
949+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
950+ } else {
951+ let blocks = self . group_values . len ( ) / self . block_size ;
952+ self . exec_state = ExecutionState :: ProducingBlocks ( Some ( blocks) ) ;
953+ }
907954 }
908955 Ok ( ( ) )
909956 }
@@ -961,8 +1008,12 @@ impl GroupedHashAggregateStream {
9611008 let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
9621009 let timer = elapsed_compute. timer ( ) ;
9631010 self . exec_state = if self . spill_state . spills . is_empty ( ) {
964- let batch = self . emit ( EmitTo :: All , false ) ?;
965- ExecutionState :: ProducingOutput ( batch)
1011+ if !self . enable_blocked_group_states {
1012+ let batch = self . emit ( EmitTo :: All , false ) ?;
1013+ ExecutionState :: ProducingOutput ( batch)
1014+ } else {
1015+ ExecutionState :: ProducingBlocks ( None )
1016+ }
9661017 } else {
9671018 // If spill files exist, stream-merge them.
9681019 self . update_merged_stream ( ) ?;
@@ -994,8 +1045,12 @@ impl GroupedHashAggregateStream {
9941045 fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
9951046 if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
9961047 if probe. should_skip ( ) {
997- let batch = self . emit ( EmitTo :: All , false ) ?;
998- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1048+ if !self . enable_blocked_group_states {
1049+ let batch = self . emit ( EmitTo :: All , false ) ?;
1050+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1051+ } else {
1052+ self . exec_state = ExecutionState :: ProducingBlocks ( None ) ;
1053+ }
9991054 }
10001055 }
10011056
0 commit comments