Skip to content

Commit a9b2114

Browse files
committed
refactor: simplify scatter kernel with single match dispatch and deduplicated type lists
- Use DataType::primitive_width() instead of manual byte_width mapping - Derive has_fallback_columns from PartitionBuffer instead of duplicating type list - Replace 4 separate matches! checks with single match on ColumnBuffer variant - Make auto-flush vs fallback-flush mutually exclusive - Remove dead clear() method
1 parent 7c75c01 commit a9b2114

2 files changed

Lines changed: 135 additions & 245 deletions

File tree

native/core/src/execution/shuffle/partitioners/multi_partition.rs

Lines changed: 109 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
// under the License.
1717

1818
use crate::execution::shuffle::metrics::ShufflePartitionerMetrics;
19-
use crate::execution::shuffle::partitioners::partition_buffer::{self, PartitionBuffer};
19+
use crate::execution::shuffle::partitioners::partition_buffer::{ColumnBuffer, PartitionBuffer};
2020
use crate::execution::shuffle::partitioners::ShufflePartitioner;
2121
use crate::execution::shuffle::writers::{BufBatchWriter, PartitionWriter};
2222
use crate::execution::shuffle::{
2323
comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter,
2424
};
2525
use crate::execution::tracing::{with_trace, with_trace_async};
2626
use arrow::array::{Array, ArrayRef, BooleanArray, RecordBatch};
27-
use arrow::datatypes::{DataType, SchemaRef};
27+
use arrow::datatypes::SchemaRef;
2828
use datafusion::common::DataFusionError;
2929
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
3030
use datafusion::execution::runtime_env::RuntimeEnv;
@@ -105,7 +105,6 @@ pub(crate) struct MultiPartitionShuffleRepartitioner {
105105
output_data_file: String,
106106
output_index_file: String,
107107
partition_buffers: Vec<PartitionBuffer>,
108-
has_fallback_columns: bool,
109108
partition_writers: Vec<PartitionWriter>,
110109
shuffle_block_writer: ShuffleBlockWriter,
111110
/// Partitioning scheme to use
@@ -167,32 +166,6 @@ impl MultiPartitionShuffleRepartitioner {
167166
.map(|_| PartitionWriter::try_new(shuffle_block_writer.clone()))
168167
.collect::<datafusion::common::Result<Vec<_>>>()?;
169168

170-
let has_fallback_columns = schema.fields().iter().any(|f| {
171-
!matches!(
172-
f.data_type(),
173-
DataType::Boolean
174-
| DataType::Int8
175-
| DataType::Int16
176-
| DataType::Int32
177-
| DataType::Int64
178-
| DataType::UInt8
179-
| DataType::UInt16
180-
| DataType::UInt32
181-
| DataType::UInt64
182-
| DataType::Float16
183-
| DataType::Float32
184-
| DataType::Float64
185-
| DataType::Date32
186-
| DataType::Date64
187-
| DataType::Timestamp(_, _)
188-
| DataType::Duration(_)
189-
| DataType::Decimal128(_, _)
190-
| DataType::Utf8
191-
| DataType::Binary
192-
| DataType::LargeUtf8
193-
| DataType::LargeBinary
194-
)
195-
});
196169
let estimated_rows_per_partition = batch_size / num_output_partitions.max(1);
197170
let partition_buffers = (0..num_output_partitions)
198171
.map(|_| PartitionBuffer::new(Arc::clone(&schema), estimated_rows_per_partition))
@@ -206,7 +179,6 @@ impl MultiPartitionShuffleRepartitioner {
206179
output_data_file,
207180
output_index_file,
208181
partition_buffers,
209-
has_fallback_columns,
210182
partition_writers,
211183
shuffle_block_writer,
212184
partitioning,
@@ -421,128 +393,112 @@ impl MultiPartitionShuffleRepartitioner {
421393
// rows within that partition. This keeps writes to the same partition buffer
422394
// sequential for better cache locality.
423395
for (col_idx, column) in input.columns().iter().enumerate() {
424-
// Determine scatter path from first partition's column type
425-
// (all partitions have the same column types)
426-
let is_fixed = matches!(
427-
self.partition_buffers[0].columns[col_idx],
428-
partition_buffer::ColumnBuffer::Fixed { .. }
429-
);
430-
let is_variable = matches!(
431-
self.partition_buffers[0].columns[col_idx],
432-
partition_buffer::ColumnBuffer::Variable { .. }
433-
);
434-
let is_large_variable = matches!(
435-
self.partition_buffers[0].columns[col_idx],
436-
partition_buffer::ColumnBuffer::LargeVariable { .. }
437-
);
438-
let is_boolean = matches!(
439-
self.partition_buffers[0].columns[col_idx],
440-
partition_buffer::ColumnBuffer::Boolean { .. }
441-
);
442-
443396
let nulls = column.nulls();
444397

445-
if is_fixed {
446-
let byte_width = match &self.partition_buffers[0].columns[col_idx] {
447-
partition_buffer::ColumnBuffer::Fixed { byte_width, .. } => *byte_width,
448-
_ => unreachable!(),
449-
};
450-
let data = column.to_data();
451-
let values = data.buffers()[0].as_slice();
452-
for p in 0..num_partitions {
453-
let start = partition_starts[p] as usize;
454-
let end = partition_starts[p + 1] as usize;
455-
if start == end {
456-
continue;
457-
}
458-
let row_indices = &partition_row_indices[start..end];
459-
for &row_idx in row_indices {
460-
let row = row_idx as usize;
461-
let src_offset = row * byte_width;
462-
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
463-
self.partition_buffers[p].append_fixed(
464-
col_idx,
465-
&values[src_offset..src_offset + byte_width],
466-
is_valid,
467-
);
398+
// Single match to determine scatter path from first partition's column type
399+
match &self.partition_buffers[0].columns[col_idx] {
400+
ColumnBuffer::Fixed { byte_width, .. } => {
401+
let byte_width = *byte_width;
402+
let data = column.to_data();
403+
let values = data.buffers()[0].as_slice();
404+
for p in 0..num_partitions {
405+
let start = partition_starts[p] as usize;
406+
let end = partition_starts[p + 1] as usize;
407+
if start == end {
408+
continue;
409+
}
410+
let row_indices = &partition_row_indices[start..end];
411+
for &row_idx in row_indices {
412+
let row = row_idx as usize;
413+
let src_offset = row * byte_width;
414+
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
415+
self.partition_buffers[p].append_fixed(
416+
col_idx,
417+
&values[src_offset..src_offset + byte_width],
418+
is_valid,
419+
);
420+
}
468421
}
469422
}
470-
} else if is_variable {
471-
let data = column.to_data();
472-
let offsets_slice = data.buffers()[0].typed_data::<i32>();
473-
let values_slice = data.buffers()[1].as_slice();
474-
for p in 0..num_partitions {
475-
let start = partition_starts[p] as usize;
476-
let end = partition_starts[p + 1] as usize;
477-
if start == end {
478-
continue;
479-
}
480-
let row_indices = &partition_row_indices[start..end];
481-
for &row_idx in row_indices {
482-
let row = row_idx as usize;
483-
let val_start = offsets_slice[row] as usize;
484-
let val_end = offsets_slice[row + 1] as usize;
485-
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
486-
self.partition_buffers[p].append_variable(
487-
col_idx,
488-
&values_slice[val_start..val_end],
489-
is_valid,
490-
);
423+
ColumnBuffer::Variable { .. } => {
424+
let data = column.to_data();
425+
let offsets_slice = data.buffers()[0].typed_data::<i32>();
426+
let values_slice = data.buffers()[1].as_slice();
427+
for p in 0..num_partitions {
428+
let start = partition_starts[p] as usize;
429+
let end = partition_starts[p + 1] as usize;
430+
if start == end {
431+
continue;
432+
}
433+
let row_indices = &partition_row_indices[start..end];
434+
for &row_idx in row_indices {
435+
let row = row_idx as usize;
436+
let val_start = offsets_slice[row] as usize;
437+
let val_end = offsets_slice[row + 1] as usize;
438+
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
439+
self.partition_buffers[p].append_variable(
440+
col_idx,
441+
&values_slice[val_start..val_end],
442+
is_valid,
443+
);
444+
}
491445
}
492446
}
493-
} else if is_large_variable {
494-
let data = column.to_data();
495-
let offsets_slice = data.buffers()[0].typed_data::<i64>();
496-
let values_slice = data.buffers()[1].as_slice();
497-
for p in 0..num_partitions {
498-
let start = partition_starts[p] as usize;
499-
let end = partition_starts[p + 1] as usize;
500-
if start == end {
501-
continue;
502-
}
503-
let row_indices = &partition_row_indices[start..end];
504-
for &row_idx in row_indices {
505-
let row = row_idx as usize;
506-
let val_start = offsets_slice[row] as usize;
507-
let val_end = offsets_slice[row + 1] as usize;
508-
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
509-
self.partition_buffers[p].append_large_variable(
510-
col_idx,
511-
&values_slice[val_start..val_end],
512-
is_valid,
513-
);
447+
ColumnBuffer::LargeVariable { .. } => {
448+
let data = column.to_data();
449+
let offsets_slice = data.buffers()[0].typed_data::<i64>();
450+
let values_slice = data.buffers()[1].as_slice();
451+
for p in 0..num_partitions {
452+
let start = partition_starts[p] as usize;
453+
let end = partition_starts[p + 1] as usize;
454+
if start == end {
455+
continue;
456+
}
457+
let row_indices = &partition_row_indices[start..end];
458+
for &row_idx in row_indices {
459+
let row = row_idx as usize;
460+
let val_start = offsets_slice[row] as usize;
461+
let val_end = offsets_slice[row + 1] as usize;
462+
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
463+
self.partition_buffers[p].append_large_variable(
464+
col_idx,
465+
&values_slice[val_start..val_end],
466+
is_valid,
467+
);
468+
}
514469
}
515470
}
516-
} else if is_boolean {
517-
let bool_array = column.as_any().downcast_ref::<BooleanArray>().unwrap();
518-
for p in 0..num_partitions {
519-
let start = partition_starts[p] as usize;
520-
let end = partition_starts[p + 1] as usize;
521-
if start == end {
522-
continue;
523-
}
524-
let row_indices = &partition_row_indices[start..end];
525-
for &row_idx in row_indices {
526-
let row = row_idx as usize;
527-
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
528-
self.partition_buffers[p].append_bool(
529-
col_idx,
530-
bool_array.value(row),
531-
is_valid,
532-
);
471+
ColumnBuffer::Boolean { .. } => {
472+
let bool_array = column.as_any().downcast_ref::<BooleanArray>().unwrap();
473+
for p in 0..num_partitions {
474+
let start = partition_starts[p] as usize;
475+
let end = partition_starts[p + 1] as usize;
476+
if start == end {
477+
continue;
478+
}
479+
let row_indices = &partition_row_indices[start..end];
480+
for &row_idx in row_indices {
481+
let row = row_idx as usize;
482+
let is_valid = nulls.is_none_or(|n| n.is_valid(row));
483+
self.partition_buffers[p].append_bool(
484+
col_idx,
485+
bool_array.value(row),
486+
is_valid,
487+
);
488+
}
533489
}
534490
}
535-
} else {
536-
// Fallback
537-
for p in 0..num_partitions {
538-
let start = partition_starts[p] as usize;
539-
let end = partition_starts[p + 1] as usize;
540-
if start == end {
541-
continue;
542-
}
543-
let row_indices = &partition_row_indices[start..end];
544-
for &row_idx in row_indices {
545-
self.partition_buffers[p].append_fallback_index(col_idx, row_idx);
491+
ColumnBuffer::Fallback { .. } => {
492+
for p in 0..num_partitions {
493+
let start = partition_starts[p] as usize;
494+
let end = partition_starts[p + 1] as usize;
495+
if start == end {
496+
continue;
497+
}
498+
let row_indices = &partition_row_indices[start..end];
499+
for &row_idx in row_indices {
500+
self.partition_buffers[p].append_fallback_index(col_idx, row_idx);
501+
}
546502
}
547503
}
548504
}
@@ -552,15 +508,23 @@ impl MultiPartitionShuffleRepartitioner {
552508
.scatter_time
553509
.add_duration(scatter_start.elapsed());
554510

555-
// Update row counts from partition_starts (O(num_partitions), not O(num_rows))
511+
// O(num_partitions) rather than O(num_rows)
556512
for p in 0..num_partitions {
557513
let count = (partition_starts[p + 1] - partition_starts[p]) as usize;
558514
self.partition_buffers[p].row_count += count;
559515
}
560516

561-
// Auto-flush partitions that reached batch_size
517+
// Flush partitions. When fallback columns exist, flush ALL non-empty
518+
// partitions since fallback indices reference the current input batch.
519+
// Otherwise, only flush partitions that reached batch_size.
520+
let flush_all = self.partition_buffers[0].has_fallback_columns();
562521
for p in 0..num_partitions {
563-
if self.partition_buffers[p].row_count >= self.batch_size {
522+
let should_flush = if flush_all {
523+
self.partition_buffers[p].row_count > 0
524+
} else {
525+
self.partition_buffers[p].row_count >= self.batch_size
526+
};
527+
if should_flush {
564528
let batch = self.partition_buffers[p].flush(Some(input))?;
565529
self.partition_writers[p].spill(
566530
&[batch],
@@ -572,23 +536,6 @@ impl MultiPartitionShuffleRepartitioner {
572536
}
573537
}
574538

575-
// If schema has fallback columns, flush ALL non-empty partitions
576-
// since fallback indices reference the current input batch
577-
if self.has_fallback_columns {
578-
for p in 0..num_partitions {
579-
if self.partition_buffers[p].row_count > 0 {
580-
let batch = self.partition_buffers[p].flush(Some(input))?;
581-
self.partition_writers[p].spill(
582-
&[batch],
583-
&self.runtime,
584-
&self.metrics,
585-
self.write_buffer_size,
586-
self.batch_size,
587-
)?;
588-
}
589-
}
590-
}
591-
592539
// Precise memory tracking
593540
let mem_after: usize = self.partition_buffers.iter().map(|b| b.memory_size()).sum();
594541
let mem_growth = mem_after.saturating_sub(mem_before);

0 commit comments

Comments
 (0)