Skip to content

Commit 7b61b30

Browse files
authored
chore: Extract some tied down logic (#3374)
1 parent 7f7ad74 commit 7b61b30

3 files changed

Lines changed: 127 additions & 114 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::physical_plan::metrics::{
19+
BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time,
20+
};
21+
22+
pub(super) struct ShufflePartitionerMetrics {
23+
/// metrics
24+
pub(super) baseline: BaselineMetrics,
25+
26+
/// Time to perform repartitioning
27+
pub(super) repart_time: Time,
28+
29+
/// Time encoding batches to IPC format
30+
pub(super) encode_time: Time,
31+
32+
/// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics.
33+
pub(super) write_time: Time,
34+
35+
/// Number of input batches
36+
pub(super) input_batches: Count,
37+
38+
/// count of spills during the execution of the operator
39+
pub(super) spill_count: Count,
40+
41+
/// total spilled bytes during the execution of the operator
42+
pub(super) spilled_bytes: Count,
43+
44+
/// The original size of spilled data. Different to `spilled_bytes` because of compression.
45+
pub(super) data_size: Count,
46+
}
47+
48+
impl ShufflePartitionerMetrics {
49+
pub(super) fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
50+
Self {
51+
baseline: BaselineMetrics::new(metrics, partition),
52+
repart_time: MetricBuilder::new(metrics).subset_time("repart_time", partition),
53+
encode_time: MetricBuilder::new(metrics).subset_time("encode_time", partition),
54+
write_time: MetricBuilder::new(metrics).subset_time("write_time", partition),
55+
input_batches: MetricBuilder::new(metrics).counter("input_batches", partition),
56+
spill_count: MetricBuilder::new(metrics).spill_count(partition),
57+
spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition),
58+
data_size: MetricBuilder::new(metrics).counter("data_size", partition),
59+
}
60+
}
61+
}

native/core/src/execution/shuffle/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
pub(crate) mod codec;
1919
mod comet_partitioning;
20+
mod metrics;
2021
mod shuffle_writer;
2122
pub mod spark_unsafe;
2223

native/core/src/execution/shuffle/shuffle_writer.rs

