@@ -382,8 +382,11 @@ impl MultiPartitionShuffleRepartitioner {
382382 // The initial values are not used.
383383 let scratch = ScratchSpace {
384384 hashes_buf : match partitioning {
385- // Only allocate the hashes_buf if hash partitioning.
386- CometPartitioning :: Hash ( _, _) => vec ! [ 0 ; batch_size] ,
385+ // Allocate hashes_buf for hash and round robin partitioning.
386+ // Round robin hashes all columns to achieve even, deterministic distribution.
387+ CometPartitioning :: Hash ( _, _) | CometPartitioning :: RoundRobin ( _, _) => {
388+ vec ! [ 0 ; batch_size]
389+ }
387390 _ => vec ! [ ] ,
388391 } ,
389392 partition_ids : vec ! [ 0 ; batch_size] ,
@@ -598,6 +601,68 @@ impl MultiPartitionShuffleRepartitioner {
598601 . await ?;
599602 self . scratch = scratch;
600603 }
604+ CometPartitioning :: RoundRobin ( num_output_partitions, max_hash_columns) => {
605+ // Comet implements "round robin" as hash partitioning on columns.
606+ // This achieves the same goal as Spark's round robin (even distribution
607+ // without semantic grouping) while being deterministic for fault tolerance.
608+ //
609+ // Note: This produces different partition assignments than Spark's round robin,
610+ // which sorts by UnsafeRow binary representation before assigning partitions.
611+ // However, both approaches provide even distribution and determinism.
612+ let mut scratch = std:: mem:: take ( & mut self . scratch ) ;
613+ let ( partition_starts, partition_row_indices) : ( & Vec < u32 > , & Vec < u32 > ) = {
614+ let mut timer = self . metrics . repart_time . timer ( ) ;
615+
616+ let num_rows = input. num_rows ( ) ;
617+
618+ // Collect columns for hashing, respecting max_hash_columns limit
619+ // max_hash_columns of 0 means no limit (hash all columns)
620+ // Negative values are normalized to 0 in the planner
621+ let num_columns_to_hash = if * max_hash_columns == 0 {
622+ input. num_columns ( )
623+ } else {
624+ ( * max_hash_columns) . min ( input. num_columns ( ) )
625+ } ;
626+ let columns_to_hash: Vec < ArrayRef > = ( 0 ..num_columns_to_hash)
627+ . map ( |i| Arc :: clone ( input. column ( i) ) )
628+ . collect ( ) ;
629+
630+ // Use identical seed as Spark hash partitioning.
631+ let hashes_buf = & mut scratch. hashes_buf [ ..num_rows] ;
632+ hashes_buf. fill ( 42_u32 ) ;
633+
634+ // Compute hash for selected columns
635+ create_murmur3_hashes ( & columns_to_hash, hashes_buf) ?;
636+
637+ // Assign partition IDs based on hash (same as hash partitioning)
638+ let partition_ids = & mut scratch. partition_ids [ ..num_rows] ;
639+ hashes_buf. iter ( ) . enumerate ( ) . for_each ( |( idx, hash) | {
640+ partition_ids[ idx] = pmod ( * hash, * num_output_partitions) as u32 ;
641+ } ) ;
642+
643+ // We now have partition ids for every input row, map that to partition starts
644+ // and partition indices to eventually write these rows to partition buffers.
645+ map_partition_ids_to_starts_and_indices (
646+ & mut scratch,
647+ * num_output_partitions,
648+ num_rows,
649+ ) ;
650+
651+ timer. stop ( ) ;
652+ Ok :: < ( & Vec < u32 > , & Vec < u32 > ) , DataFusionError > ( (
653+ & scratch. partition_starts ,
654+ & scratch. partition_row_indices ,
655+ ) )
656+ } ?;
657+
658+ self . buffer_partitioned_batch_may_spill (
659+ input,
660+ partition_row_indices,
661+ partition_starts,
662+ )
663+ . await ?;
664+ self . scratch = scratch;
665+ }
601666 other => {
602667 // this should be unreachable as long as the validation logic
603668 // in the constructor is kept up-to-date
@@ -1431,6 +1496,7 @@ mod test {
14311496 Arc :: new ( row_converter) ,
14321497 owned_rows,
14331498 ) ,
1499+ CometPartitioning :: RoundRobin ( num_partitions, 0 ) ,
14341500 ] {
14351501 let batches = ( 0 ..num_batches) . map ( |_| batch. clone ( ) ) . collect :: < Vec < _ > > ( ) ;
14361502
@@ -1483,4 +1549,95 @@ mod test {
14831549 let expected = vec ! [ 69 , 5 , 193 , 171 , 115 ] ;
14841550 assert_eq ! ( result, expected) ;
14851551 }
1552+
1553+ #[ test]
1554+ #[ cfg_attr( miri, ignore) ]
1555+ fn test_round_robin_deterministic ( ) {
1556+ // Test that round robin partitioning produces identical results when run multiple times
1557+ use std:: fs;
1558+ use std:: io:: Read ;
1559+
1560+ let batch_size = 1000 ;
1561+ let num_batches = 10 ;
1562+ let num_partitions = 8 ;
1563+
1564+ let batch = create_batch ( batch_size) ;
1565+ let batches = ( 0 ..num_batches) . map ( |_| batch. clone ( ) ) . collect :: < Vec < _ > > ( ) ;
1566+
1567+ // Run shuffle twice and compare results
1568+ for run in 0 ..2 {
1569+ let data_file = format ! ( "/tmp/rr_data_{}.out" , run) ;
1570+ let index_file = format ! ( "/tmp/rr_index_{}.out" , run) ;
1571+
1572+ let partitions = std:: slice:: from_ref ( & batches) ;
1573+ let exec = ShuffleWriterExec :: try_new (
1574+ Arc :: new ( DataSourceExec :: new ( Arc :: new (
1575+ MemorySourceConfig :: try_new ( partitions, batch. schema ( ) , None ) . unwrap ( ) ,
1576+ ) ) ) ,
1577+ CometPartitioning :: RoundRobin ( num_partitions, 0 ) ,
1578+ CompressionCodec :: Zstd ( 1 ) ,
1579+ data_file. clone ( ) ,
1580+ index_file. clone ( ) ,
1581+ false ,
1582+ 1024 * 1024 ,
1583+ )
1584+ . unwrap ( ) ;
1585+
1586+ let config = SessionConfig :: new ( ) ;
1587+ let runtime_env = Arc :: new (
1588+ RuntimeEnvBuilder :: new ( )
1589+ . with_memory_limit ( 10 * 1024 * 1024 , 1.0 )
1590+ . build ( )
1591+ . unwrap ( ) ,
1592+ ) ;
1593+ let session_ctx = Arc :: new ( SessionContext :: new_with_config_rt ( config, runtime_env) ) ;
1594+ let task_ctx = Arc :: new ( TaskContext :: from ( session_ctx. as_ref ( ) ) ) ;
1595+
1596+ // Execute the shuffle
1597+ futures:: executor:: block_on ( async {
1598+ let mut stream = exec. execute ( 0 , Arc :: clone ( & task_ctx) ) . unwrap ( ) ;
1599+ while stream. next ( ) . await . is_some ( ) { }
1600+ } ) ;
1601+
1602+ if run == 1 {
1603+ // Compare data files
1604+ let mut data0 = Vec :: new ( ) ;
1605+ fs:: File :: open ( "/tmp/rr_data_0.out" )
1606+ . unwrap ( )
1607+ . read_to_end ( & mut data0)
1608+ . unwrap ( ) ;
1609+ let mut data1 = Vec :: new ( ) ;
1610+ fs:: File :: open ( "/tmp/rr_data_1.out" )
1611+ . unwrap ( )
1612+ . read_to_end ( & mut data1)
1613+ . unwrap ( ) ;
1614+ assert_eq ! (
1615+ data0, data1,
1616+ "Round robin shuffle data should be identical across runs"
1617+ ) ;
1618+
1619+ // Compare index files
1620+ let mut index0 = Vec :: new ( ) ;
1621+ fs:: File :: open ( "/tmp/rr_index_0.out" )
1622+ . unwrap ( )
1623+ . read_to_end ( & mut index0)
1624+ . unwrap ( ) ;
1625+ let mut index1 = Vec :: new ( ) ;
1626+ fs:: File :: open ( "/tmp/rr_index_1.out" )
1627+ . unwrap ( )
1628+ . read_to_end ( & mut index1)
1629+ . unwrap ( ) ;
1630+ assert_eq ! (
1631+ index0, index1,
1632+ "Round robin shuffle index should be identical across runs"
1633+ ) ;
1634+ }
1635+ }
1636+
1637+ // Clean up
1638+ let _ = fs:: remove_file ( "/tmp/rr_data_0.out" ) ;
1639+ let _ = fs:: remove_file ( "/tmp/rr_index_0.out" ) ;
1640+ let _ = fs:: remove_file ( "/tmp/rr_data_1.out" ) ;
1641+ let _ = fs:: remove_file ( "/tmp/rr_index_1.out" ) ;
1642+ }
14861643}
0 commit comments