Skip to content

Commit 62c91dd

Browse files
committed
feat: split oversized blocks in BlockPartitionStream and add unit tests
Prevent BlockPartitionStream from emitting blocks larger than rows_threshold by splitting them after build. Also adds comprehensive unit tests covering partition, finalize_partition, and take_partitions split behavior.
1 parent 801c2c1 commit 62c91dd

6 files changed

Lines changed: 190 additions & 16 deletions

File tree

src/query/expression/src/kernels/stream_partition.rs

Lines changed: 181 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ impl BlockPartitionStream {
164164
}
165165

166166
partition.num_rows = 0;
167-
ready_blocks.push((id, DataBlock::new(columns, rows)));
167+
let block = DataBlock::new(columns, rows);
168+
for sub_block in split_block_if_needed(block, self.rows_threshold) {
169+
ready_blocks.push((id, sub_block));
170+
}
168171
}
169172
}
170173

@@ -214,23 +217,26 @@ impl BlockPartitionStream {
214217

215218
let num_rows = partition.num_rows;
216219
partition.num_rows = 0;
217-
take_blocks.push((id, DataBlock::new(columns, num_rows)));
220+
let block = DataBlock::new(columns, num_rows);
221+
for sub_block in split_block_if_needed(block, self.rows_threshold) {
222+
take_blocks.push((id, sub_block));
223+
}
218224
}
219225

220226
take_blocks
221227
}
222228

223-
pub fn finalize_partition(&mut self, partition_id: usize) -> Option<DataBlock> {
229+
pub fn finalize_partition(&mut self, partition_id: usize) -> Vec<DataBlock> {
224230
if !self.initialize {
225-
return None;
231+
return vec![];
226232
}
227233

228234
let partition = &mut self.partitions[partition_id];
229235

230236
let num_rows = partition.num_rows;
231237

232238
if num_rows == 0 {
233-
return None;
239+
return vec![];
234240
}
235241

236242
let mut columns = Vec::with_capacity(partition.columns_builder.len());
@@ -245,7 +251,20 @@ impl BlockPartitionStream {
245251
}
246252

247253
partition.num_rows = 0;
248-
Some(DataBlock::new(columns, num_rows))
254+
let block = DataBlock::new(columns, num_rows);
255+
self.split_block_if_needed(block)
256+
}
257+
258+
fn split_block_if_needed(&self, block: DataBlock) -> Vec<DataBlock> {
259+
split_block_if_needed(block, self.rows_threshold)
260+
}
261+
}
262+
263+
fn split_block_if_needed(block: DataBlock, rows_threshold: usize) -> Vec<DataBlock> {
264+
if rows_threshold < usize::MAX && block.num_rows() > rows_threshold {
265+
block.split_by_rows_no_tail(rows_threshold)
266+
} else {
267+
vec![block]
249268
}
250269
}
251270

