Skip to content

Commit ab2219e

Browse files
committed
fix: produce batch_size batches instead of slicing one large batch
Related PRs: #18906 #19562 #15591
1 parent 65d51e8 commit ab2219e

3 files changed

Lines changed: 56 additions & 41 deletions

File tree

datafusion/physical-plan/src/aggregates/order/full.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl GroupOrderingFull {
106106
assert!(*current >= n);
107107
*current -= n;
108108
}
109-
State::Complete => panic!("invalid state: complete"),
109+
State::Complete => {}
110110
}
111111
}
112112

datafusion/physical-plan/src/aggregates/order/partial.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ impl GroupOrderingPartial {
174174
assert!(*current_sort >= n);
175175
*current_sort -= n;
176176
}
177-
State::Complete => panic!("invalid state: complete"),
177+
State::Complete => {}
178178
}
179179
}
180180

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub(crate) enum ExecutionState {
6363
ReadingInput,
6464
/// When producing output, the remaining rows to output are stored
6565
/// here and are sliced off as needed in batch_size chunks
66-
ProducingOutput(RecordBatch),
66+
ProducingOutput(EmitTo),
6767
/// Produce intermediate aggregate state for each input row without
6868
/// aggregation.
6969
///
@@ -753,10 +753,8 @@ impl Stream for GroupedHashAggregateStream {
753753
&& let Some(to_emit) = self.group_ordering.emit_to()
754754
{
755755
timer.done();
756-
if let Some(batch) = self.emit(to_emit, false)? {
757-
self.exec_state =
758-
ExecutionState::ProducingOutput(batch);
759-
};
756+
self.exec_state =
757+
ExecutionState::ProducingOutput(to_emit);
760758
// make sure the exec_state just set is not overwritten below
761759
break 'reading_input;
762760
}
@@ -837,33 +835,58 @@ impl Stream for GroupedHashAggregateStream {
837835
}
838836
}
839837

840-
ExecutionState::ProducingOutput(batch) => {
838+
ExecutionState::ProducingOutput(to_emit) => {
841839
// slice off a part of the batch, if needed
842-
let output_batch;
843840
let size = self.batch_size;
844-
(self.exec_state, output_batch) = if batch.num_rows() <= size {
845-
(
846-
if self.input_done {
847-
ExecutionState::Done
848-
}
849-
// In Partial aggregation, we also need to check
850-
// if we should trigger partial skipping
851-
else if self.mode == AggregateMode::Partial
852-
&& self.should_skip_aggregation()
853-
{
854-
ExecutionState::SkippingAggregation
841+
let (batch, remaining) = match to_emit {
842+
EmitTo::All => {
843+
let to_produce = std::cmp::min(size, self.group_values.len());
844+
(
845+
if to_produce > 0 {
846+
self.emit(EmitTo::First(to_produce), false)?
847+
} else {
848+
None
849+
},
850+
EmitTo::All,
851+
)
852+
}
853+
&EmitTo::First(n) => {
854+
let to_emit = std::cmp::min(n, size);
855+
if to_emit > 0 {
856+
(
857+
self.emit(EmitTo::First(to_emit), false)?,
858+
EmitTo::First(n.saturating_sub(to_emit)),
859+
)
855860
} else {
856-
ExecutionState::ReadingInput
857-
},
858-
batch.clone(),
859-
)
861+
(None, EmitTo::First(0))
862+
}
863+
}
864+
};
865+
866+
let num_rows = batch.as_ref().map(|b| b.num_rows()).unwrap_or(0);
867+
868+
self.exec_state = if num_rows < size {
869+
if self.input_done {
870+
ExecutionState::Done
871+
}
872+
// In Partial aggregation, we also need to check
873+
// if we should trigger partial skipping
874+
else if self.mode == AggregateMode::Partial
875+
&& self.should_skip_aggregation()
876+
{
877+
ExecutionState::SkippingAggregation
878+
} else {
879+
ExecutionState::ReadingInput
880+
}
860881
} else {
861882
// output first batch_size rows
862-
let size = self.batch_size;
863-
let num_remaining = batch.num_rows() - size;
864-
let remaining = batch.slice(size, num_remaining);
865-
let output = batch.slice(0, size);
866-
(ExecutionState::ProducingOutput(remaining), output)
883+
ExecutionState::ProducingOutput(remaining)
884+
};
885+
886+
let output_batch = match batch {
887+
// it could be that no batch was emitted
888+
None => continue,
889+
Some(b) => b,
867890
};
868891

869892
if let Some(reduction_factor) = self.reduction_factor.as_ref() {
@@ -1047,10 +1070,8 @@ impl GroupedHashAggregateStream {
10471070
},
10481071
};
10491072

1050-
if n > 0
1051-
&& let Some(batch) = self.emit(EmitTo::First(n), false)?
1052-
{
1053-
Ok(Some(ExecutionState::ProducingOutput(batch)))
1073+
if n > 0 {
1074+
Ok(Some(ExecutionState::ProducingOutput(EmitTo::First(n))))
10541075
} else {
10551076
Err(oom)
10561077
}
@@ -1230,12 +1251,7 @@ impl GroupedHashAggregateStream {
12301251
let timer = elapsed_compute.timer();
12311252
self.exec_state = if self.spill_state.spills.is_empty() {
12321253
// Input has been entirely processed without spilling to disk.
1233-
1234-
// Flush any remaining group values.
1235-
let batch = self.emit(EmitTo::All, false)?;
1236-
1237-
// If there are none, we're done; otherwise switch to emitting them
1238-
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
1254+
ExecutionState::ProducingOutput(EmitTo::All)
12391255
} else {
12401256
// Spill any remaining data to disk. There is some performance overhead in
12411257
// writing out this last chunk of data and reading it back. The benefit of
@@ -1312,9 +1328,8 @@ impl GroupedHashAggregateStream {
13121328
fn switch_to_skip_aggregation(&mut self) -> Result<Option<ExecutionState>> {
13131329
if let Some(probe) = self.skip_aggregation_probe.as_mut()
13141330
&& probe.should_skip()
1315-
&& let Some(batch) = self.emit(EmitTo::All, false)?
13161331
{
1317-
return Ok(Some(ExecutionState::ProducingOutput(batch)));
1332+
return Ok(Some(ExecutionState::ProducingOutput(EmitTo::All)));
13181333
};
13191334

13201335
Ok(None)

0 commit comments

Comments
 (0)