Skip to content

Commit 79b83d8

Browse files
authored
chore: Refactor JVM shuffle: Move SpillSorter to top level class and add tests (#3081)
1 parent 33f514a commit 79b83d8

6 files changed

Lines changed: 638 additions & 247 deletions

File tree

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ jobs:
122122
org.apache.comet.exec.CometAsyncShuffleSuite
123123
org.apache.comet.exec.DisableAQECometShuffleSuite
124124
org.apache.comet.exec.DisableAQECometAsyncShuffleSuite
125+
org.apache.spark.shuffle.sort.SpillSorterSuite
125126
- name: "parquet"
126127
value: |
127128
org.apache.comet.parquet.CometParquetWriterSuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ jobs:
8585
org.apache.comet.exec.CometAsyncShuffleSuite
8686
org.apache.comet.exec.DisableAQECometShuffleSuite
8787
org.apache.comet.exec.DisableAQECometAsyncShuffleSuite
88+
org.apache.spark.shuffle.sort.SpillSorterSuite
8889
- name: "parquet"
8990
value: |
9091
org.apache.comet.parquet.CometParquetWriterSuite

dev/ensure-jars-have-correct-contents.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ allowed_expr+="|^org/apache/spark/shuffle/$"
8686
allowed_expr+="|^org/apache/spark/shuffle/sort/$"
8787
allowed_expr+="|^org/apache/spark/shuffle/sort/CometShuffleExternalSorter.*$"
8888
allowed_expr+="|^org/apache/spark/shuffle/sort/RowPartition.class$"
89+
allowed_expr+="|^org/apache/spark/shuffle/sort/SpillSorter.*$"
8990
allowed_expr+="|^org/apache/spark/shuffle/comet/.*$"
9091
allowed_expr+="|^org/apache/spark/sql/$"
9192
# allow ExplainPlanGenerator trait since it may not be available in older Spark versions

spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorter.java

Lines changed: 21 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.io.IOException;
2424
import java.util.LinkedList;
2525
import java.util.concurrent.*;
26-
import javax.annotation.Nullable;
2726

2827
import scala.Tuple2;
2928

@@ -32,7 +31,6 @@
3231

3332
import org.apache.spark.SparkConf;
3433
import org.apache.spark.TaskContext;
35-
import org.apache.spark.executor.ShuffleWriteMetrics;
3634
import org.apache.spark.memory.SparkOutOfMemoryError;
3735
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
3836
import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport;
@@ -41,17 +39,14 @@
4139
import org.apache.spark.sql.comet.execution.shuffle.CometUnsafeShuffleWriter;
4240
import org.apache.spark.sql.comet.execution.shuffle.ShuffleThreadPool;
4341
import org.apache.spark.sql.comet.execution.shuffle.SpillInfo;
44-
import org.apache.spark.sql.comet.execution.shuffle.SpillWriter;
4542
import org.apache.spark.sql.types.StructType;
4643
import org.apache.spark.storage.BlockManager;
4744
import org.apache.spark.storage.TempShuffleBlockId;
48-
import org.apache.spark.unsafe.Platform;
4945
import org.apache.spark.unsafe.UnsafeAlignedOffset;
5046
import org.apache.spark.unsafe.array.LongArray;
5147
import org.apache.spark.util.Utils;
5248

5349
import org.apache.comet.CometConf$;
54-
import org.apache.comet.Native;
5550

5651
/**
5752
* An external sorter that is specialized for sort-based shuffle.
@@ -169,10 +164,28 @@ public CometShuffleExternalSorter(
169164
this.threadPool = null;
170165
}
171166

172-
this.activeSpillSorter = new SpillSorter();
173-
174167
this.preferDictionaryRatio =
175168
(double) CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get();
169+
170+
this.activeSpillSorter = createSpillSorter();
171+
}
172+
173+
/** Creates a new SpillSorter with all required dependencies. */
174+
private SpillSorter createSpillSorter() {
175+
return new SpillSorter(
176+
allocator,
177+
initialSize,
178+
schema,
179+
uaoSize,
180+
preferDictionaryRatio,
181+
compressionCodec,
182+
compressionLevel,
183+
checksumAlgorithm,
184+
partitionChecksums,
185+
writeMetrics,
186+
taskContext,
187+
spills,
188+
this::spill);
176189
}
177190

178191
public long[] getChecksums() {
@@ -237,7 +250,7 @@ public void spill() throws IOException {
237250
}
238251
}
239252

240-
activeSpillSorter = new SpillSorter();
253+
activeSpillSorter = createSpillSorter();
241254
} else {
242255
activeSpillSorter.writeSortedFileNative(false, tracingEnabled);
243256
final long spillSize = activeSpillSorter.freeMemory();
@@ -410,243 +423,4 @@ public SpillInfo[] closeAndGetSpills() throws IOException {
410423

411424
return spills.toArray(new SpillInfo[spills.size()]);
412425
}
413-
414-
class SpillSorter extends SpillWriter {
415-
private boolean freed = false;
416-
417-
private SpillInfo spillInfo;
418-
419-
// These variables are reset after spilling:
420-
@Nullable private ShuffleInMemorySorter inMemSorter;
421-
422-
// This external sorter can call native code to sort partition ids and record pointers of rows.
423-
// In order to do that, we need pass the address of the internal array in the sorter to native.
424-
// But we cannot access it as it is private member in the Spark sorter. Instead, we allocate
425-
// the array and assign the pointer array in the sorter.
426-
private LongArray sorterArray;
427-
428-
SpillSorter() {
429-
this.spillInfo = null;
430-
431-
this.allocator = CometShuffleExternalSorter.this.allocator;
432-
433-
// Allocate array for in-memory sorter.
434-
// As we cannot access the address of the internal array in the sorter, so we need to
435-
// allocate the array manually and expand the pointer array in the sorter.
436-
// We don't want in-memory sorter to allocate memory but the initial size cannot be zero.
437-
try {
438-
this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true);
439-
} catch (java.lang.IllegalAccessError e) {
440-
throw new java.lang.RuntimeException(
441-
"Error loading in-memory sorter check class path -- see "
442-
+ "https://github.com/apache/arrow-datafusion-comet?tab=readme-ov-file#enable-comet-shuffle",
443-
e);
444-
}
445-
sorterArray = allocator.allocateArray(initialSize);
446-
this.inMemSorter.expandPointerArray(sorterArray);
447-
448-
this.allocatedPages = new LinkedList<>();
449-
450-
this.nativeLib = new Native();
451-
this.dataTypes = serializeSchema(schema);
452-
}
453-
454-
/** Frees allocated memory pages of this writer */
455-
@Override
456-
public long freeMemory() {
457-
// We need to synchronize here because we may get the memory usage by calling
458-
// this method in the task thread.
459-
synchronized (this) {
460-
return super.freeMemory();
461-
}
462-
}
463-
464-
@Override
465-
public long getMemoryUsage() {
466-
// We need to synchronize here because we may free the memory pages in another thread,
467-
// i.e. when spilling, but this method may be called in the task thread.
468-
synchronized (this) {
469-
long totalPageSize = super.getMemoryUsage();
470-
471-
if (freed) {
472-
return totalPageSize;
473-
} else {
474-
return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
475-
}
476-
}
477-
}
478-
479-
@Override
480-
protected void spill(int required) throws IOException {
481-
CometShuffleExternalSorter.this.spill();
482-
}
483-
484-
/** Free the pointer array held by this sorter. */
485-
public void freeArray() {
486-
synchronized (this) {
487-
inMemSorter.free();
488-
freed = true;
489-
}
490-
}
491-
492-
/**
493-
* Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
494-
* records.
495-
*/
496-
public void reset() {
497-
// We allocate pointer array outside the sorter.
498-
// So we can get array address which can be used by native code.
499-
inMemSorter.reset();
500-
sorterArray = allocator.allocateArray(initialSize);
501-
inMemSorter.expandPointerArray(sorterArray);
502-
}
503-
504-
void setSpillInfo(SpillInfo spillInfo) {
505-
this.spillInfo = spillInfo;
506-
}
507-
508-
public int numRecords() {
509-
return this.inMemSorter.numRecords();
510-
}
511-
512-
public void writeSortedFileNative(boolean isLastFile, boolean tracingEnabled)
513-
throws IOException {
514-
// This call performs the actual sort.
515-
long arrayAddr = this.sorterArray.getBaseOffset();
516-
int pos = inMemSorter.numRecords();
517-
nativeLib.sortRowPartitionsNative(arrayAddr, pos, tracingEnabled);
518-
ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
519-
new ShuffleInMemorySorter.ShuffleSorterIterator(pos, this.sorterArray, 0);
520-
521-
// If there are no sorted records, so we don't need to create an empty spill file.
522-
if (!sortedRecords.hasNext()) {
523-
return;
524-
}
525-
526-
final ShuffleWriteMetricsReporter writeMetricsToUse;
527-
528-
if (isLastFile) {
529-
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
530-
writeMetricsToUse = writeMetrics;
531-
} else {
532-
// We're spilling, so bytes written should be counted towards spill rather than write.
533-
// Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
534-
// them towards shuffle bytes written.
535-
writeMetricsToUse = new ShuffleWriteMetrics();
536-
}
537-
538-
int currentPartition = -1;
539-
540-
final RowPartition rowPartition = new RowPartition(initialSize);
541-
542-
while (sortedRecords.hasNext()) {
543-
sortedRecords.loadNext();
544-
final int partition = sortedRecords.packedRecordPointer.getPartitionId();
545-
assert (partition >= currentPartition);
546-
if (partition != currentPartition) {
547-
// Switch to the new partition
548-
if (currentPartition != -1) {
549-
550-
if (partitionChecksums.length > 0) {
551-
// If checksum is enabled, we need to update the checksum for the current partition.
552-
setChecksum(partitionChecksums[currentPartition]);
553-
setChecksumAlgo(checksumAlgorithm);
554-
}
555-
556-
long written =
557-
doSpilling(
558-
dataTypes,
559-
spillInfo.file,
560-
rowPartition,
561-
writeMetricsToUse,
562-
preferDictionaryRatio,
563-
compressionCodec,
564-
compressionLevel,
565-
tracingEnabled);
566-
spillInfo.partitionLengths[currentPartition] = written;
567-
568-
// Store the checksum for the current partition.
569-
partitionChecksums[currentPartition] = getChecksum();
570-
}
571-
currentPartition = partition;
572-
}
573-
574-
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
575-
final long recordOffsetInPage = allocator.getOffsetInPage(recordPointer);
576-
// Note that we need to skip over record key (partition id)
577-
// Note that we already use off-heap memory for serialized rows, so recordPage is always
578-
// null.
579-
int recordSizeInBytes = UnsafeAlignedOffset.getSize(null, recordOffsetInPage) - 4;
580-
long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip over record length too
581-
rowPartition.addRow(recordReadPosition, recordSizeInBytes);
582-
}
583-
584-
if (currentPartition != -1) {
585-
long written =
586-
doSpilling(
587-
dataTypes,
588-
spillInfo.file,
589-
rowPartition,
590-
writeMetricsToUse,
591-
preferDictionaryRatio,
592-
compressionCodec,
593-
compressionLevel,
594-
tracingEnabled);
595-
spillInfo.partitionLengths[currentPartition] = written;
596-
597-
synchronized (spills) {
598-
spills.add(spillInfo);
599-
}
600-
}
601-
602-
if (!isLastFile) { // i.e. this is a spill file
603-
// The current semantics of `shuffleRecordsWritten` seem to be that it's updated when
604-
// records
605-
// are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
606-
// relies on its `recordWritten()` method being called in order to trigger periodic updates
607-
// to
608-
// `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
609-
// counter at a higher-level, then the in-progress metrics for records written and bytes
610-
// written would get out of sync.
611-
//
612-
// When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
613-
// in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
614-
// metrics to the true write metrics here. The reason for performing this copying is so that
615-
// we can avoid reporting spilled bytes as shuffle write bytes.
616-
//
617-
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
618-
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
619-
// SPARK-3577 tracks the spill time separately.
620-
621-
// This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning
622-
// of this method.
623-
synchronized (writeMetrics) {
624-
writeMetrics.incRecordsWritten(
625-
((ShuffleWriteMetrics) writeMetricsToUse).recordsWritten());
626-
taskContext
627-
.taskMetrics()
628-
.incDiskBytesSpilled(((ShuffleWriteMetrics) writeMetricsToUse).bytesWritten());
629-
}
630-
}
631-
}
632-
633-
public boolean hasSpaceForAnotherRecord() {
634-
return inMemSorter.hasSpaceForAnotherRecord();
635-
}
636-
637-
public void expandPointerArray(LongArray newArray) {
638-
inMemSorter.expandPointerArray(newArray);
639-
this.sorterArray = newArray;
640-
}
641-
642-
public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) {
643-
final Object base = currentPage.getBaseObject();
644-
final long recordAddress = allocator.encodePageNumberAndOffset(currentPage, pageCursor);
645-
UnsafeAlignedOffset.putSize(base, pageCursor, length);
646-
pageCursor += uaoSize;
647-
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
648-
pageCursor += length;
649-
inMemSorter.insertRecord(recordAddress, partitionId);
650-
}
651-
}
652426
}

0 commit comments

Comments
 (0)