Skip to content

Commit 19ab0e7

Browse files
authored
fix: add EmptySchemaShufflePartitioner and test from #3858 (#3893)
1 parent 98adfbd commit 19ab0e7

4 files changed

Lines changed: 337 additions & 1 deletion

File tree

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::metrics::ShufflePartitionerMetrics;
19+
use crate::partitioners::ShufflePartitioner;
20+
use crate::ShuffleBlockWriter;
21+
use arrow::array::RecordBatch;
22+
use arrow::datatypes::SchemaRef;
23+
use datafusion::common::DataFusionError;
24+
use std::fs::OpenOptions;
25+
use std::io::{BufWriter, Seek, Write};
26+
use tokio::time::Instant;
27+
28+
/// A partitioner for zero-column schemas (e.g. queries where ColumnPruning removes all columns).
29+
/// This handles shuffles for operations like COUNT(*) that produce empty-schema record batches
30+
/// but contain a valid row count. Accumulates the total row count and writes a single
31+
/// zero-column IPC batch to partition 0. All other partitions get empty entries in the index file.
32+
pub(crate) struct EmptySchemaShufflePartitioner {
33+
output_data_file: String,
34+
output_index_file: String,
35+
schema: SchemaRef,
36+
shuffle_block_writer: ShuffleBlockWriter,
37+
num_output_partitions: usize,
38+
total_rows: usize,
39+
metrics: ShufflePartitionerMetrics,
40+
}
41+
42+
impl EmptySchemaShufflePartitioner {
43+
pub(crate) fn try_new(
44+
output_data_file: String,
45+
output_index_file: String,
46+
schema: SchemaRef,
47+
num_output_partitions: usize,
48+
metrics: ShufflePartitionerMetrics,
49+
codec: crate::CompressionCodec,
50+
) -> datafusion::common::Result<Self> {
51+
debug_assert!(
52+
schema.fields().is_empty(),
53+
"EmptySchemaShufflePartitioner requires a zero-column schema"
54+
);
55+
let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec)?;
56+
Ok(Self {
57+
output_data_file,
58+
output_index_file,
59+
schema,
60+
shuffle_block_writer,
61+
num_output_partitions,
62+
total_rows: 0,
63+
metrics,
64+
})
65+
}
66+
}
67+
68+
#[async_trait::async_trait]
69+
impl ShufflePartitioner for EmptySchemaShufflePartitioner {
70+
async fn insert_batch(&mut self, batch: RecordBatch) -> datafusion::common::Result<()> {
71+
let start_time = Instant::now();
72+
let num_rows = batch.num_rows();
73+
if num_rows > 0 {
74+
self.total_rows += num_rows;
75+
self.metrics.baseline.record_output(num_rows);
76+
}
77+
self.metrics.input_batches.add(1);
78+
self.metrics
79+
.baseline
80+
.elapsed_compute()
81+
.add_duration(start_time.elapsed());
82+
Ok(())
83+
}
84+
85+
fn shuffle_write(&mut self) -> datafusion::common::Result<()> {
86+
let start_time = Instant::now();
87+
88+
let output_data = OpenOptions::new()
89+
.write(true)
90+
.create(true)
91+
.truncate(true)
92+
.open(&self.output_data_file)
93+
.map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?;
94+
let mut output_data = BufWriter::new(output_data);
95+
96+
// Write a single zero-column batch with the accumulated row count to partition 0
97+
if self.total_rows > 0 {
98+
let batch = RecordBatch::try_new_with_options(
99+
self.schema.clone(),
100+
vec![],
101+
&arrow::array::RecordBatchOptions::new().with_row_count(Some(self.total_rows)),
102+
)?;
103+
self.shuffle_block_writer.write_batch(
104+
&batch,
105+
&mut output_data,
106+
&self.metrics.encode_time,
107+
)?;
108+
}
109+
110+
let mut write_timer = self.metrics.write_time.timer();
111+
output_data.flush()?;
112+
let data_file_length = output_data.stream_position()?;
113+
114+
// Write index file: partition 0 spans [0, data_file_length), all others are empty
115+
let index_file = OpenOptions::new()
116+
.write(true)
117+
.create(true)
118+
.truncate(true)
119+
.open(&self.output_index_file)
120+
.map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?;
121+
let mut index_writer = BufWriter::new(index_file);
122+
index_writer.write_all(&0i64.to_le_bytes())?;
123+
for _ in 0..self.num_output_partitions {
124+
index_writer.write_all(&(data_file_length as i64).to_le_bytes())?;
125+
}
126+
index_writer.flush()?;
127+
write_timer.stop();
128+
129+
self.metrics
130+
.baseline
131+
.elapsed_compute()
132+
.add_duration(start_time.elapsed());
133+
Ok(())
134+
}
135+
}

