@@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.files
1818
1919import org .apache .gluten .backendsapi .BackendsApiManager
2020import org .apache .gluten .backendsapi .velox .VeloxBatchType
21+ import org .apache .gluten .columnarbatch .VeloxColumnarBatches
2122import org .apache .gluten .config .GlutenConfig
2223import org .apache .gluten .execution ._
2324import org .apache .gluten .execution .datasource .GlutenFormatFactory
@@ -46,6 +47,7 @@ import org.apache.spark.sql.execution.datasources.FileFormatWriter._
4647import org .apache .spark .sql .execution .metric .SQLMetric
4748import org .apache .spark .sql .internal .SQLConf
4849import org .apache .spark .sql .types .DataType
50+ import org .apache .spark .sql .vectorized .ColumnarBatch
4951import org .apache .spark .util .{SerializableConfiguration , Utils }
5052
5153import org .apache .hadoop .conf .Configuration
@@ -583,42 +585,84 @@ object GlutenDeltaFileFormatWriter extends LoggingShims {
583585 record match {
584586 case carrierRow : BatchCarrierRow =>
585587 carrierRow match {
586- case placeholderRow : PlaceholderRow =>
588+ case _ : PlaceholderRow =>
587589 // Do nothing.
588590 case terminalRow : TerminalRow =>
589- val numRows = terminalRow.batch().numRows()
590- if (numRows > 0 ) {
591- val blockStripes = GlutenFormatFactory .rowSplitter
592- .splitBlockByPartitionAndBucket(terminalRow.batch(), partitionColIndice,
593- isBucketed)
594- val iter = blockStripes.iterator()
595- while (iter.hasNext) {
596- val blockStripe = iter.next()
597- val headingRow = blockStripe.getHeadingRow
598- beforeWrite(headingRow)
599- val currentColumnBatch = blockStripe.getColumnarBatch
600- val numRowsOfCurrentColumnarBatch = currentColumnBatch.numRows()
601- assert(numRowsOfCurrentColumnarBatch > 0 )
602- val currentTerminalRow = terminalRow.withNewBatch(currentColumnBatch)
603- currentWriter.write(currentTerminalRow)
604- statsTrackers.foreach {
605- tracker =>
606- tracker.newRow(currentWriter.path, currentTerminalRow)
607- for (_ <- 0 until numRowsOfCurrentColumnarBatch - 1 ) {
608- tracker.newRow(currentWriter.path, new PlaceholderRow ())
609- }
610- }
611- currentColumnBatch.close()
612- }
613- blockStripes.release()
614- recordsInFile += numRows
615- }
591+ writePartitionedBatch(terminalRow)
616592 }
617593 case _ =>
618594 beforeWrite(record)
619595 writeRecord(record)
620596 }
621597 }
598+
599+ private def writeCurrentBatch (terminalRow : TerminalRow , rowCount : Int ): Unit = {
600+ assert(rowCount > 0 )
601+ currentWriter.write(terminalRow)
602+ statsTrackers.foreach(_.newRow(currentWriter.path, terminalRow))
603+ recordsInFile += rowCount
604+ }
605+
606+ private def writeCurrentBatchWithMaxRecords (
607+ terminalRow : TerminalRow ,
608+ columnBatch : ColumnarBatch ): Unit = {
609+ val numRows = columnBatch.numRows()
610+ var offset = 0
611+ while (offset < numRows) {
612+ val rowsRemaining = numRows - offset
613+ val rowsToWrite = if (description.maxRecordsPerFile > 0 ) {
614+ if (recordsInFile >= description.maxRecordsPerFile) {
615+ renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId)
616+ }
617+ math.min(rowsRemaining.toLong, description.maxRecordsPerFile - recordsInFile).toInt
618+ } else {
619+ rowsRemaining
620+ }
621+
622+ assert(rowsToWrite > 0 )
623+ val batchToWrite =
624+ if (offset == 0 && rowsToWrite == numRows) {
625+ columnBatch
626+ } else {
627+ VeloxColumnarBatches .slice(columnBatch, offset, rowsToWrite)
628+ }
629+ try {
630+ writeCurrentBatch(terminalRow.withNewBatch(batchToWrite), rowsToWrite)
631+ } finally {
632+ if (batchToWrite ne columnBatch) {
633+ batchToWrite.close()
634+ }
635+ }
636+ offset += rowsToWrite
637+ }
638+ }
639+
640+ private def writePartitionStripe (terminalRow : TerminalRow , blockStripe : BlockStripe ): Unit = {
641+ beforeWrite(blockStripe.getHeadingRow)
642+ val currentColumnBatch = blockStripe.getColumnarBatch
643+ try {
644+ assert(currentColumnBatch.numRows() > 0 )
645+ writeCurrentBatchWithMaxRecords(terminalRow, currentColumnBatch)
646+ } finally {
647+ currentColumnBatch.close()
648+ }
649+ }
650+
651+ private def writePartitionedBatch (terminalRow : TerminalRow ): Unit = {
652+ val numRows = terminalRow.batch().numRows()
653+ if (numRows > 0 ) {
654+ val blockStripes = GlutenFormatFactory .rowSplitter
655+ .splitBlockByPartitionAndBucket(terminalRow.batch(), partitionColIndice, isBucketed)
656+ try {
657+ val iter = blockStripes.iterator()
658+ while (iter.hasNext) {
659+ writePartitionStripe(terminalRow, iter.next())
660+ }
661+ } finally {
662+ blockStripes.release()
663+ }
664+ }
665+ }
622666 }
623667}
624668// spotless:on
0 commit comments