@@ -26,7 +26,7 @@ import java.nio.channels.Channels
2626import scala .jdk .CollectionConverters ._
2727
2828import org .apache .arrow .c .CDataDictionaryProvider
29- import org .apache .arrow .vector .{ BigIntVector , BitVector , DateDayVector , DecimalVector , FieldVector , FixedSizeBinaryVector , Float4Vector , Float8Vector , IntVector , NullVector , SmallIntVector , TimeStampMicroTZVector , TimeStampMicroVector , TinyIntVector , ValueVector , VarBinaryVector , VarCharVector , VectorSchemaRoot }
29+ import org .apache .arrow .vector ._
3030import org .apache .arrow .vector .complex .{ListVector , MapVector , StructVector }
3131import org .apache .arrow .vector .dictionary .DictionaryProvider
3232import org .apache .arrow .vector .ipc .{ArrowStreamReader , ArrowStreamWriter }
@@ -234,6 +234,7 @@ object Utils extends CometTypeShim with Logging {
234234
235235 /**
236236 * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
237+ *
237238 * @param bytes
238239 * the serialized batches
239240 * @param source
@@ -264,10 +265,11 @@ object Utils extends CometTypeShim with Logging {
264265 * re-serialize once via ArrowStreamWriter. This is done on the driver (not per-task) so the
265266 * cost is paid once rather than once per consumer partition.
266267 */
267- def coalesceBroadcastBatches (input : Iterator [ChunkedByteBuffer ]): Array [ChunkedByteBuffer ] = {
268+ def coalesceBroadcastBatches (
269+ input : Iterator [ChunkedByteBuffer ]): (Array [ChunkedByteBuffer ], Long , Long ) = {
268270 val buffers = input.filterNot(_.size == 0 ).toArray
269271 if (buffers.isEmpty) {
270- return Array .empty
272+ return ( Array .empty, 0L , 0L )
271273 }
272274
273275 val allocator = org.apache.comet.CometArrowAllocator
@@ -308,7 +310,7 @@ object Utils extends CometTypeShim with Logging {
308310 }
309311
310312 if (targetRoot == null ) {
311- return Array .empty
313+ return ( Array .empty, 0L , 0L )
312314 }
313315
314316 assert(
@@ -320,7 +322,7 @@ object Utils extends CometTypeShim with Logging {
320322 val outCodec = CompressionCodec .createCodec(SparkEnv .get.conf)
321323 val cbbos = new ChunkedByteBufferOutputStream (1024 * 1024 , ByteBuffer .allocate)
322324 val out = new DataOutputStream (outCodec.compressedOutputStream(cbbos))
323- // null provider is safe here — we assert no dictionary-encoded columns above
325+ // null provider is safe here because we assert no dictionary-encoded columns above
324326 val writer = new ArrowStreamWriter (targetRoot, null , Channels .newChannel(out))
325327 try {
326328 writer.start()
@@ -329,7 +331,7 @@ object Utils extends CometTypeShim with Logging {
329331 writer.close()
330332 }
331333
332- Array (cbbos.toChunkedByteBuffer)
334+ ( Array (cbbos.toChunkedByteBuffer), batchCount.toLong, totalRows )
333335 } finally {
334336 if (targetRoot != null ) {
335337 targetRoot.close()
0 commit comments