1616// under the License.
1717
1818use crate :: execution:: shuffle:: metrics:: ShufflePartitionerMetrics ;
19- use crate :: execution:: shuffle:: partitioners:: partition_buffer:: { self , PartitionBuffer } ;
19+ use crate :: execution:: shuffle:: partitioners:: partition_buffer:: { ColumnBuffer , PartitionBuffer } ;
2020use crate :: execution:: shuffle:: partitioners:: ShufflePartitioner ;
2121use crate :: execution:: shuffle:: writers:: { BufBatchWriter , PartitionWriter } ;
2222use crate :: execution:: shuffle:: {
2323 comet_partitioning, CometPartitioning , CompressionCodec , ShuffleBlockWriter ,
2424} ;
2525use crate :: execution:: tracing:: { with_trace, with_trace_async} ;
2626use arrow:: array:: { Array , ArrayRef , BooleanArray , RecordBatch } ;
27- use arrow:: datatypes:: { DataType , SchemaRef } ;
27+ use arrow:: datatypes:: SchemaRef ;
2828use datafusion:: common:: DataFusionError ;
2929use datafusion:: execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
3030use datafusion:: execution:: runtime_env:: RuntimeEnv ;
@@ -105,7 +105,6 @@ pub(crate) struct MultiPartitionShuffleRepartitioner {
105105 output_data_file : String ,
106106 output_index_file : String ,
107107 partition_buffers : Vec < PartitionBuffer > ,
108- has_fallback_columns : bool ,
109108 partition_writers : Vec < PartitionWriter > ,
110109 shuffle_block_writer : ShuffleBlockWriter ,
111110 /// Partitioning scheme to use
@@ -167,32 +166,6 @@ impl MultiPartitionShuffleRepartitioner {
167166 . map ( |_| PartitionWriter :: try_new ( shuffle_block_writer. clone ( ) ) )
168167 . collect :: < datafusion:: common:: Result < Vec < _ > > > ( ) ?;
169168
170- let has_fallback_columns = schema. fields ( ) . iter ( ) . any ( |f| {
171- !matches ! (
172- f. data_type( ) ,
173- DataType :: Boolean
174- | DataType :: Int8
175- | DataType :: Int16
176- | DataType :: Int32
177- | DataType :: Int64
178- | DataType :: UInt8
179- | DataType :: UInt16
180- | DataType :: UInt32
181- | DataType :: UInt64
182- | DataType :: Float16
183- | DataType :: Float32
184- | DataType :: Float64
185- | DataType :: Date32
186- | DataType :: Date64
187- | DataType :: Timestamp ( _, _)
188- | DataType :: Duration ( _)
189- | DataType :: Decimal128 ( _, _)
190- | DataType :: Utf8
191- | DataType :: Binary
192- | DataType :: LargeUtf8
193- | DataType :: LargeBinary
194- )
195- } ) ;
196169 let estimated_rows_per_partition = batch_size / num_output_partitions. max ( 1 ) ;
197170 let partition_buffers = ( 0 ..num_output_partitions)
198171 . map ( |_| PartitionBuffer :: new ( Arc :: clone ( & schema) , estimated_rows_per_partition) )
@@ -206,7 +179,6 @@ impl MultiPartitionShuffleRepartitioner {
206179 output_data_file,
207180 output_index_file,
208181 partition_buffers,
209- has_fallback_columns,
210182 partition_writers,
211183 shuffle_block_writer,
212184 partitioning,
@@ -421,128 +393,112 @@ impl MultiPartitionShuffleRepartitioner {
421393 // rows within that partition. This keeps writes to the same partition buffer
422394 // sequential for better cache locality.
423395 for ( col_idx, column) in input. columns ( ) . iter ( ) . enumerate ( ) {
424- // Determine scatter path from first partition's column type
425- // (all partitions have the same column types)
426- let is_fixed = matches ! (
427- self . partition_buffers[ 0 ] . columns[ col_idx] ,
428- partition_buffer:: ColumnBuffer :: Fixed { .. }
429- ) ;
430- let is_variable = matches ! (
431- self . partition_buffers[ 0 ] . columns[ col_idx] ,
432- partition_buffer:: ColumnBuffer :: Variable { .. }
433- ) ;
434- let is_large_variable = matches ! (
435- self . partition_buffers[ 0 ] . columns[ col_idx] ,
436- partition_buffer:: ColumnBuffer :: LargeVariable { .. }
437- ) ;
438- let is_boolean = matches ! (
439- self . partition_buffers[ 0 ] . columns[ col_idx] ,
440- partition_buffer:: ColumnBuffer :: Boolean { .. }
441- ) ;
442-
443396 let nulls = column. nulls ( ) ;
444397
445- if is_fixed {
446- let byte_width = match & self . partition_buffers [ 0 ] . columns [ col_idx] {
447- partition_buffer :: ColumnBuffer :: Fixed { byte_width, .. } => * byte_width ,
448- _ => unreachable ! ( ) ,
449- } ;
450- let data = column . to_data ( ) ;
451- let values = data . buffers ( ) [ 0 ] . as_slice ( ) ;
452- for p in 0 ..num_partitions {
453- let start = partition_starts[ p] as usize ;
454- let end = partition_starts [ p + 1 ] as usize ;
455- if start == end {
456- continue ;
457- }
458- let row_indices = & partition_row_indices [ start..end ] ;
459- for & row_idx in row_indices {
460- let row = row_idx as usize ;
461- let src_offset = row * byte_width ;
462- let is_valid = nulls . is_none_or ( |n| n . is_valid ( row ) ) ;
463- self . partition_buffers [ p ] . append_fixed (
464- col_idx ,
465- & values [ src_offset..src_offset + byte_width ] ,
466- is_valid ,
467- ) ;
398+ // Single match to determine scatter path from first partition's column type
399+ match & self . partition_buffers [ 0 ] . columns [ col_idx] {
400+ ColumnBuffer :: Fixed { byte_width, .. } => {
401+ let byte_width = * byte_width ;
402+ let data = column . to_data ( ) ;
403+ let values = data . buffers ( ) [ 0 ] . as_slice ( ) ;
404+ for p in 0 ..num_partitions {
405+ let start = partition_starts [ p ] as usize ;
406+ let end = partition_starts[ p + 1 ] as usize ;
407+ if start == end {
408+ continue ;
409+ }
410+ let row_indices = & partition_row_indices [ start..end ] ;
411+ for & row_idx in row_indices {
412+ let row = row_idx as usize ;
413+ let src_offset = row * byte_width ;
414+ let is_valid = nulls . is_none_or ( |n| n . is_valid ( row ) ) ;
415+ self . partition_buffers [ p ] . append_fixed (
416+ col_idx ,
417+ & values [ src_offset..src_offset + byte_width ] ,
418+ is_valid ,
419+ ) ;
420+ }
468421 }
469422 }
470- } else if is_variable {
471- let data = column. to_data ( ) ;
472- let offsets_slice = data. buffers ( ) [ 0 ] . typed_data :: < i32 > ( ) ;
473- let values_slice = data. buffers ( ) [ 1 ] . as_slice ( ) ;
474- for p in 0 ..num_partitions {
475- let start = partition_starts[ p] as usize ;
476- let end = partition_starts[ p + 1 ] as usize ;
477- if start == end {
478- continue ;
479- }
480- let row_indices = & partition_row_indices[ start..end] ;
481- for & row_idx in row_indices {
482- let row = row_idx as usize ;
483- let val_start = offsets_slice[ row] as usize ;
484- let val_end = offsets_slice[ row + 1 ] as usize ;
485- let is_valid = nulls. is_none_or ( |n| n. is_valid ( row) ) ;
486- self . partition_buffers [ p] . append_variable (
487- col_idx,
488- & values_slice[ val_start..val_end] ,
489- is_valid,
490- ) ;
423+ ColumnBuffer :: Variable { .. } => {
424+ let data = column. to_data ( ) ;
425+ let offsets_slice = data. buffers ( ) [ 0 ] . typed_data :: < i32 > ( ) ;
426+ let values_slice = data. buffers ( ) [ 1 ] . as_slice ( ) ;
427+ for p in 0 ..num_partitions {
428+ let start = partition_starts[ p] as usize ;
429+ let end = partition_starts[ p + 1 ] as usize ;
430+ if start == end {
431+ continue ;
432+ }
433+ let row_indices = & partition_row_indices[ start..end] ;
434+ for & row_idx in row_indices {
435+ let row = row_idx as usize ;
436+ let val_start = offsets_slice[ row] as usize ;
437+ let val_end = offsets_slice[ row + 1 ] as usize ;
438+ let is_valid = nulls. is_none_or ( |n| n. is_valid ( row) ) ;
439+ self . partition_buffers [ p] . append_variable (
440+ col_idx,
441+ & values_slice[ val_start..val_end] ,
442+ is_valid,
443+ ) ;
444+ }
491445 }
492446 }
493- } else if is_large_variable {
494- let data = column. to_data ( ) ;
495- let offsets_slice = data. buffers ( ) [ 0 ] . typed_data :: < i64 > ( ) ;
496- let values_slice = data. buffers ( ) [ 1 ] . as_slice ( ) ;
497- for p in 0 ..num_partitions {
498- let start = partition_starts[ p] as usize ;
499- let end = partition_starts[ p + 1 ] as usize ;
500- if start == end {
501- continue ;
502- }
503- let row_indices = & partition_row_indices[ start..end] ;
504- for & row_idx in row_indices {
505- let row = row_idx as usize ;
506- let val_start = offsets_slice[ row] as usize ;
507- let val_end = offsets_slice[ row + 1 ] as usize ;
508- let is_valid = nulls. is_none_or ( |n| n. is_valid ( row) ) ;
509- self . partition_buffers [ p] . append_large_variable (
510- col_idx,
511- & values_slice[ val_start..val_end] ,
512- is_valid,
513- ) ;
447+ ColumnBuffer :: LargeVariable { .. } => {
448+ let data = column. to_data ( ) ;
449+ let offsets_slice = data. buffers ( ) [ 0 ] . typed_data :: < i64 > ( ) ;
450+ let values_slice = data. buffers ( ) [ 1 ] . as_slice ( ) ;
451+ for p in 0 ..num_partitions {
452+ let start = partition_starts[ p] as usize ;
453+ let end = partition_starts[ p + 1 ] as usize ;
454+ if start == end {
455+ continue ;
456+ }
457+ let row_indices = & partition_row_indices[ start..end] ;
458+ for & row_idx in row_indices {
459+ let row = row_idx as usize ;
460+ let val_start = offsets_slice[ row] as usize ;
461+ let val_end = offsets_slice[ row + 1 ] as usize ;
462+ let is_valid = nulls. is_none_or ( |n| n. is_valid ( row) ) ;
463+ self . partition_buffers [ p] . append_large_variable (
464+ col_idx,
465+ & values_slice[ val_start..val_end] ,
466+ is_valid,
467+ ) ;
468+ }
514469 }
515470 }
516- } else if is_boolean {
517- let bool_array = column. as_any ( ) . downcast_ref :: < BooleanArray > ( ) . unwrap ( ) ;
518- for p in 0 ..num_partitions {
519- let start = partition_starts[ p] as usize ;
520- let end = partition_starts[ p + 1 ] as usize ;
521- if start == end {
522- continue ;
523- }
524- let row_indices = & partition_row_indices[ start..end] ;
525- for & row_idx in row_indices {
526- let row = row_idx as usize ;
527- let is_valid = nulls. is_none_or ( |n| n. is_valid ( row) ) ;
528- self . partition_buffers [ p] . append_bool (
529- col_idx,
530- bool_array. value ( row) ,
531- is_valid,
532- ) ;
471+ ColumnBuffer :: Boolean { .. } => {
472+ let bool_array = column. as_any ( ) . downcast_ref :: < BooleanArray > ( ) . unwrap ( ) ;
473+ for p in 0 ..num_partitions {
474+ let start = partition_starts[ p] as usize ;
475+ let end = partition_starts[ p + 1 ] as usize ;
476+ if start == end {
477+ continue ;
478+ }
479+ let row_indices = & partition_row_indices[ start..end] ;
480+ for & row_idx in row_indices {
481+ let row = row_idx as usize ;
482+ let is_valid = nulls. is_none_or ( |n| n. is_valid ( row) ) ;
483+ self . partition_buffers [ p] . append_bool (
484+ col_idx,
485+ bool_array. value ( row) ,
486+ is_valid,
487+ ) ;
488+ }
533489 }
534490 }
535- } else {
536- // Fallback
537- for p in 0 ..num_partitions {
538- let start = partition_starts[ p] as usize ;
539- let end = partition_starts [ p + 1 ] as usize ;
540- if start == end {
541- continue ;
542- }
543- let row_indices = & partition_row_indices [ start..end ] ;
544- for & row_idx in row_indices {
545- self . partition_buffers [ p ] . append_fallback_index ( col_idx , row_idx ) ;
491+ ColumnBuffer :: Fallback { .. } => {
492+ for p in 0 ..num_partitions {
493+ let start = partition_starts [ p ] as usize ;
494+ let end = partition_starts[ p + 1 ] as usize ;
495+ if start == end {
496+ continue ;
497+ }
498+ let row_indices = & partition_row_indices [ start..end ] ;
499+ for & row_idx in row_indices {
500+ self . partition_buffers [ p ] . append_fallback_index ( col_idx , row_idx ) ;
501+ }
546502 }
547503 }
548504 }
@@ -552,15 +508,23 @@ impl MultiPartitionShuffleRepartitioner {
552508 . scatter_time
553509 . add_duration ( scatter_start. elapsed ( ) ) ;
554510
555- // Update row counts from partition_starts ( O(num_partitions), not O(num_rows) )
511+ // O(num_partitions) rather than O(num_rows)
556512 for p in 0 ..num_partitions {
557513 let count = ( partition_starts[ p + 1 ] - partition_starts[ p] ) as usize ;
558514 self . partition_buffers [ p] . row_count += count;
559515 }
560516
561- // Auto-flush partitions that reached batch_size
517+ // Flush partitions. When fallback columns exist, flush ALL non-empty
518+ // partitions since fallback indices reference the current input batch.
519+ // Otherwise, only flush partitions that reached batch_size.
520+ let flush_all = self . partition_buffers [ 0 ] . has_fallback_columns ( ) ;
562521 for p in 0 ..num_partitions {
563- if self . partition_buffers [ p] . row_count >= self . batch_size {
522+ let should_flush = if flush_all {
523+ self . partition_buffers [ p] . row_count > 0
524+ } else {
525+ self . partition_buffers [ p] . row_count >= self . batch_size
526+ } ;
527+ if should_flush {
564528 let batch = self . partition_buffers [ p] . flush ( Some ( input) ) ?;
565529 self . partition_writers [ p] . spill (
566530 & [ batch] ,
@@ -572,23 +536,6 @@ impl MultiPartitionShuffleRepartitioner {
572536 }
573537 }
574538
575- // If schema has fallback columns, flush ALL non-empty partitions
576- // since fallback indices reference the current input batch
577- if self . has_fallback_columns {
578- for p in 0 ..num_partitions {
579- if self . partition_buffers [ p] . row_count > 0 {
580- let batch = self . partition_buffers [ p] . flush ( Some ( input) ) ?;
581- self . partition_writers [ p] . spill (
582- & [ batch] ,
583- & self . runtime ,
584- & self . metrics ,
585- self . write_buffer_size ,
586- self . batch_size ,
587- ) ?;
588- }
589- }
590- }
591-
592539 // Precise memory tracking
593540 let mem_after: usize = self . partition_buffers . iter ( ) . map ( |b| b. memory_size ( ) ) . sum ( ) ;
594541 let mem_growth = mem_after. saturating_sub ( mem_before) ;
0 commit comments