Lines changed: 65 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Defines the External shuffle repartition plan.
1919
20+
use crate::execution::shuffle::metrics::ShufflePartitionerMetrics;
2021
use crate::execution::shuffle::{CometPartitioning, CompressionCodec, ShuffleBlockWriter};
2122
use crate::execution::tracing::{with_trace, with_trace_async};
2223
use arrow::compute::interleave_record_batch;
@@ -35,9 +36,7 @@ use datafusion::{
3536
runtime_env::RuntimeEnv,
3637
},
3738
physical_plan::{
38-
metrics::{
39-
BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time,
40-
},
39+
metrics::{ExecutionPlanMetricsSet, MetricsSet, Time},
4140
stream::RecordBatchStreamAdapter,
4241
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
4342
Statistics,
@@ -185,7 +184,7 @@ impl ExecutionPlan for ShuffleWriterExec {
185184
context: Arc<TaskContext>,
186185
) -> Result<SendableRecordBatchStream> {
187186
let input = self.input.execute(partition, Arc::clone(&context))?;
188-
let metrics = ShuffleRepartitionerMetrics::new(&self.metrics, 0);
187+
let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0);
189188

190189
Ok(Box::pin(RecordBatchStreamAdapter::new(
191190
self.schema(),
@@ -216,7 +215,7 @@ async fn external_shuffle(
216215
output_data_file: String,
217216
output_index_file: String,
218217
partitioning: CometPartitioning,
219-
metrics: ShuffleRepartitionerMetrics,
218+
metrics: ShufflePartitionerMetrics,
220219
context: Arc<TaskContext>,
221220
codec: CompressionCodec,
222221
tracing_enabled: bool,
@@ -268,47 +267,6 @@ async fn external_shuffle(
268267
.await
269268
}
270269

271-
struct ShuffleRepartitionerMetrics {
272-
/// metrics
273-
baseline: BaselineMetrics,
274-
275-
/// Time to perform repartitioning
276-
repart_time: Time,
277-
278-
/// Time encoding batches to IPC format
279-
encode_time: Time,
280-
281-
/// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics.
282-
write_time: Time,
283-
284-
/// Number of input batches
285-
input_batches: Count,
286-
287-
/// count of spills during the execution of the operator
288-
spill_count: Count,
289-
290-
/// total spilled bytes during the execution of the operator
291-
spilled_bytes: Count,
292-
293-
/// The original size of spilled data. Different to `spilled_bytes` because of compression.
294-
data_size: Count,
295-
}
296-
297-
impl ShuffleRepartitionerMetrics {
298-
fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
299-
Self {
300-
baseline: BaselineMetrics::new(metrics, partition),
301-
repart_time: MetricBuilder::new(metrics).subset_time("repart_time", partition),
302-
encode_time: MetricBuilder::new(metrics).subset_time("encode_time", partition),
303-
write_time: MetricBuilder::new(metrics).subset_time("write_time", partition),
304-
input_batches: MetricBuilder::new(metrics).counter("input_batches", partition),
305-
spill_count: MetricBuilder::new(metrics).spill_count(partition),
306-
spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition),
307-
data_size: MetricBuilder::new(metrics).counter("data_size", partition),
308-
}
309-
}
310-
}
311-
312270
#[async_trait::async_trait]
313271
trait ShufflePartitioner: Send + Sync {
314272
/// Insert a batch into the partitioner
@@ -328,7 +286,7 @@ struct MultiPartitionShuffleRepartitioner {
328286
/// Partitioning scheme to use
329287
partitioning: CometPartitioning,
330288
runtime: Arc<RuntimeEnv>,
331-
metrics: ShuffleRepartitionerMetrics,
289+
metrics: ShufflePartitionerMetrics,
332290
/// Reused scratch space for computing partition indices
333291
scratch: ScratchSpace,
334292
/// The configured batch size
@@ -356,6 +314,54 @@ struct ScratchSpace {
356314
partition_starts: Vec<u32>,
357315
}
358316

317+
impl ScratchSpace {
318+
fn map_partition_ids_to_starts_and_indices(
319+
&mut self,
320+
num_output_partitions: usize,
321+
num_rows: usize,
322+
) {
323+
let partition_ids = &mut self.partition_ids[..num_rows];
324+
325+
// count each partition size, while leaving the last extra element as 0
326+
let partition_counters = &mut self.partition_starts;
327+
partition_counters.resize(num_output_partitions + 1, 0);
328+
partition_counters.fill(0);
329+
partition_ids
330+
.iter()
331+
.for_each(|partition_id| partition_counters[*partition_id as usize] += 1);
332+
333+
// accumulate partition counters into partition ends
334+
// e.g. partition counter: [1, 3, 2, 1, 0] => [1, 4, 6, 7, 7]
335+
let partition_ends = partition_counters;
336+
let mut accum = 0;
337+
partition_ends.iter_mut().for_each(|v| {
338+
*v += accum;
339+
accum = *v;
340+
});
341+
342+
// calculate partition row indices and partition starts
343+
// e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] will produce the following partition_row_indices
344+
// and partition_starts arrays:
345+
//
346+
// partition_row_indices: [6, 1, 2, 3, 4, 5, 0]
347+
// partition_starts: [0, 1, 4, 6, 7]
348+
//
349+
// partition_starts conceptually splits partition_row_indices into smaller slices.
350+
// Each slice partition_row_indices[partition_starts[K]..partition_starts[K + 1]] contains the
351+
// row indices of the input batch that are partitioned into partition K. For example,
352+
// first partition 0 has one row index [6], partition 1 has row indices [1, 2, 3], etc.
353+
let partition_row_indices = &mut self.partition_row_indices;
354+
partition_row_indices.resize(num_rows, 0);
355+
for (index, partition_id) in partition_ids.iter().enumerate().rev() {
356+
partition_ends[*partition_id as usize] -= 1;
357+
let end = partition_ends[*partition_id as usize];
358+
partition_row_indices[end as usize] = index as u32;
359+
}
360+
361+
// after calculating, partition ends become partition starts
362+
}
363+
}
364+
359365
impl MultiPartitionShuffleRepartitioner {
360366
#[allow(clippy::too_many_arguments)]
361367
pub fn try_new(
@@ -364,7 +370,7 @@ impl MultiPartitionShuffleRepartitioner {
364370
output_index_file: String,
365371
schema: SchemaRef,
366372
partitioning: CometPartitioning,
367-
metrics: ShuffleRepartitionerMetrics,
373+
metrics: ShufflePartitionerMetrics,
368374
runtime: Arc<RuntimeEnv>,
369375
batch_size: usize,
370376
codec: CompressionCodec,
@@ -432,52 +438,6 @@ impl MultiPartitionShuffleRepartitioner {
432438
return Ok(());
433439
}
434440

435-
fn map_partition_ids_to_starts_and_indices(
436-
scratch: &mut ScratchSpace,
437-
num_output_partitions: usize,
438-
num_rows: usize,
439-
) {
440-
let partition_ids = &mut scratch.partition_ids[..num_rows];
441-
442-
// count each partition size, while leaving the last extra element as 0
443-
let partition_counters = &mut scratch.partition_starts;
444-
partition_counters.resize(num_output_partitions + 1, 0);
445-
partition_counters.fill(0);
446-
partition_ids
447-
.iter()
448-
.for_each(|partition_id| partition_counters[*partition_id as usize] += 1);
449-
450-
// accumulate partition counters into partition ends
451-
// e.g. partition counter: [1, 3, 2, 1, 0] => [1, 4, 6, 7, 7]
452-
let partition_ends = partition_counters;
453-
let mut accum = 0;
454-
partition_ends.iter_mut().for_each(|v| {
455-
*v += accum;
456-
accum = *v;
457-
});
458-
459-
// calculate partition row indices and partition starts
460-
// e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] will produce the following partition_row_indices
461-
// and partition_starts arrays:
462-
//
463-
// partition_row_indices: [6, 1, 2, 3, 4, 5, 0]
464-
// partition_starts: [0, 1, 4, 6, 7]
465-
//
466-
// partition_starts conceptually splits partition_row_indices into smaller slices.
467-
// Each slice partition_row_indices[partition_starts[K]..partition_starts[K + 1]] contains the
468-
// row indices of the input batch that are partitioned into partition K. For example,
469-
// first partition 0 has one row index [6], partition 1 has row indices [1, 2, 3], etc.
470-
let partition_row_indices = &mut scratch.partition_row_indices;
471-
partition_row_indices.resize(num_rows, 0);
472-
for (index, partition_id) in partition_ids.iter().enumerate().rev() {
473-
partition_ends[*partition_id as usize] -= 1;
474-
let end = partition_ends[*partition_id as usize];
475-
partition_row_indices[end as usize] = index as u32;
476-
}
477-
478-
// after calculating, partition ends become partition starts
479-
}
480-
481441
if input.num_rows() > self.batch_size {
482442
return Err(DataFusionError::Internal(
483443
"Input batch size exceeds configured batch size. Call `insert_batch` instead."
@@ -524,11 +484,8 @@ impl MultiPartitionShuffleRepartitioner {
524484

525485
// We now have partition ids for every input row, map that to partition starts
526486
// and partition indices to eventually right these rows to partition buffers.
527-
map_partition_ids_to_starts_and_indices(
528-
&mut scratch,
529-
*num_output_partitions,
530-
num_rows,
531-
);
487+
scratch
488+
.map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows);
532489

533490
timer.stop();
534491
Ok::<(&Vec<u32>, &Vec<u32>), DataFusionError>((
@@ -580,11 +537,8 @@ impl MultiPartitionShuffleRepartitioner {
580537

581538
// We now have partition ids for every input row, map that to partition starts
582539
// and partition indices to eventually right these rows to partition buffers.
583-
map_partition_ids_to_starts_and_indices(
584-
&mut scratch,
585-
*num_output_partitions,
586-
num_rows,
587-
);
540+
scratch
541+
.map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows);
588542

589543
timer.stop();
590544
Ok::<(&Vec<u32>, &Vec<u32>), DataFusionError>((
@@ -642,11 +596,8 @@ impl MultiPartitionShuffleRepartitioner {
642596

643597
// We now have partition ids for every input row, map that to partition starts
644598
// and partition indices to eventually write these rows to partition buffers.
645-
map_partition_ids_to_starts_and_indices(
646-
&mut scratch,
647-
*num_output_partitions,
648-
num_rows,
649-
);
599+
scratch
600+
.map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows);
650601

651602
timer.stop();
652603
Ok::<(&Vec<u32>, &Vec<u32>), DataFusionError>((
@@ -923,7 +874,7 @@ struct SinglePartitionShufflePartitioner {
923874
/// Number of rows in the concatenating batches
924875
num_buffered_rows: usize,
925876
/// Metrics for the repartitioner
926-
metrics: ShuffleRepartitionerMetrics,
877+
metrics: ShufflePartitionerMetrics,
927878
/// The configured batch size
928879
batch_size: usize,
929880
}
@@ -933,7 +884,7 @@ impl SinglePartitionShufflePartitioner {
933884
output_data_path: String,
934885
output_index_path: String,
935886
schema: SchemaRef,
936-
metrics: ShuffleRepartitionerMetrics,
887+
metrics: ShufflePartitionerMetrics,
937888
batch_size: usize,
938889
codec: CompressionCodec,
939890
write_buffer_size: usize,
@@ -1200,7 +1151,7 @@ impl PartitionWriter {
12001151
&mut self,
12011152
iter: &mut PartitionedBatchIterator,
12021153
runtime: &RuntimeEnv,
1203-
metrics: &ShuffleRepartitionerMetrics,
1154+
metrics: &ShufflePartitionerMetrics,
12041155
write_buffer_size: usize,
12051156
) -> Result<usize> {
12061157
if let Some(batch) = iter.next() {
@@ -1393,7 +1344,7 @@ mod test {
13931344
}
13941345

13951346
#[tokio::test]
1396-
async fn shuffle_repartitioner_memory() {
1347+
async fn shuffle_partitioner_memory() {
13971348
let batch = create_batch(900);
13981349
assert_eq!(8316, batch.get_array_memory_size()); // Not stable across Arrow versions
13991350

@@ -1407,7 +1358,7 @@ mod test {
14071358
"/tmp/index.out".to_string(),
14081359
batch.schema(),
14091360
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
1410-
ShuffleRepartitionerMetrics::new(&metrics_set, 0),
1361+
ShufflePartitionerMetrics::new(&metrics_set, 0),
14111362
runtime_env,
14121363
1024,
14131364
CompressionCodec::Lz4Frame,

0 commit comments

Comments
 (0)