native/shuffle/src/partitioners/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod empty_schema;
1819
mod multi_partition;
1920
mod partitioned_batch_iterator;
2021
mod single_partition;
2122
mod traits;
2223

24+
pub(crate) use empty_schema::EmptySchemaShufflePartitioner;
2325
pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner;
2426
pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator;
2527
pub(crate) use single_partition::SinglePartitionShufflePartitioner;

native/shuffle/src/shuffle_writer.rs

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
2020
use crate::metrics::ShufflePartitionerMetrics;
2121
use crate::partitioners::{
22-
MultiPartitionShuffleRepartitioner, ShufflePartitioner, SinglePartitionShufflePartitioner,
22+
EmptySchemaShufflePartitioner, MultiPartitionShuffleRepartitioner, ShufflePartitioner,
23+
SinglePartitionShufflePartitioner,
2324
};
2425
use crate::{CometPartitioning, CompressionCodec};
2526
use async_trait::async_trait;
@@ -210,6 +211,17 @@ async fn external_shuffle(
210211
let schema = input.schema();
211212

212213
let mut repartitioner: Box<dyn ShufflePartitioner> = match &partitioning {
214+
_ if schema.fields().is_empty() => {
215+
log::debug!("found empty schema, overriding {partitioning:?} partitioning with EmptySchemaShufflePartitioner");
216+
Box::new(EmptySchemaShufflePartitioner::try_new(
217+
output_data_file,
218+
output_index_file,
219+
Arc::clone(&schema),
220+
partitioning.partition_count(),
221+
metrics,
222+
codec,
223+
)?)
224+
}
213225
any if any.partition_count() == 1 => {
214226
Box::new(SinglePartitionShufflePartitioner::try_new(
215227
output_data_file,
@@ -688,4 +700,160 @@ mod test {
688700
}
689701
total_rows
690702
}
703+
704+
#[test]
705+
#[cfg_attr(miri, ignore)]
706+
fn test_empty_schema_shuffle_writer() {
707+
use std::fs;
708+
use std::io::Read;
709+
710+
let num_rows = 1000;
711+
let num_batches = 5;
712+
let num_partitions = 10;
713+
714+
let schema = Arc::new(Schema::new(Vec::<Field>::new()));
715+
let batch = RecordBatch::try_new_with_options(
716+
Arc::clone(&schema),
717+
vec![],
718+
&arrow::array::RecordBatchOptions::new().with_row_count(Some(num_rows)),
719+
)
720+
.unwrap();
721+
722+
let batches = (0..num_batches).map(|_| batch.clone()).collect::<Vec<_>>();
723+
let partitions = &[batches];
724+
725+
let dir = tempfile::tempdir().unwrap();
726+
let data_file = dir.path().join("data.out");
727+
let index_file = dir.path().join("index.out");
728+
729+
let exec = ShuffleWriterExec::try_new(
730+
Arc::new(DataSourceExec::new(Arc::new(
731+
MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
732+
))),
733+
CometPartitioning::RoundRobin(num_partitions, 0),
734+
CompressionCodec::Zstd(1),
735+
data_file.to_str().unwrap().to_string(),
736+
index_file.to_str().unwrap().to_string(),
737+
false,
738+
1024 * 1024,
739+
)
740+
.unwrap();
741+
742+
let config = SessionConfig::new();
743+
let runtime_env = Arc::new(RuntimeEnvBuilder::new().build().unwrap());
744+
let ctx = SessionContext::new_with_config_rt(config, runtime_env);
745+
let task_ctx = ctx.task_ctx();
746+
let stream = exec.execute(0, task_ctx).unwrap();
747+
let rt = Runtime::new().unwrap();
748+
rt.block_on(collect(stream)).unwrap();
749+
750+
// Verify data file is non-empty (contains IPC batch with row count)
751+
let mut data = Vec::new();
752+
fs::File::open(&data_file)
753+
.unwrap()
754+
.read_to_end(&mut data)
755+
.unwrap();
756+
assert!(!data.is_empty(), "Data file should contain IPC data");
757+
758+
// Verify row count survives roundtrip
759+
let total_rows = read_all_ipc_blocks(&data);
760+
assert_eq!(
761+
total_rows,
762+
num_rows * num_batches,
763+
"Row count should survive roundtrip"
764+
);
765+
766+
// Verify index file structure: num_partitions + 1 offsets
767+
let mut index_data = Vec::new();
768+
fs::File::open(&index_file)
769+
.unwrap()
770+
.read_to_end(&mut index_data)
771+
.unwrap();
772+
let expected_index_size = (num_partitions + 1) * 8;
773+
assert_eq!(index_data.len(), expected_index_size);
774+
775+
// First offset should be 0
776+
let first_offset = i64::from_le_bytes(index_data[0..8].try_into().unwrap());
777+
assert_eq!(first_offset, 0);
778+
779+
// Second offset should equal data file length (partition 0 holds all data)
780+
let data_len = data.len() as i64;
781+
let second_offset = i64::from_le_bytes(index_data[8..16].try_into().unwrap());
782+
assert_eq!(second_offset, data_len);
783+
784+
// All remaining offsets should equal data file length (empty partitions)
785+
for i in 2..=num_partitions {
786+
let offset = i64::from_le_bytes(index_data[i * 8..(i + 1) * 8].try_into().unwrap());
787+
assert_eq!(
788+
offset, data_len,
789+
"Partition {i} offset should equal data length"
790+
);
791+
}
792+
}
793+
794+
#[test]
795+
#[cfg_attr(miri, ignore)]
796+
fn test_empty_schema_shuffle_writer_zero_rows() {
797+
use std::fs;
798+
use std::io::Read;
799+
800+
let num_partitions = 4;
801+
802+
let schema = Arc::new(Schema::new(Vec::<Field>::new()));
803+
let batch = RecordBatch::try_new_with_options(
804+
Arc::clone(&schema),
805+
vec![],
806+
&arrow::array::RecordBatchOptions::new().with_row_count(Some(0)),
807+
)
808+
.unwrap();
809+
810+
let batches = vec![batch];
811+
let partitions = &[batches];
812+
813+
let dir = tempfile::tempdir().unwrap();
814+
let data_file = dir.path().join("data.out");
815+
let index_file = dir.path().join("index.out");
816+
817+
let exec = ShuffleWriterExec::try_new(
818+
Arc::new(DataSourceExec::new(Arc::new(
819+
MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(),
820+
))),
821+
CometPartitioning::RoundRobin(num_partitions, 0),
822+
CompressionCodec::Zstd(1),
823+
data_file.to_str().unwrap().to_string(),
824+
index_file.to_str().unwrap().to_string(),
825+
false,
826+
1024 * 1024,
827+
)
828+
.unwrap();
829+
830+
let config = SessionConfig::new();
831+
let runtime_env = Arc::new(RuntimeEnvBuilder::new().build().unwrap());
832+
let ctx = SessionContext::new_with_config_rt(config, runtime_env);
833+
let task_ctx = ctx.task_ctx();
834+
let stream = exec.execute(0, task_ctx).unwrap();
835+
let rt = Runtime::new().unwrap();
836+
rt.block_on(collect(stream)).unwrap();
837+
838+
// Data file should be empty (no rows to write)
839+
let mut data = Vec::new();
840+
fs::File::open(&data_file)
841+
.unwrap()
842+
.read_to_end(&mut data)
843+
.unwrap();
844+
assert!(data.is_empty(), "Data file should be empty with zero rows");
845+
846+
// Index file should have all-zero offsets
847+
let mut index_data = Vec::new();
848+
fs::File::open(&index_file)
849+
.unwrap()
850+
.read_to_end(&mut index_data)
851+
.unwrap();
852+
let expected_index_size = (num_partitions + 1) * 8;
853+
assert_eq!(index_data.len(), expected_index_size);
854+
for i in 0..=num_partitions {
855+
let offset = i64::from_le_bytes(index_data[i * 8..(i + 1) * 8].try_into().unwrap());
856+
assert_eq!(offset, 0, "All offsets should be 0 with zero rows");
857+
}
858+
}
691859
}

spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,4 +468,35 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper
468468
}
469469
}
470470
}
471+
472+
// Regression test for https://github.com/apache/datafusion-comet/issues/3846
473+
test("repartition count") {
474+
withTempPath { dir =>
475+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
476+
spark
477+
.range(1000)
478+
.selectExpr("id", "concat('name_', id) as name")
479+
.repartition(100)
480+
.write
481+
.parquet(dir.toString)
482+
}
483+
withSQLConf(
484+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION,
485+
CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_ENABLED.key -> "true") {
486+
val testDF = spark.read.parquet(dir.toString).repartition(10)
487+
// Verify CometShuffleExchangeExec is in the plan
488+
assert(
489+
find(testDF.queryExecution.executedPlan) {
490+
case _: CometShuffleExchangeExec => true
491+
case _ => false
492+
}.isDefined,
493+
"Expected CometShuffleExchangeExec in the plan")
494+
// Actual validation, no crash
495+
val count = testDF.count()
496+
assert(count == 1000)
497+
// Ensure test df evaluated by Comet
498+
checkSparkAnswerAndOperator(testDF)
499+
}
500+
}
501+
}
471502
}

0 commit comments

Comments
 (0)