@@ -532,8 +551,163 @@ fn copy_array<I: Index>(
532551
from: &ArrayColumn<AnyType>,
533552
indices: &[I],
534553
) {
535-
// TODO:
536554
for index in indices {
537555
unsafe { to.push(from.index_unchecked(index.to_usize())) }
538556
}
539557
}
558+
559+
#[cfg(test)]
560+
mod tests {
561+
use super::*;
562+
use crate::FromData;
563+
use crate::types::Int32Type;
564+
565+
fn make_block(values: Vec<i32>) -> DataBlock {
566+
DataBlock::new_from_columns(vec![Int32Type::from_data(values)])
567+
}
568+
569+
use crate::types::NumberColumn;
570+
571+
fn collect_column_values(block: &DataBlock) -> Vec<i32> {
572+
let col = block.columns()[0].to_column();
573+
match col.as_number().unwrap() {
574+
NumberColumn::Int32(buf) => buf.to_vec(),
575+
_ => panic!("expected Int32 column"),
576+
}
577+
}
578+
579+
#[test]
580+
fn test_partition_no_split_under_threshold() {
581+
let mut stream = BlockPartitionStream::create(100, 0, 2);
582+
// All indices go to partition 0
583+
let indices = vec![0u64; 50];
584+
let block = make_block((0..50).collect());
585+
let result = stream.partition(indices, block, true);
586+
// 50 rows < 100 threshold, nothing emitted
587+
assert!(result.is_empty());
588+
}
589+
590+
#[test]
591+
fn test_partition_emit_at_threshold() {
592+
let mut stream = BlockPartitionStream::create(10, 0, 1);
593+
let indices = vec![0u64; 10];
594+
let block = make_block((0..10).collect());
595+
let result = stream.partition(indices, block, true);
596+
assert_eq!(result.len(), 1);
597+
assert_eq!(result[0].0, 0);
598+
assert_eq!(result[0].1.num_rows(), 10);
599+
}
600+
601+
#[test]
602+
fn test_partition_splits_large_block() {
603+
let mut stream = BlockPartitionStream::create(10, 0, 1);
604+
// Push 25 rows into partition 0
605+
let indices = vec![0u64; 25];
606+
let block = make_block((0..25).collect());
607+
let result = stream.partition(indices, block, true);
608+
// Should be split into blocks of 10, 10, 5
609+
assert_eq!(result.len(), 3);
610+
assert_eq!(result[0].1.num_rows(), 10);
611+
assert_eq!(result[1].1.num_rows(), 10);
612+
assert_eq!(result[2].1.num_rows(), 5);
613+
// All should have partition_id 0
614+
assert!(result.iter().all(|(id, _)| *id == 0));
615+
// Verify data integrity
616+
let all_values: Vec<i32> = result
617+
.iter()
618+
.flat_map(|(_, b)| collect_column_values(b))
619+
.collect();
620+
assert_eq!(all_values, (0..25).collect::<Vec<i32>>());
621+
}
622+
623+
#[test]
624+
fn test_partition_multiple_partitions_split() {
625+
let mut stream = BlockPartitionStream::create(5, 0, 2);
626+
// 8 rows to partition 0, 7 rows to partition 1
627+
let mut indices = vec![0u64; 8];
628+
indices.extend(vec![1u64; 7]);
629+
let block = make_block((0..15).collect());
630+
let result = stream.partition(indices, block, true);
631+
let p0: Vec<_> = result.iter().filter(|(id, _)| *id == 0).collect();
632+
let p1: Vec<_> = result.iter().filter(|(id, _)| *id == 1).collect();
633+
// partition 0: 8 rows -> split into 5 + 3
634+
assert_eq!(p0.len(), 2);
635+
assert_eq!(p0[0].1.num_rows(), 5);
636+
assert_eq!(p0[1].1.num_rows(), 3);
637+
// partition 1: 7 rows -> split into 5 + 2
638+
assert_eq!(p1.len(), 2);
639+
assert_eq!(p1[0].1.num_rows(), 5);
640+
assert_eq!(p1[1].1.num_rows(), 2);
641+
}
642+
643+
#[test]
644+
fn test_finalize_partition_splits() {
645+
let mut stream = BlockPartitionStream::create(10, 0, 1);
646+
// Push 25 rows but don't emit (out_ready=false)
647+
let indices = vec![0u64; 25];
648+
let block = make_block((0..25).collect());
649+
let result = stream.partition(indices, block, false);
650+
assert!(result.is_empty());
651+
// Finalize should split
652+
let blocks = stream.finalize_partition(0);
653+
assert_eq!(blocks.len(), 3);
654+
assert_eq!(blocks[0].num_rows(), 10);
655+
assert_eq!(blocks[1].num_rows(), 10);
656+
assert_eq!(blocks[2].num_rows(), 5);
657+
}
658+
659+
#[test]
660+
fn test_finalize_empty_partition() {
661+
let mut stream = BlockPartitionStream::create(10, 0, 2);
662+
// Initialize by pushing some data to partition 0
663+
let indices = vec![0u64; 5];
664+
let block = make_block(vec![1, 2, 3, 4, 5]);
665+
stream.partition(indices, block, false);
666+
// Partition 1 has no data
667+
let blocks = stream.finalize_partition(1);
668+
assert!(blocks.is_empty());
669+
}
670+
671+
#[test]
672+
fn test_take_partitions_splits() {
673+
let mut stream = BlockPartitionStream::create(5, 0, 3);
674+
// Push 12 rows to partition 0, 8 to partition 1, 3 to partition 2
675+
let mut indices = vec![0u64; 12];
676+
indices.extend(vec![1u64; 8]);
677+
indices.extend(vec![2u64; 3]);
678+
let block = make_block((0..23).collect());
679+
stream.partition(indices, block, false);
680+
681+
// Take all except partition 2
682+
let excluded: HashSet<usize> = [2].into_iter().collect();
683+
let result = stream.take_partitions(&excluded);
684+
685+
let p0: Vec<_> = result.iter().filter(|(id, _)| *id == 0).collect();
686+
let p1: Vec<_> = result.iter().filter(|(id, _)| *id == 1).collect();
687+
let p2: Vec<_> = result.iter().filter(|(id, _)| *id == 2).collect();
688+
// partition 0: 12 rows -> 5 + 5 + 2
689+
assert_eq!(p0.len(), 3);
690+
assert_eq!(p0[0].1.num_rows(), 5);
691+
assert_eq!(p0[1].1.num_rows(), 5);
692+
assert_eq!(p0[2].1.num_rows(), 2);
693+
// partition 1: 8 rows -> 5 + 3
694+
assert_eq!(p1.len(), 2);
695+
assert_eq!(p1[0].1.num_rows(), 5);
696+
assert_eq!(p1[1].1.num_rows(), 3);
697+
// partition 2 excluded
698+
assert!(p2.is_empty());
699+
}
700+
701+
#[test]
702+
fn test_no_split_when_no_row_threshold() {
703+
// rows_threshold=0 means usize::MAX, no splitting
704+
let mut stream = BlockPartitionStream::create(0, 1, 1);
705+
let indices = vec![0u64; 100];
706+
let block = make_block((0..100).collect());
707+
// bytes_threshold=1 triggers emit, but no row splitting
708+
let result = stream.partition(indices, block, true);
709+
assert_eq!(result.len(), 1);
710+
assert_eq!(result[0].1.num_rows(), 100);
711+
}
712+
713+
}

