|
23 | 23 | import java.io.IOException; |
24 | 24 | import java.util.LinkedList; |
25 | 25 | import java.util.concurrent.*; |
26 | | -import javax.annotation.Nullable; |
27 | 26 |
|
28 | 27 | import scala.Tuple2; |
29 | 28 |
|
|
32 | 31 |
|
33 | 32 | import org.apache.spark.SparkConf; |
34 | 33 | import org.apache.spark.TaskContext; |
35 | | -import org.apache.spark.executor.ShuffleWriteMetrics; |
36 | 34 | import org.apache.spark.memory.SparkOutOfMemoryError; |
37 | 35 | import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; |
38 | 36 | import org.apache.spark.shuffle.comet.CometShuffleChecksumSupport; |
|
41 | 39 | import org.apache.spark.sql.comet.execution.shuffle.CometUnsafeShuffleWriter; |
42 | 40 | import org.apache.spark.sql.comet.execution.shuffle.ShuffleThreadPool; |
43 | 41 | import org.apache.spark.sql.comet.execution.shuffle.SpillInfo; |
44 | | -import org.apache.spark.sql.comet.execution.shuffle.SpillWriter; |
45 | 42 | import org.apache.spark.sql.types.StructType; |
46 | 43 | import org.apache.spark.storage.BlockManager; |
47 | 44 | import org.apache.spark.storage.TempShuffleBlockId; |
48 | | -import org.apache.spark.unsafe.Platform; |
49 | 45 | import org.apache.spark.unsafe.UnsafeAlignedOffset; |
50 | 46 | import org.apache.spark.unsafe.array.LongArray; |
51 | 47 | import org.apache.spark.util.Utils; |
52 | 48 |
|
53 | 49 | import org.apache.comet.CometConf$; |
54 | | -import org.apache.comet.Native; |
55 | 50 |
|
56 | 51 | /** |
57 | 52 | * An external sorter that is specialized for sort-based shuffle. |
@@ -169,10 +164,28 @@ public CometShuffleExternalSorter( |
169 | 164 | this.threadPool = null; |
170 | 165 | } |
171 | 166 |
|
172 | | - this.activeSpillSorter = new SpillSorter(); |
173 | | - |
174 | 167 | this.preferDictionaryRatio = |
175 | 168 | (double) CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get(); |
| 169 | + |
| 170 | + this.activeSpillSorter = createSpillSorter(); |
| 171 | + } |
| 172 | + |
| 173 | + /** Creates a new SpillSorter with all required dependencies. */ |
| 174 | + private SpillSorter createSpillSorter() { |
| 175 | + return new SpillSorter( |
| 176 | + allocator, |
| 177 | + initialSize, |
| 178 | + schema, |
| 179 | + uaoSize, |
| 180 | + preferDictionaryRatio, |
| 181 | + compressionCodec, |
| 182 | + compressionLevel, |
| 183 | + checksumAlgorithm, |
| 184 | + partitionChecksums, |
| 185 | + writeMetrics, |
| 186 | + taskContext, |
| 187 | + spills, |
| 188 | + this::spill); |
176 | 189 | } |
177 | 190 |
|
178 | 191 | public long[] getChecksums() { |
@@ -237,7 +250,7 @@ public void spill() throws IOException { |
237 | 250 | } |
238 | 251 | } |
239 | 252 |
|
240 | | - activeSpillSorter = new SpillSorter(); |
| 253 | + activeSpillSorter = createSpillSorter(); |
241 | 254 | } else { |
242 | 255 | activeSpillSorter.writeSortedFileNative(false, tracingEnabled); |
243 | 256 | final long spillSize = activeSpillSorter.freeMemory(); |
@@ -410,243 +423,4 @@ public SpillInfo[] closeAndGetSpills() throws IOException { |
410 | 423 |
|
411 | 424 | return spills.toArray(new SpillInfo[spills.size()]); |
412 | 425 | } |
413 | | - |
414 | | - class SpillSorter extends SpillWriter { |
415 | | - private boolean freed = false; |
416 | | - |
417 | | - private SpillInfo spillInfo; |
418 | | - |
419 | | - // These variables are reset after spilling: |
420 | | - @Nullable private ShuffleInMemorySorter inMemSorter; |
421 | | - |
422 | | - // This external sorter can call native code to sort partition ids and record pointers of rows. |
423 | | - // In order to do that, we need pass the address of the internal array in the sorter to native. |
424 | | - // But we cannot access it as it is private member in the Spark sorter. Instead, we allocate |
425 | | - // the array and assign the pointer array in the sorter. |
426 | | - private LongArray sorterArray; |
427 | | - |
428 | | - SpillSorter() { |
429 | | - this.spillInfo = null; |
430 | | - |
431 | | - this.allocator = CometShuffleExternalSorter.this.allocator; |
432 | | - |
433 | | - // Allocate array for in-memory sorter. |
434 | | - // As we cannot access the address of the internal array in the sorter, so we need to |
435 | | - // allocate the array manually and expand the pointer array in the sorter. |
436 | | - // We don't want in-memory sorter to allocate memory but the initial size cannot be zero. |
437 | | - try { |
438 | | - this.inMemSorter = new ShuffleInMemorySorter(allocator, 1, true); |
439 | | - } catch (java.lang.IllegalAccessError e) { |
440 | | - throw new java.lang.RuntimeException( |
441 | | - "Error loading in-memory sorter check class path -- see " |
442 | | - + "https://github.com/apache/arrow-datafusion-comet?tab=readme-ov-file#enable-comet-shuffle", |
443 | | - e); |
444 | | - } |
445 | | - sorterArray = allocator.allocateArray(initialSize); |
446 | | - this.inMemSorter.expandPointerArray(sorterArray); |
447 | | - |
448 | | - this.allocatedPages = new LinkedList<>(); |
449 | | - |
450 | | - this.nativeLib = new Native(); |
451 | | - this.dataTypes = serializeSchema(schema); |
452 | | - } |
453 | | - |
454 | | - /** Frees allocated memory pages of this writer */ |
455 | | - @Override |
456 | | - public long freeMemory() { |
457 | | - // We need to synchronize here because we may get the memory usage by calling |
458 | | - // this method in the task thread. |
459 | | - synchronized (this) { |
460 | | - return super.freeMemory(); |
461 | | - } |
462 | | - } |
463 | | - |
464 | | - @Override |
465 | | - public long getMemoryUsage() { |
466 | | - // We need to synchronize here because we may free the memory pages in another thread, |
467 | | - // i.e. when spilling, but this method may be called in the task thread. |
468 | | - synchronized (this) { |
469 | | - long totalPageSize = super.getMemoryUsage(); |
470 | | - |
471 | | - if (freed) { |
472 | | - return totalPageSize; |
473 | | - } else { |
474 | | - return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; |
475 | | - } |
476 | | - } |
477 | | - } |
478 | | - |
479 | | - @Override |
480 | | - protected void spill(int required) throws IOException { |
481 | | - CometShuffleExternalSorter.this.spill(); |
482 | | - } |
483 | | - |
484 | | - /** Free the pointer array held by this sorter. */ |
485 | | - public void freeArray() { |
486 | | - synchronized (this) { |
487 | | - inMemSorter.free(); |
488 | | - freed = true; |
489 | | - } |
490 | | - } |
491 | | - |
492 | | - /** |
493 | | - * Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the |
494 | | - * records. |
495 | | - */ |
496 | | - public void reset() { |
497 | | - // We allocate pointer array outside the sorter. |
498 | | - // So we can get array address which can be used by native code. |
499 | | - inMemSorter.reset(); |
500 | | - sorterArray = allocator.allocateArray(initialSize); |
501 | | - inMemSorter.expandPointerArray(sorterArray); |
502 | | - } |
503 | | - |
504 | | - void setSpillInfo(SpillInfo spillInfo) { |
505 | | - this.spillInfo = spillInfo; |
506 | | - } |
507 | | - |
508 | | - public int numRecords() { |
509 | | - return this.inMemSorter.numRecords(); |
510 | | - } |
511 | | - |
512 | | - public void writeSortedFileNative(boolean isLastFile, boolean tracingEnabled) |
513 | | - throws IOException { |
514 | | - // This call performs the actual sort. |
515 | | - long arrayAddr = this.sorterArray.getBaseOffset(); |
516 | | - int pos = inMemSorter.numRecords(); |
517 | | - nativeLib.sortRowPartitionsNative(arrayAddr, pos, tracingEnabled); |
518 | | - ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = |
519 | | - new ShuffleInMemorySorter.ShuffleSorterIterator(pos, this.sorterArray, 0); |
520 | | - |
521 | | - // If there are no sorted records, so we don't need to create an empty spill file. |
522 | | - if (!sortedRecords.hasNext()) { |
523 | | - return; |
524 | | - } |
525 | | - |
526 | | - final ShuffleWriteMetricsReporter writeMetricsToUse; |
527 | | - |
528 | | - if (isLastFile) { |
529 | | - // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. |
530 | | - writeMetricsToUse = writeMetrics; |
531 | | - } else { |
532 | | - // We're spilling, so bytes written should be counted towards spill rather than write. |
533 | | - // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count |
534 | | - // them towards shuffle bytes written. |
535 | | - writeMetricsToUse = new ShuffleWriteMetrics(); |
536 | | - } |
537 | | - |
538 | | - int currentPartition = -1; |
539 | | - |
540 | | - final RowPartition rowPartition = new RowPartition(initialSize); |
541 | | - |
542 | | - while (sortedRecords.hasNext()) { |
543 | | - sortedRecords.loadNext(); |
544 | | - final int partition = sortedRecords.packedRecordPointer.getPartitionId(); |
545 | | - assert (partition >= currentPartition); |
546 | | - if (partition != currentPartition) { |
547 | | - // Switch to the new partition |
548 | | - if (currentPartition != -1) { |
549 | | - |
550 | | - if (partitionChecksums.length > 0) { |
551 | | - // If checksum is enabled, we need to update the checksum for the current partition. |
552 | | - setChecksum(partitionChecksums[currentPartition]); |
553 | | - setChecksumAlgo(checksumAlgorithm); |
554 | | - } |
555 | | - |
556 | | - long written = |
557 | | - doSpilling( |
558 | | - dataTypes, |
559 | | - spillInfo.file, |
560 | | - rowPartition, |
561 | | - writeMetricsToUse, |
562 | | - preferDictionaryRatio, |
563 | | - compressionCodec, |
564 | | - compressionLevel, |
565 | | - tracingEnabled); |
566 | | - spillInfo.partitionLengths[currentPartition] = written; |
567 | | - |
568 | | - // Store the checksum for the current partition. |
569 | | - partitionChecksums[currentPartition] = getChecksum(); |
570 | | - } |
571 | | - currentPartition = partition; |
572 | | - } |
573 | | - |
574 | | - final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); |
575 | | - final long recordOffsetInPage = allocator.getOffsetInPage(recordPointer); |
576 | | - // Note that we need to skip over record key (partition id) |
577 | | - // Note that we already use off-heap memory for serialized rows, so recordPage is always |
578 | | - // null. |
579 | | - int recordSizeInBytes = UnsafeAlignedOffset.getSize(null, recordOffsetInPage) - 4; |
580 | | - long recordReadPosition = recordOffsetInPage + uaoSize + 4; // skip over record length too |
581 | | - rowPartition.addRow(recordReadPosition, recordSizeInBytes); |
582 | | - } |
583 | | - |
584 | | - if (currentPartition != -1) { |
585 | | - long written = |
586 | | - doSpilling( |
587 | | - dataTypes, |
588 | | - spillInfo.file, |
589 | | - rowPartition, |
590 | | - writeMetricsToUse, |
591 | | - preferDictionaryRatio, |
592 | | - compressionCodec, |
593 | | - compressionLevel, |
594 | | - tracingEnabled); |
595 | | - spillInfo.partitionLengths[currentPartition] = written; |
596 | | - |
597 | | - synchronized (spills) { |
598 | | - spills.add(spillInfo); |
599 | | - } |
600 | | - } |
601 | | - |
602 | | - if (!isLastFile) { // i.e. this is a spill file |
603 | | - // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when |
604 | | - // records |
605 | | - // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter |
606 | | - // relies on its `recordWritten()` method being called in order to trigger periodic updates |
607 | | - // to |
608 | | - // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that |
609 | | - // counter at a higher-level, then the in-progress metrics for records written and bytes |
610 | | - // written would get out of sync. |
611 | | - // |
612 | | - // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; |
613 | | - // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those |
614 | | - // metrics to the true write metrics here. The reason for performing this copying is so that |
615 | | - // we can avoid reporting spilled bytes as shuffle write bytes. |
616 | | - // |
617 | | - // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. |
618 | | - // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. |
619 | | - // SPARK-3577 tracks the spill time separately. |
620 | | - |
621 | | - // This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning |
622 | | - // of this method. |
623 | | - synchronized (writeMetrics) { |
624 | | - writeMetrics.incRecordsWritten( |
625 | | - ((ShuffleWriteMetrics) writeMetricsToUse).recordsWritten()); |
626 | | - taskContext |
627 | | - .taskMetrics() |
628 | | - .incDiskBytesSpilled(((ShuffleWriteMetrics) writeMetricsToUse).bytesWritten()); |
629 | | - } |
630 | | - } |
631 | | - } |
632 | | - |
633 | | - public boolean hasSpaceForAnotherRecord() { |
634 | | - return inMemSorter.hasSpaceForAnotherRecord(); |
635 | | - } |
636 | | - |
637 | | - public void expandPointerArray(LongArray newArray) { |
638 | | - inMemSorter.expandPointerArray(newArray); |
639 | | - this.sorterArray = newArray; |
640 | | - } |
641 | | - |
642 | | - public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) { |
643 | | - final Object base = currentPage.getBaseObject(); |
644 | | - final long recordAddress = allocator.encodePageNumberAndOffset(currentPage, pageCursor); |
645 | | - UnsafeAlignedOffset.putSize(base, pageCursor, length); |
646 | | - pageCursor += uaoSize; |
647 | | - Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); |
648 | | - pageCursor += length; |
649 | | - inMemSorter.insertRecord(recordAddress, partitionId); |
650 | | - } |
651 | | - } |
652 | 426 | } |
0 commit comments