src/query/service/src/pipelines/processors/transforms/aggregator/new_aggregate/new_aggregate_spiller.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ impl PartitionStream for LocalPartitionStream {
239239
let ids = self.partition_stream.partition_ids();
240240
let mut pending_blocks = Vec::with_capacity(ids.len());
241241
for id in ids {
242-
if let Some(block) = self.partition_stream.finalize_partition(id) {
242+
for block in self.partition_stream.finalize_partition(id) {
243243
pending_blocks.push((id, block));
244244
}
245245
}
@@ -291,7 +291,7 @@ impl SharedPartitionStream {
291291
let mut pending_blocks = Vec::with_capacity(ids.len());
292292

293293
for id in ids {
294-
if let Some(block) = inner.partition_stream.finalize_partition(id) {
294+
for block in inner.partition_stream.finalize_partition(id) {
295295
pending_blocks.push((id, block));
296296
}
297297
}

src/query/service/src/pipelines/processors/transforms/hash_join/hash_join_spiller.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ impl HashJoinSpiller {
291291
if let Some(buffer_blocks) = self.restore_cross_buffer(partition_id)? {
292292
data_blocks.extend(buffer_blocks);
293293
}
294-
} else if let Some(buffer_block) = self.restore_buffer(partition_id) {
295-
data_blocks.push(buffer_block);
294+
} else {
295+
data_blocks.extend(self.restore_buffer(partition_id));
296296
}
297297
}
298298

@@ -324,7 +324,7 @@ impl HashJoinSpiller {
324324
Ok(data_blocks)
325325
}
326326

327-
fn restore_buffer(&mut self, partition_id: usize) -> Option<DataBlock> {
327+
fn restore_buffer(&mut self, partition_id: usize) -> Vec<DataBlock> {
328328
self.block_partition_stream.finalize_partition(partition_id)
329329
}
330330

src/query/service/src/pipelines/processors/transforms/new_hash_join/grace/grace_join.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ impl<T: GraceMemoryJoin> GraceHashJoin<T> {
334334
let ready_partitions_id = self.build_partition_stream.partition_ids();
335335
let mut ready_partitions = Vec::with_capacity(ready_partitions_id.len());
336336
for id in ready_partitions_id {
337-
if let Some(data) = self.build_partition_stream.finalize_partition(id) {
337+
for data in self.build_partition_stream.finalize_partition(id) {
338338
ready_partitions.push((id, data));
339339
}
340340
}
@@ -345,7 +345,7 @@ impl<T: GraceMemoryJoin> GraceHashJoin<T> {
345345
let ready_partitions_id = self.probe_partition_stream.partition_ids();
346346

347347
for id in ready_partitions_id {
348-
if let Some(data_block) = self.probe_partition_stream.finalize_partition(id) {
348+
for data_block in self.probe_partition_stream.finalize_partition(id) {
349349
self.partitions[id].writer.write(data_block)?;
350350
self.partitions[id].writer.flush()?;
351351
}

src/query/service/src/servers/flight/v1/exchange/hash_send_sink.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl Processor for HashSendSink {
129129
let mut futures = Vec::new();
130130

131131
for partition_id in 0..self.channels.len() {
132-
if let Some(block) = self.partition_stream.finalize_partition(partition_id) {
132+
for block in self.partition_stream.finalize_partition(partition_id) {
133133
if block.is_empty() {
134134
continue;
135135
}

src/query/service/src/servers/flight/v1/exchange/hash_send_transform.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ impl Processor for HashSendTransform {
162162
let mut futures = Vec::new();
163163

164164
for partition_id in 0..self.channels.len() {
165-
if let Some(block) = self.partition_stream.finalize_partition(partition_id) {
165+
for block in self.partition_stream.finalize_partition(partition_id) {
166166
if block.is_empty() {
167167
continue;
168168
}

0 commit comments

Comments
 (0)