diff --git a/lance-spark-3.4_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java b/lance-spark-3.4_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java index b81d3496..64ba8b99 100644 --- a/lance-spark-3.4_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java +++ b/lance-spark-3.4_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java @@ -212,21 +212,16 @@ private static class PositionDeltaWriteFactory implements DeltaWriterFactory { @Override public DeltaWriter createWriter(int partitionId, long taskId) { int batchSize = writeOptions.getBatchSize(); - boolean useQueuedBuffer = writeOptions.isUseQueuedWriteBuffer(); + int poolSize = writeOptions.getQueueDepth(); boolean useLargeVarTypes = writeOptions.isUseLargeVarTypes(); + long maxBatchBytes = writeOptions.getMaxBatchBytes(); // Merge initial storage options with write options WriteParams params = writeOptions.toWriteParams(initialStorageOptions); - // Select buffer type based on configuration - ArrowBatchWriteBuffer writeBuffer; - if (useQueuedBuffer) { - int queueDepth = writeOptions.getQueueDepth(); - writeBuffer = - new QueuedArrowBatchWriteBuffer(sparkSchema, batchSize, queueDepth, useLargeVarTypes); - } else { - writeBuffer = new SemaphoreArrowBatchWriteBuffer(sparkSchema, batchSize, useLargeVarTypes); - } + ArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + sparkSchema, batchSize, poolSize, useLargeVarTypes, maxBatchBytes); // Create fragment in background thread Callable> fragmentCreator = diff --git a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java index 9eae0d99..20d43a0e 100644 --- a/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java +++ b/lance-spark-3.5_2.12/src/main/java/org/lance/spark/write/SparkPositionDeltaWrite.java @@ -212,21 +212,16 @@ private static class PositionDeltaWriteFactory implements DeltaWriterFactory { @Override public DeltaWriter createWriter(int partitionId, long taskId) { int batchSize = writeOptions.getBatchSize(); - boolean useQueuedBuffer = writeOptions.isUseQueuedWriteBuffer(); + int poolSize = writeOptions.getQueueDepth(); boolean useLargeVarTypes = writeOptions.isUseLargeVarTypes(); + long maxBatchBytes = writeOptions.getMaxBatchBytes(); // Merge initial storage options with write options WriteParams params = writeOptions.toWriteParams(initialStorageOptions); - // Select buffer type based on configuration - ArrowBatchWriteBuffer writeBuffer; - if (useQueuedBuffer) { - int queueDepth = writeOptions.getQueueDepth(); - writeBuffer = - new QueuedArrowBatchWriteBuffer(sparkSchema, batchSize, queueDepth, useLargeVarTypes); - } else { - writeBuffer = new SemaphoreArrowBatchWriteBuffer(sparkSchema, batchSize, useLargeVarTypes); - } + ArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + sparkSchema, batchSize, poolSize, useLargeVarTypes, maxBatchBytes); // Create fragment in background thread Callable> fragmentCreator = diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ArrowBatchWriteBuffer.java index 607e7b82..2d7cde28 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ArrowBatchWriteBuffer.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/ArrowBatchWriteBuffer.java @@ -27,9 +27,8 @@ * Abstract base class for Arrow batch write buffers that bridge Spark row writing and Lance * fragment creation. * - *

Both {@link SemaphoreArrowBatchWriteBuffer} (lock-based) and {@link - * QueuedArrowBatchWriteBuffer} (queue-based) extend this class, allowing the write path to be - * configured at runtime. + *

{@link PooledArrowBatchWriteBuffer} extends this class, pre-allocating a pool of + * VectorSchemaRoots that are reused across batches to minimize allocation overhead. */ public abstract class ArrowBatchWriteBuffer extends ArrowReader { diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java index 498ece2c..de12f565 100644 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/LanceDataWriter.java @@ -131,24 +131,16 @@ protected WriterFactory( @Override public DataWriter createWriter(int partitionId, long taskId) { int batchSize = writeOptions.getBatchSize(); - boolean useQueuedBuffer = writeOptions.isUseQueuedWriteBuffer(); + int poolSize = writeOptions.getQueueDepth(); boolean useLargeVarTypes = writeOptions.isUseLargeVarTypes(); long maxBatchBytes = writeOptions.getMaxBatchBytes(); // Merge initial storage options with write options WriteParams params = writeOptions.toWriteParams(initialStorageOptions); - // Select buffer type based on configuration - ArrowBatchWriteBuffer writeBuffer; - if (useQueuedBuffer) { - int queueDepth = writeOptions.getQueueDepth(); - writeBuffer = - new QueuedArrowBatchWriteBuffer( - schema, batchSize, queueDepth, useLargeVarTypes, maxBatchBytes); - } else { - writeBuffer = - new SemaphoreArrowBatchWriteBuffer(schema, batchSize, useLargeVarTypes, maxBatchBytes); - } + ArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + schema, batchSize, poolSize, useLargeVarTypes, maxBatchBytes); // Create fragment in background thread Callable> fragmentCreator = diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/PooledArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/PooledArrowBatchWriteBuffer.java new file mode 100644 index 00000000..a495fb59 --- /dev/null +++ b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/PooledArrowBatchWriteBuffer.java @@ -0,0 +1,379 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.spark.write; + +import org.lance.spark.LanceRuntime; +import org.lance.spark.LanceSparkWriteOptions; + +import com.google.common.base.Preconditions; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseLargeVariableWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.LanceArrowUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; + +/** + * Pool-based buffer for Arrow batches that pre-allocates a fixed set of VectorSchemaRoots and + * cycles through them. No per-batch allocation after initialization — vectors are reused via {@code + * reset()}, which preserves underlying buffer capacity. + * + *

This replaces both the semaphore-based and queue-based approaches: + * + *

    + *
  • Unlike the semaphore approach, there is no lock contention on every row write and no + * per-batch vector reallocation. + *
  • Unlike the queue approach, there is no per-batch child allocator overhead. + *
+ * + *

Batches are flushed when either the row count reaches {@code batchSize} or the estimated + * memory for the current batch exceeds {@code maxBatchBytes}, whichever comes first. Byte tracking + * uses row-level measurement: fixed-width vector costs are precomputed from the schema, and + * variable-width vector growth is tracked via {@code getDataBuffer().readableBytes()} deltas. + * + *

Architecture: + * + *

+ * Producer (Spark thread):              Consumer (Fragment creation thread):
+ * - Grabs free root from pool           - Takes filled root from readyQueue
+ * - Fills rows, resets on reuse         - Reads batches via ArrowReader interface
+ * - When full, puts in readyQueue       - Returns root to freePool when done
+ * - Only blocks if pool is exhausted    - Processes in parallel with producer
+ * 
+ */ +public class PooledArrowBatchWriteBuffer extends ArrowBatchWriteBuffer { + private static final int DEFAULT_POOL_SIZE = 4; + + private final Schema schema; + private final StructType sparkSchema; + private final int batchSize; + private final long maxBatchBytes; + private final int poolSize; + + /** Free roots ready for the producer to fill. */ + private final BlockingQueue freePool; + + /** Filled roots ready for the consumer to read. */ + private final BlockingQueue readyQueue; + + // -- Producer state (only touched by producer thread) -- + private VectorSchemaRoot producerBatch; + private org.lance.spark.arrow.LanceArrowWriter producerArrowWriter; + private int producerRowCount = 0; + private volatile boolean producerFinished = false; + + // -- Consumer state (only touched by consumer thread) -- + private VectorSchemaRoot consumerBatch; + private boolean consumerFinished = false; + + // -- Byte tracking -- + /** Precomputed fixed-width bytes per row (sum of type widths + validity bytes). */ + private final long fixedBytesPerRow; + + /** Indices of variable-width vectors for per-row byte tracking. */ + private final int[] variableWidthIndices; + + /** Accumulated variable-width bytes for current batch. */ + private long currentVarBytes = 0; + + public PooledArrowBatchWriteBuffer( + BufferAllocator allocator, Schema schema, StructType sparkSchema, int batchSize) { + this( + allocator, + schema, + sparkSchema, + batchSize, + DEFAULT_POOL_SIZE, + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + } + + public PooledArrowBatchWriteBuffer( + BufferAllocator allocator, + Schema schema, + StructType sparkSchema, + int batchSize, + int poolSize, + long maxBatchBytes) { + super(allocator); + Preconditions.checkNotNull(schema); + Preconditions.checkArgument(batchSize > 0, "Batch size must be positive"); + Preconditions.checkArgument(poolSize > 0, "Pool size must be positive"); + Preconditions.checkArgument(maxBatchBytes > 0, "maxBatchBytes must be positive"); + + this.schema = schema; + this.sparkSchema = sparkSchema; + this.batchSize = batchSize; + this.maxBatchBytes = maxBatchBytes; + this.poolSize = poolSize; + this.freePool = new ArrayBlockingQueue<>(poolSize); + this.readyQueue = new ArrayBlockingQueue<>(poolSize); + + // Precompute byte tracking metadata from schema + VectorSchemaRoot probe = VectorSchemaRoot.create(schema, allocator); + long fixedBytes = 0; + List varIndices = new ArrayList<>(); + for (int i = 0; i < probe.getFieldVectors().size(); i++) { + FieldVector vec = probe.getFieldVectors().get(i); + if (vec instanceof BaseVariableWidthVector || vec instanceof BaseLargeVariableWidthVector) { + varIndices.add(i); + } else if (vec instanceof BaseFixedWidthVector) { + // Fixed-width: type width + 1 bit validity (amortized to 1 byte per 8 rows, + // but we approximate as 1 byte per row for simplicity) + int typeWidth = ((BaseFixedWidthVector) vec).getTypeWidth(); + fixedBytes += typeWidth + 1; // data + validity byte + } + } + probe.close(); + this.fixedBytesPerRow = fixedBytes; + this.variableWidthIndices = varIndices.stream().mapToInt(Integer::intValue).toArray(); + + // Pre-allocate all roots + for (int i = 0; i < poolSize; i++) { + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + freePool.add(root); + } + + // Grab the first root for the producer + producerBatch = freePool.poll(); + producerArrowWriter = + org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(producerBatch, sparkSchema); + } + + /** Simplified constructor that uses LanceRuntime allocator and converts Spark schema to Arrow. */ + public PooledArrowBatchWriteBuffer(StructType sparkSchema, int batchSize, int poolSize) { + this(sparkSchema, batchSize, poolSize, false, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + } + + /** Constructor with large var types support, using LanceRuntime allocator. */ + public PooledArrowBatchWriteBuffer( + StructType sparkSchema, int batchSize, int poolSize, boolean useLargeVarTypes) { + this( + sparkSchema, + batchSize, + poolSize, + useLargeVarTypes, + LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); + } + + /** Constructor with all tuning parameters, using LanceRuntime allocator. */ + public PooledArrowBatchWriteBuffer( + StructType sparkSchema, + int batchSize, + int poolSize, + boolean useLargeVarTypes, + long maxBatchBytes) { + this( + LanceRuntime.allocator(), + LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, useLargeVarTypes), + sparkSchema, + batchSize, + poolSize, + maxBatchBytes); + } + + /** Returns whether the current batch should be flushed based on byte size. */ + private boolean isBatchFullByBytes() { + if (maxBatchBytes == Long.MAX_VALUE) { + return false; + } + long estimatedBytes = fixedBytesPerRow * producerRowCount + currentVarBytes; + return estimatedBytes >= maxBatchBytes; + } + + @Override + public void write(InternalRow row) { + Preconditions.checkNotNull(row); + Preconditions.checkState(!producerFinished, "Cannot write after setFinished() is called"); + + checkForError(); + + producerArrowWriter.write(row); + producerRowCount++; + + // Track variable-width byte growth. Use getBufferSizeFor(rowCount) because setSafe() + // updates the offset buffer but never advances the data buffer's writerIndex + // (that only happens at setValueCount), so readableBytes() would stay at 0. + if (variableWidthIndices.length > 0) { + long varBytes = 0; + for (int idx : variableWidthIndices) { + FieldVector vec = producerBatch.getVector(idx); + if (vec instanceof BaseVariableWidthVector) { + varBytes += ((BaseVariableWidthVector) vec).getBufferSizeFor(producerRowCount); + } else if (vec instanceof BaseLargeVariableWidthVector) { + varBytes += ((BaseLargeVariableWidthVector) vec).getBufferSizeFor(producerRowCount); + } + } + currentVarBytes = varBytes; + } + + if (producerRowCount >= batchSize || (producerRowCount > 0 && isBatchFullByBytes())) { + flushAndAcquireNext(); + } + } + + private void flushAndAcquireNext() { + producerArrowWriter.finish(); + producerBatch.setRowCount(producerRowCount); + + try { + while (!readyQueue.offer(producerBatch, 100, TimeUnit.MILLISECONDS)) { + checkForError(); + } + + // Acquire a free root for next batch + VectorSchemaRoot next = null; + while (next == null) { + next = freePool.poll(100, TimeUnit.MILLISECONDS); + if (next == null) { + checkForError(); + } + } + + // Reset for reuse — preserves buffer capacity + producerBatch = next; + producerArrowWriter = + org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(producerBatch, sparkSchema); + producerArrowWriter.reset(); + producerRowCount = 0; + currentVarBytes = 0; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while queuing batch", e); + } + } + + @Override + public void setFinished() { + if (producerFinished) { + return; + } + + try { + if (producerRowCount > 0) { + producerArrowWriter.finish(); + producerBatch.setRowCount(producerRowCount); + while (!readyQueue.offer(producerBatch, 100, TimeUnit.MILLISECONDS)) { + checkForError(); + } + } else { + freePool.offer(producerBatch); + } + producerBatch = null; + producerArrowWriter = null; + + // Signal completion only after the final batch is safely in the queue + producerFinished = true; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while finishing", e); + } + } + + // ========== ArrowReader interface for consumer ========== + + @Override + public boolean loadNextBatch() throws IOException { + if (consumerFinished) { + return false; + } + + try { + // Return previous batch to pool + if (consumerBatch != null) { + // Reset vectors for reuse + consumerBatch.setRowCount(0); + for (FieldVector v : consumerBatch.getFieldVectors()) { + v.reset(); + } + freePool.offer(consumerBatch); + consumerBatch = null; + } + + while (true) { + VectorSchemaRoot batch = readyQueue.poll(100, TimeUnit.MILLISECONDS); + if (batch != null) { + consumerBatch = batch; + return true; + } + if (producerFinished && readyQueue.isEmpty()) { + consumerFinished = true; + return false; + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for batch", e); + } + } + + @Override + public VectorSchemaRoot getVectorSchemaRoot() { + if (consumerBatch != null) { + return consumerBatch; + } + // Return an empty root for initial schema access + try { + return super.getVectorSchemaRoot(); + } catch (IOException e) { + throw new RuntimeException("Failed to get vector schema root", e); + } + } + + @Override + protected void prepareLoadNextBatch() throws IOException { + // No-op — batch is already prepared by producer + } + + @Override + public long bytesRead() { + return 0; + } + + @Override + protected void closeReadSource() throws IOException { + if (consumerBatch != null) { + consumerBatch.close(); + consumerBatch = null; + } + VectorSchemaRoot r; + while ((r = freePool.poll()) != null) { + r.close(); + } + while ((r = readyQueue.poll()) != null) { + r.close(); + } + } + + @Override + protected Schema readSchema() { + return this.schema; + } + + /** Returns the pool size (for monitoring/debugging). */ + public int getPoolSize() { + return poolSize; + } +} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java deleted file mode 100644 index 2fdc5c2b..00000000 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/QueuedArrowBatchWriteBuffer.java +++ /dev/null @@ -1,415 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.lance.spark.write; - -import org.lance.spark.LanceRuntime; -import org.lance.spark.LanceSparkWriteOptions; - -import com.google.common.base.Preconditions; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.LanceArrowUtils; - -import java.io.IOException; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - -/** - * Queue-based buffer for Arrow batches that enables pipelined batch processing. - * - *

Unlike the semaphore-based {@link SemaphoreArrowBatchWriteBuffer} which blocks on every row - * write, this implementation uses a bounded queue to allow multiple batches to be in flight - * simultaneously. This enables better pipelining between Spark row ingestion and Lance fragment - * creation. - * - *

Batches are flushed when either the row count reaches {@code batchSize} or the allocated - * memory for the current batch exceeds {@code maxBatchBytes}, whichever comes first. This limits - * the size of each individual batch when rows are very large (e.g., rows with large binary/string - * columns or embeddings). - * - *

Because this implementation is queue-based, multiple completed batches can be buffered at the - * same time. As a result, total in-flight Arrow memory can be roughly {@code queueDepth * - * maxBatchBytes}, plus the current producer batch and allocator overhead. Users may need to tune - * {@code queueDepth} and {@code maxBatchBytes} together to stay within memory limits. - * - *

Architecture: - * - *

- * Producer (Spark thread):           Consumer (Fragment creation thread):
- * - Writes rows to local batch       - Takes completed batches from queue
- * - When batch full, puts in queue   - Writes batches to Lance format
- * - Only blocks if queue is full     - Processes batches in parallel with producer
- * 
- */ -public class QueuedArrowBatchWriteBuffer extends ArrowBatchWriteBuffer { - /** Default queue depth - number of batches that can be buffered. */ - private static final int DEFAULT_QUEUE_DEPTH = 8; - - private final Schema schema; - private final StructType sparkSchema; - private final int batchSize; - private final long maxBatchBytes; - private final int queueDepth; - - /** - * Queue holding completed batches ready for consumption. Each entry pairs a VectorSchemaRoot with - * its dedicated child allocator, so the consumer can free both when done. - */ - private final BlockingQueue batchQueue; - - /** Child allocator for producer batches (shares root with consumer for buffer transfer). */ - private final BufferAllocator producerAllocator; - - /** Current batch being filled by producer. */ - private VectorSchemaRoot currentBatch; - - /** Child allocator dedicated to the current batch for accurate byte tracking. */ - private BufferAllocator currentBatchAllocator; - - /** Arrow writer for current batch. */ - private org.lance.spark.arrow.LanceArrowWriter currentArrowWriter; - - /** Count of rows in current batch. */ - private final AtomicInteger currentBatchRowCount = new AtomicInteger(0); - - /** Flag indicating producer has finished. */ - private volatile boolean producerFinished = false; - - /** Flag indicating consumer has consumed all batches. */ - private volatile boolean consumerFinished = false; - - /** Current batch being read by consumer (for ArrowReader interface). */ - private BatchEntry consumerCurrentEntry; - - /** Pairs a VectorSchemaRoot with its allocator for proper lifecycle management. */ - private static class BatchEntry { - final VectorSchemaRoot batch; - final BufferAllocator allocator; - - BatchEntry(VectorSchemaRoot batch, BufferAllocator allocator) { - this.batch = batch; - this.allocator = allocator; - } - - void close() { - try { - batch.close(); - } finally { - allocator.close(); - } - } - } - - public QueuedArrowBatchWriteBuffer( - BufferAllocator allocator, Schema schema, StructType sparkSchema, int batchSize) { - this( - allocator, - schema, - sparkSchema, - batchSize, - DEFAULT_QUEUE_DEPTH, - LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - /** Simplified constructor that uses LanceRuntime allocator and converts Spark schema to Arrow. */ - public QueuedArrowBatchWriteBuffer(StructType sparkSchema, int batchSize, int queueDepth) { - this(sparkSchema, batchSize, queueDepth, false, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - /** Constructor with large var types support, using LanceRuntime allocator. */ - public QueuedArrowBatchWriteBuffer( - StructType sparkSchema, int batchSize, int queueDepth, boolean useLargeVarTypes) { - this( - sparkSchema, - batchSize, - queueDepth, - useLargeVarTypes, - LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - /** Constructor with all tuning parameters, using LanceRuntime allocator. */ - public QueuedArrowBatchWriteBuffer( - StructType sparkSchema, - int batchSize, - int queueDepth, - boolean useLargeVarTypes, - long maxBatchBytes) { - this( - LanceRuntime.allocator(), - LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, useLargeVarTypes), - sparkSchema, - batchSize, - queueDepth, - maxBatchBytes); - } - - public QueuedArrowBatchWriteBuffer( - BufferAllocator allocator, - Schema schema, - StructType sparkSchema, - int batchSize, - int queueDepth) { - this( - allocator, - schema, - sparkSchema, - batchSize, - queueDepth, - LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - public QueuedArrowBatchWriteBuffer( - BufferAllocator allocator, - Schema schema, - StructType sparkSchema, - int batchSize, - int queueDepth, - long maxBatchBytes) { - super(allocator); - Preconditions.checkNotNull(schema); - Preconditions.checkArgument(batchSize > 0, "Batch size must be positive"); - Preconditions.checkArgument(queueDepth > 0, "Queue depth must be positive"); - Preconditions.checkArgument(maxBatchBytes > 0, "maxBatchBytes must be positive"); - - this.schema = schema; - this.sparkSchema = sparkSchema; - this.batchSize = batchSize; - this.maxBatchBytes = maxBatchBytes; - this.queueDepth = queueDepth; - this.batchQueue = new ArrayBlockingQueue<>(queueDepth); - - // Create a child allocator for producer that shares the same root as the consumer - // allocator. This is required for Arrow buffer transfer between producer and consumer. - this.producerAllocator = - allocator.newChildAllocator("queued-buffer-producer", 0, Long.MAX_VALUE); - - // Initialize first batch for producer - allocateNewBatch(); - } - - /** Allocates a new batch for the producer to fill. */ - private void allocateNewBatch() { - // Each batch gets its own child allocator so getAllocatedMemory() accurately reflects - // only this batch's memory, unaffected by concurrent consumer freeing of older batches. - currentBatchAllocator = producerAllocator.newChildAllocator("batch", 0, Long.MAX_VALUE); - try { - currentBatch = VectorSchemaRoot.create(schema, currentBatchAllocator); - currentBatch.allocateNew(); - } catch (Exception e) { - if (currentBatch != null) { - currentBatch.close(); - } - currentBatchAllocator.close(); - throw e; - } - currentArrowWriter = - org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(currentBatch, sparkSchema); - currentBatchRowCount.set(0); - } - - /** Returns whether the current batch should be flushed based on byte size. */ - private boolean isBatchFullByBytes() { - if (maxBatchBytes == Long.MAX_VALUE) { - return false; - } - return currentBatchAllocator.getAllocatedMemory() >= maxBatchBytes; - } - - /** - * Writes a row to the current batch. When the batch is full (by row count or byte size), it is - * queued for consumption and a new batch is allocated. - * - *

This method only blocks if the queue is full, allowing the producer to continue writing - * while the consumer processes previous batches. - * - * @param row The row to write - */ - @Override - public void write(InternalRow row) { - Preconditions.checkNotNull(row); - Preconditions.checkState(!producerFinished, "Cannot write after setFinished() is called"); - - checkForError(); - - try { - currentArrowWriter.write(row); - - int count = currentBatchRowCount.incrementAndGet(); - - if (count >= batchSize || (count > 0 && isBatchFullByBytes())) { - // Batch is full - finalize and queue it - currentArrowWriter.finish(); - currentBatch.setRowCount(count); - - BatchEntry entry = new BatchEntry(currentBatch, currentBatchAllocator); - // Put in queue, periodically checking for consumer errors - try { - while (!batchQueue.offer(entry, 100, TimeUnit.MILLISECONDS)) { - checkForError(); - } - } catch (RuntimeException e) { - entry.close(); - throw e; - } - - // Allocate new batch for next writes - allocateNewBatch(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while queuing batch", e); - } - } - - /** - * Signals that the producer has finished writing. Any partial batch is queued for consumption. - */ - @Override - public void setFinished() { - if (producerFinished) { - return; - } - - try { - // Queue any remaining partial batch before signaling completion to avoid - // a race where the consumer sees producerFinished=true with an empty queue - // and exits before the final batch is enqueued. - int remainingRows = currentBatchRowCount.get(); - if (remainingRows > 0) { - currentArrowWriter.finish(); - currentBatch.setRowCount(remainingRows); - BatchEntry entry = new BatchEntry(currentBatch, currentBatchAllocator); - try { - while (!batchQueue.offer(entry, 100, TimeUnit.MILLISECONDS)) { - checkForError(); - } - } catch (RuntimeException e) { - entry.close(); - throw e; - } - } else { - // No remaining rows, close the empty batch and its allocator - currentBatch.close(); - currentBatchAllocator.close(); - } - currentBatch = null; - currentBatchAllocator = null; - currentArrowWriter = null; - - // Signal completion only after the final batch is safely in the queue - producerFinished = true; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Interrupted while finishing", e); - } - } - - // ========== ArrowReader interface for consumer ========== - - @Override - public boolean loadNextBatch() throws IOException { - if (consumerFinished) { - return false; - } - - try { - // Close previous batch if any - if (consumerCurrentEntry != null) { - consumerCurrentEntry.close(); - consumerCurrentEntry = null; - } - - // Try to get next batch from queue - while (true) { - // Use poll with timeout to check for producer finished - BatchEntry entry = batchQueue.poll(100, TimeUnit.MILLISECONDS); - - if (entry != null) { - consumerCurrentEntry = entry; - return true; - } - - // Check if producer is done and queue is empty - if (producerFinished && batchQueue.isEmpty()) { - consumerFinished = true; - return false; - } - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IOException("Interrupted while waiting for batch", e); - } - } - - @Override - public VectorSchemaRoot getVectorSchemaRoot() { - if (consumerCurrentEntry != null) { - return consumerCurrentEntry.batch; - } - // Return an empty root for initial schema access - try { - return super.getVectorSchemaRoot(); - } catch (IOException e) { - throw new RuntimeException("Failed to get vector schema root", e); - } - } - - @Override - protected void prepareLoadNextBatch() throws IOException { - // No-op - batch is already prepared by producer - } - - @Override - public long bytesRead() { - return 0; // Not tracked - } - - @Override - protected void closeReadSource() throws IOException { - // Close any remaining batch - if (consumerCurrentEntry != null) { - consumerCurrentEntry.close(); - consumerCurrentEntry = null; - } - - // Drain and close any batches left in queue - BatchEntry entry; - while ((entry = batchQueue.poll()) != null) { - entry.close(); - } - - // Close producer allocator - producerAllocator.close(); - } - - @Override - protected Schema readSchema() { - return this.schema; - } - - /** Returns the queue depth (for monitoring/debugging). */ - public int getQueueDepth() { - return queueDepth; - } - - /** Returns the current queue size (for monitoring/debugging). */ - public int getCurrentQueueSize() { - return batchQueue.size(); - } -} diff --git a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java b/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java deleted file mode 100644 index 0c5cd62b..00000000 --- a/lance-spark-base_2.12/src/main/java/org/lance/spark/write/SemaphoreArrowBatchWriteBuffer.java +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.lance.spark.write; - -import org.lance.spark.LanceRuntime; -import org.lance.spark.LanceSparkWriteOptions; - -import com.google.common.base.Preconditions; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.LanceArrowUtils; - -import javax.annotation.concurrent.GuardedBy; - -import java.io.IOException; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.ReentrantLock; - -/** - * Buffers Spark rows into Arrow batches for consumption by Lance fragment creation. - * - *

This class bridges the producer (Spark thread writing rows) and consumer (Lance native code - * pulling batches via ArrowReader interface). It uses a lock with conditions to synchronize between - * the two threads - the producer blocks until the consumer is ready for more data, and vice versa. - * - *

Batches are flushed when either the row count reaches {@code batchSize} or the cumulative - * bytes written in the current batch exceeds {@code maxBatchBytes}, whichever comes first. This - * prevents OOM when individual rows are very large (e.g., rows with large binary/string columns or - * embeddings). - * - *

Because this buffer reuses a single VectorSchemaRoot across batches, the allocator retains - * buffer capacity from previous batches. The byte-based flush tracks per-row allocator growth - * (before/after each write) to accurately measure each batch's memory usage regardless of retained - * capacity. - * - * @see QueuedArrowBatchWriteBuffer for a queue-based alternative with per-batch allocators - */ -public class SemaphoreArrowBatchWriteBuffer extends ArrowBatchWriteBuffer { - private final Schema schema; - private final StructType sparkSchema; - private final int batchSize; - private final long maxBatchBytes; - - private final ReentrantLock lock = new ReentrantLock(); - private final Condition canWrite = lock.newCondition(); - private final Condition batchReady = lock.newCondition(); - - @GuardedBy("lock") - private boolean finished = false; - - @GuardedBy("lock") - private int count; - - /** - * Tracks per-batch memory usage for byte-based flushing. {@code batchStartBytes} captures the - * allocator memory after {@code clear()+allocateNew()}, and {@code currentBatchBytes} is the - * delta from that baseline. This is necessary because the shared allocator retains capacity from - * previous batches, so absolute memory is not reliable for per-batch tracking. - */ - @GuardedBy("lock") - private long currentBatchBytes; - - private long batchStartBytes; - - private org.lance.spark.arrow.LanceArrowWriter arrowWriter = null; - - public SemaphoreArrowBatchWriteBuffer( - BufferAllocator allocator, - Schema schema, - StructType sparkSchema, - int batchSize, - long maxBatchBytes) { - // Pass a child allocator to ArrowReader so VectorSchemaRoot allocation is tracked - super(allocator.newChildAllocator("semaphore-buffer", 0, Long.MAX_VALUE)); - Preconditions.checkNotNull(schema); - Preconditions.checkArgument(batchSize > 0); - Preconditions.checkArgument(maxBatchBytes > 0, "maxBatchBytes must be positive"); - this.schema = schema; - this.sparkSchema = sparkSchema; - this.batchSize = batchSize; - this.maxBatchBytes = maxBatchBytes; - // Start with count = batchSize so the writer blocks on canWrite.await() until the - // reader's prepareLoadNextBatch() initializes arrowWriter and resets count to 0. - this.count = batchSize; - } - - public SemaphoreArrowBatchWriteBuffer( - BufferAllocator allocator, Schema schema, StructType sparkSchema, int batchSize) { - this(allocator, schema, sparkSchema, batchSize, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - /** Simplified constructor that uses LanceRuntime allocator and converts Spark schema to Arrow. */ - public SemaphoreArrowBatchWriteBuffer(StructType sparkSchema, int batchSize) { - this(sparkSchema, batchSize, false, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - /** Constructor with large var types support, using LanceRuntime allocator. */ - public SemaphoreArrowBatchWriteBuffer( - StructType sparkSchema, int batchSize, boolean useLargeVarTypes) { - this(sparkSchema, batchSize, useLargeVarTypes, LanceSparkWriteOptions.DEFAULT_MAX_BATCH_BYTES); - } - - /** Constructor with all tuning parameters, using LanceRuntime allocator. */ - public SemaphoreArrowBatchWriteBuffer( - StructType sparkSchema, int batchSize, boolean useLargeVarTypes, long maxBatchBytes) { - this( - LanceRuntime.allocator(), - LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, useLargeVarTypes), - sparkSchema, - batchSize, - maxBatchBytes); - } - - @Override - public void onTaskComplete() { - lock.lock(); - try { - canWrite.signalAll(); - batchReady.signalAll(); - } finally { - lock.unlock(); - } - } - - /** Returns whether the current batch should be flushed based on byte size. */ - private boolean isBatchFullByBytes() { - if (maxBatchBytes == Long.MAX_VALUE) { - return false; - } - return currentBatchBytes >= maxBatchBytes; - } - - /** Returns whether the current batch should be flushed (by row count or byte size). */ - private boolean isBatchFull() { - return count >= batchSize || (count > 0 && isBatchFullByBytes()); - } - - @Override - public void write(InternalRow row) { - Preconditions.checkNotNull(row); - lock.lock(); - try { - checkForError(); - - // wait until prepareLoadNextBatch signals that writes are available - while (isBatchFull()) { - canWrite.await(); - checkForError(); - } - - arrowWriter.write(row); - currentBatchBytes = this.allocator.getAllocatedMemory() - batchStartBytes; - count++; - - if (isBatchFull()) { - batchReady.signal(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } finally { - lock.unlock(); - } - } - - @Override - public void setFinished() { - lock.lock(); - try { - finished = true; - batchReady.signal(); - canWrite.signalAll(); - } finally { - lock.unlock(); - } - } - - @Override - public void prepareLoadNextBatch() throws IOException { - // Don't call super.prepareLoadNextBatch() which does clear()+allocateNew(). - // Arrow's allocateNew() remembers previous allocation sizes and pre-allocates - // that much capacity, which defeats per-batch byte tracking (the delta stays 0 - // because writes fit within pre-allocated capacity). Instead, reset each vector - // (releasing memory and clearing allocation hints) then allocateNew() from scratch. - VectorSchemaRoot root = this.getVectorSchemaRoot(); - for (FieldVector v : root.getFieldVectors()) { - v.clear(); - v.setInitialCapacity(1); - v.allocateNew(); - } - root.setRowCount(0); - arrowWriter = org.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(root, sparkSchema); - lock.lock(); - try { - count = 0; - currentBatchBytes = 0; - batchStartBytes = this.allocator.getAllocatedMemory(); - canWrite.signalAll(); - } finally { - lock.unlock(); - } - } - - @Override - public boolean loadNextBatch() throws IOException { - prepareLoadNextBatch(); - lock.lock(); - try { - if (finished && count == 0) { - // Clear any buffers allocated by prepareLoadNextBatch() since no rows were written - this.getVectorSchemaRoot().clear(); - return false; - } - - // wait until batch is full (by rows or bytes) or finished - while (!isBatchFull() && !finished) { - batchReady.await(); - checkForError(); - } - - arrowWriter.finish(); - - if (!finished) { - return true; - } else { - return count > 0; - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } finally { - lock.unlock(); - } - } - - @Override - public long bytesRead() { - throw new UnsupportedOperationException(); - } - - @Override - protected synchronized void closeReadSource() throws IOException { - // Close the child allocator that was created for byte tracking. - // The VectorSchemaRoot is closed by ArrowReader.close() before this is called. - this.allocator.close(); - } - - @Override - protected Schema readSchema() { - return this.schema; - } -} diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/PooledArrowBatchWriteBufferTest.java similarity index 59% rename from lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java rename to lance-spark-base_2.12/src/test/java/org/lance/spark/write/PooledArrowBatchWriteBufferTest.java index a640804a..eef77583 100644 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/QueuedArrowBatchWriteBufferTest.java +++ b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/PooledArrowBatchWriteBufferTest.java @@ -27,7 +27,6 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -42,7 +41,7 @@ import static org.junit.jupiter.api.Assertions.*; -public class QueuedArrowBatchWriteBufferTest { +public class PooledArrowBatchWriteBufferTest { private Schema createIntSchema() { Field field = @@ -58,6 +57,54 @@ private StructType createIntSparkSchema() { new StructField[] {DataTypes.createStructField("column1", DataTypes.IntegerType, true)}); } + private void runWriterReader( + PooledArrowBatchWriteBuffer writeBuffer, + int totalRows, + AtomicInteger rowsWritten, + AtomicInteger rowsRead, + AtomicReference writerError, + AtomicReference readerError) + throws InterruptedException { + Thread writerThread = + new Thread( + () -> { + try { + for (int i = 0; i < totalRows; i++) { + InternalRow row = + new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); + writeBuffer.write(row); + } + writeBuffer.setFinished(); + } catch (Throwable e) { + writerError.set(e); + } + }); + + Thread readerThread = + new Thread( + () -> { + try { + while (writeBuffer.loadNextBatch()) { + VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); + int rowCount = root.getRowCount(); + int baseValue = rowsRead.get(); + rowsRead.addAndGet(rowCount); + for (int i = 0; i < rowCount; i++) { + int value = (int) root.getVector("column1").getObject(i); + assertEquals(baseValue + i + 1, value); + } + } + } catch (Throwable e) { + readerError.set(e); + } + }); + + writerThread.start(); + readerThread.start(); + writerThread.join(); + readerThread.join(); + } + @Test public void testBasicWriteAndRead() throws Exception { try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { @@ -66,57 +113,17 @@ public void testBasicWriteAndRead() throws Exception { final int totalRows = 125; final int batchSize = 34; - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final int poolSize = 4; + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, poolSize, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); AtomicReference writerError = new AtomicReference<>(); AtomicReference readerError = new AtomicReference<>(); - Thread writerThread = - new Thread( - () -> { - try { - for (int i = 0; i < totalRows; i++) { - InternalRow row = - new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); - writeBuffer.write(row); - } - } catch (Throwable e) { - writerError.set(e); - e.printStackTrace(); - } finally { - writeBuffer.setFinished(); - } - }); - - Thread readerThread = - new Thread( - () -> { - try { - while (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - int rowCount = root.getRowCount(); - int baseValue = rowsRead.get(); - rowsRead.addAndGet(rowCount); - for (int i = 0; i < rowCount; i++) { - int value = (int) root.getVector("column1").getObject(i); - assertEquals(baseValue + i + 1, value); - } - } - } catch (Throwable e) { - readerError.set(e); - e.printStackTrace(); - } - }); - - writerThread.start(); - readerThread.start(); - - writerThread.join(); - readerThread.join(); + runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead, writerError, readerError); assertNull(writerError.get(), "Writer thread should not have errors"); assertNull(readerError.get(), "Reader thread should not have errors"); @@ -128,16 +135,15 @@ public void testBasicWriteAndRead() throws Exception { @Test public void testPartialBatch() throws Exception { - // Test that partial batches (when totalRows % batchSize != 0) are handled correctly try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); final int totalRows = 50; - final int batchSize = 34; // Will have 1 full batch (34) + 1 partial batch (16) - final int queueDepth = 2; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final int batchSize = 34; + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -182,22 +188,16 @@ public void testPartialBatch() throws Exception { @Test public void testEmptyWrite() throws Exception { - // Test that calling setFinished without writing any rows works correctly try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, 100, 2); + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer(allocator, schema, sparkSchema, 100, 2, Long.MAX_VALUE); AtomicInteger batchCount = new AtomicInteger(0); - Thread writerThread = - new Thread( - () -> { - // Don't write anything, just finish - writeBuffer.setFinished(); - }); + Thread writerThread = new Thread(writeBuffer::setFinished); Thread readerThread = new Thread( @@ -223,16 +223,16 @@ public void testEmptyWrite() throws Exception { @Test public void testLargeDataset() throws Exception { - // Test with a larger dataset to ensure queue pipelining works try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); final int totalRows = 10000; final int batchSize = 512; - final int queueDepth = 8; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final int poolSize = 8; + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, poolSize, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -278,50 +278,27 @@ public void testLargeDataset() throws Exception { } @Test - public void testQueueDepthOne() throws Exception { - // Test with minimum queue depth of 1 + public void testPoolSizeOne() throws Exception { + // Pool size 1 = serial producer/consumer, equivalent to old semaphore behavior try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); final int totalRows = 100; final int batchSize = 10; - final int queueDepth = 1; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 1, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); + AtomicReference writerError = new AtomicReference<>(); + AtomicReference readerError = new AtomicReference<>(); - Thread writerThread = - new Thread( - () -> { - for (int i = 0; i < totalRows; i++) { - InternalRow row = - new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); - writeBuffer.write(row); - } - writeBuffer.setFinished(); - }); - - Thread readerThread = - new Thread( - () -> { - try { - while (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - rowsRead.addAndGet(root.getRowCount()); - } - } catch (Exception e) { - e.printStackTrace(); - } - }); - - writerThread.start(); - readerThread.start(); - writerThread.join(); - readerThread.join(); + runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead, writerError, readerError); + assertNull(writerError.get()); + assertNull(readerError.get()); assertEquals(totalRows, rowsWritten.get()); assertEquals(totalRows, rowsRead.get()); writeBuffer.close(); @@ -330,7 +307,6 @@ public void testQueueDepthOne() throws Exception { @Test public void testMultipleColumns() throws Exception { - // Test with multiple columns of different types try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Field intField = new Field( @@ -353,9 +329,9 @@ public void testMultipleColumns() throws Exception { final int totalRows = 200; final int batchSize = 50; - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -405,32 +381,36 @@ public void testMultipleColumns() throws Exception { } @Test - public void testDefaultQueueDepth() throws Exception { - // Test using the constructor with default queue depth + public void testExactBatchBoundary() throws Exception { try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); - final int totalRows = 100; - final int batchSize = 20; - // Use constructor without queueDepth - should use default of 8 - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize); - - assertEquals(8, writeBuffer.getQueueDepth()); + final int batchSize = 10; + final int totalRows = 30; // Exactly 3 batches + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); + AtomicInteger batchCount = new AtomicInteger(0); + AtomicReference writerError = new AtomicReference<>(); + AtomicReference readerError = new AtomicReference<>(); Thread writerThread = new Thread( () -> { - for (int i = 0; i < totalRows; i++) { - InternalRow row = - new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); - writeBuffer.write(row); + try { + for (int i = 0; i < totalRows; i++) { + InternalRow row = + new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); + writeBuffer.write(row); + } + writeBuffer.setFinished(); + } catch (Throwable e) { + writerError.set(e); } - writeBuffer.setFinished(); }); Thread readerThread = @@ -438,11 +418,13 @@ public void testDefaultQueueDepth() throws Exception { () -> { try { while (writeBuffer.loadNextBatch()) { + batchCount.incrementAndGet(); VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); rowsRead.addAndGet(root.getRowCount()); + assertEquals(batchSize, root.getRowCount()); } - } catch (Exception e) { - e.printStackTrace(); + } catch (Throwable e) { + readerError.set(e); } }); @@ -451,83 +433,53 @@ public void testDefaultQueueDepth() throws Exception { writerThread.join(); readerThread.join(); + assertNull(writerError.get()); + assertNull(readerError.get()); assertEquals(totalRows, rowsWritten.get()); assertEquals(totalRows, rowsRead.get()); + assertEquals(3, batchCount.get()); writeBuffer.close(); } } @Test - public void testSlowConsumer() throws Exception { - // Test that the queue buffers batches when consumer is slow + public void testSingleRowBatch() throws Exception { try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); - final int totalRows = 100; - final int batchSize = 10; - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final int batchSize = 1; + final int totalRows = 5; + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); - AtomicInteger maxQueueSize = new AtomicInteger(0); - - Thread writerThread = - new Thread( - () -> { - for (int i = 0; i < totalRows; i++) { - InternalRow row = - new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); - writeBuffer.write(row); - // Track max queue size - int currentSize = writeBuffer.getCurrentQueueSize(); - maxQueueSize.updateAndGet(prev -> Math.max(prev, currentSize)); - } - writeBuffer.setFinished(); - }); - - Thread readerThread = - new Thread( - () -> { - try { - while (writeBuffer.loadNextBatch()) { - // Simulate slow consumer - Thread.sleep(10); - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - rowsRead.addAndGet(root.getRowCount()); - } - } catch (Exception e) { - e.printStackTrace(); - } - }); + AtomicReference writerError = new AtomicReference<>(); + AtomicReference readerError = new AtomicReference<>(); - writerThread.start(); - readerThread.start(); - writerThread.join(); - readerThread.join(); + runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead, writerError, readerError); + assertNull(writerError.get()); + assertNull(readerError.get()); assertEquals(totalRows, rowsWritten.get()); assertEquals(totalRows, rowsRead.get()); - // Queue should have been used (max size > 0 at some point) - assertTrue(maxQueueSize.get() >= 0); writeBuffer.close(); } } @Test public void testWriteErrorPropagation() throws Exception { - // Test that the queue buffers batches when consumer is slow try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { Schema schema = createIntSchema(); StructType sparkSchema = createIntSparkSchema(); final int totalRows = 100; final int batchSize = 10; - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize, queueDepth); + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, Long.MAX_VALUE); AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); @@ -539,20 +491,16 @@ public void testWriteErrorPropagation() throws Exception { VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); rowsRead.addAndGet(root.getRowCount()); readerConsumedBatch.countDown(); - - // Throw a mock exception after reading a batch throw new RuntimeException("Mock exception"); } return rowsRead.get(); }; - // Start background thread to read from the queue FutureTask readTask = writeBuffer.createTrackedTask(read); Thread readerThread = new Thread(readTask); readerThread.start(); - // Write data to queue until it throws an exception - Assertions.assertThrows( + assertThrows( RuntimeException.class, () -> { try { @@ -562,9 +510,7 @@ public void testWriteErrorPropagation() throws Exception { rowsWritten.incrementAndGet(); if (rowsWritten.get() >= batchSize) { - // Wait for the reader to consume a batch and throw readerConsumedBatch.await(); - // Wait for the reader task to fully complete so checkForError detects it while (!readTask.isDone()) { Thread.sleep(1); } @@ -575,7 +521,7 @@ public void testWriteErrorPropagation() throws Exception { } }); - Assertions.assertThrows(ExecutionException.class, readTask::get); + assertThrows(ExecutionException.class, readTask::get); assertEquals(batchSize, rowsWritten.get()); assertEquals(batchSize, rowsRead.get()); @@ -583,69 +529,32 @@ public void testWriteErrorPropagation() throws Exception { } } - // ========== Byte-based flush tests ========== - - private Schema createStringSchema() { - Field field = - new Field( - "data", - FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()), - null); - return new Schema(Collections.singletonList(field)); - } - - private StructType createStringSparkSchema() { - return new StructType( - new StructField[] {DataTypes.createStructField("data", DataTypes.StringType, true)}); - } - - /** Generate a string of approximately the given size in bytes. */ - private UTF8String generateLargeString(int sizeBytes) { - byte[] data = new byte[sizeBytes]; - java.util.Arrays.fill(data, (byte) 'A'); - return UTF8String.fromBytes(data); - } - @Test - public void testByteBasedFlush() throws Exception { - // Each row is ~100KB. With batchSize=1000 (would be 100MB), but maxBatchBytes=256KB, - // we should see batches flush after ~2-3 rows instead of waiting for 1000. + public void testByteBasedFlushWithSmallRows() throws Exception { + // Small rows should not trigger byte-based flush — only row count matters try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createStringSchema(); - StructType sparkSchema = createStringSparkSchema(); + Schema schema = createIntSchema(); + StructType sparkSchema = createIntSparkSchema(); - final int totalRows = 20; - final int batchSize = 1000; // High row limit - should never be reached - final long maxBatchBytes = 256 * 1024; // 256KB - should trigger flush after ~2 rows - final int rowSizeBytes = 100 * 1024; // ~100KB per row - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes); + final int totalRows = 100; + final int batchSize = 50; + // 256MB limit — small int rows will never reach this + final long maxBatchBytes = 256L * 1024 * 1024; + + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, maxBatchBytes); - AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); AtomicInteger batchCount = new AtomicInteger(0); - AtomicInteger maxRowsInBatch = new AtomicInteger(0); - AtomicReference writerError = new AtomicReference<>(); - AtomicReference readerError = new AtomicReference<>(); Thread writerThread = new Thread( () -> { - try { - for (int i = 0; i < totalRows; i++) { - UTF8String largeValue = generateLargeString(rowSizeBytes); - InternalRow row = new GenericInternalRow(new Object[] {largeValue}); - writeBuffer.write(row); - rowsWritten.incrementAndGet(); - } - } catch (Throwable e) { - writerError.set(e); - e.printStackTrace(); - } finally { - writeBuffer.setFinished(); + for (int i = 0; i < totalRows; i++) { + writeBuffer.write(new GenericInternalRow(new Object[] {i})); } + writeBuffer.setFinished(); }); Thread readerThread = @@ -653,14 +562,10 @@ public void testByteBasedFlush() throws Exception { () -> { try { while (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - int rowCount = root.getRowCount(); - rowsRead.addAndGet(rowCount); batchCount.incrementAndGet(); - maxRowsInBatch.updateAndGet(prev -> Math.max(prev, rowCount)); + rowsRead.addAndGet(writeBuffer.getVectorSchemaRoot().getRowCount()); } - } catch (Throwable e) { - readerError.set(e); + } catch (Exception e) { e.printStackTrace(); } }); @@ -670,62 +575,54 @@ public void testByteBasedFlush() throws Exception { writerThread.join(); readerThread.join(); - assertNull(writerError.get(), "Writer thread should not have errors"); - assertNull(readerError.get(), "Reader thread should not have errors"); - assertEquals(totalRows, rowsWritten.get()); assertEquals(totalRows, rowsRead.get()); - // With 100KB rows and 256KB limit, each batch should have at most ~3 rows. - // Without byte-based flush, we'd get 1 batch of 20 rows (since batchSize=1000). - assertTrue( - batchCount.get() > 1, - "Should have multiple batches due to byte-based flushing, but got " + batchCount.get()); - assertTrue( - maxRowsInBatch.get() < batchSize, - "Max rows per batch (" - + maxRowsInBatch.get() - + ") should be less than batchSize (" - + batchSize - + ") due to byte-based flushing"); + assertEquals(2, batchCount.get()); // 100 rows / 50 batch size = 2 batches writeBuffer.close(); } } @Test - public void testByteBasedFlushWithSmallRows() throws Exception { - // With small rows, the row count limit should be reached before byte limit. + public void testByteBasedFlushWithLargeStrings() throws Exception { + // Large string rows should trigger byte-based flush before row count limit try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); + Field stringField = + new Field( + "data", + FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()), + null); + Schema schema = new Schema(Collections.singletonList(stringField)); + StructType sparkSchema = + new StructType( + new StructField[] {DataTypes.createStructField("data", DataTypes.StringType, true)}); - final int totalRows = 100; - final int batchSize = 25; - final long maxBatchBytes = 100 * 1024 * 1024; // 100MB - should never be reached - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes); + final int totalRows = 20; + final int batchSize = 1000; // High row limit — byte limit should trigger first + // ~100KB strings, 256KB byte limit → should flush every ~2-3 rows + final long maxBatchBytes = 256L * 1024; + + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, maxBatchBytes); + + // Build a ~100KB string + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 100 * 1024; i++) { + sb.append('x'); + } + String largeString = sb.toString(); - AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); AtomicInteger batchCount = new AtomicInteger(0); - AtomicReference writerError = new AtomicReference<>(); - AtomicReference readerError = new AtomicReference<>(); Thread writerThread = new Thread( () -> { - try { - for (int i = 0; i < totalRows; i++) { - InternalRow row = - new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); - writeBuffer.write(row); - } - } catch (Throwable e) { - writerError.set(e); - e.printStackTrace(); - } finally { - writeBuffer.setFinished(); + for (int i = 0; i < totalRows; i++) { + writeBuffer.write( + new GenericInternalRow( + new Object[] {UTF8String.fromString(largeString + i)})); } + writeBuffer.setFinished(); }); Thread readerThread = @@ -733,12 +630,15 @@ public void testByteBasedFlushWithSmallRows() throws Exception { () -> { try { while (writeBuffer.loadNextBatch()) { + batchCount.incrementAndGet(); VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); rowsRead.addAndGet(root.getRowCount()); - batchCount.incrementAndGet(); + // Each batch should have fewer than batchSize rows + assertTrue( + root.getRowCount() < batchSize, + "Byte-based flush should trigger before row count limit"); } - } catch (Throwable e) { - readerError.set(e); + } catch (Exception e) { e.printStackTrace(); } }); @@ -748,54 +648,53 @@ public void testByteBasedFlushWithSmallRows() throws Exception { writerThread.join(); readerThread.join(); - assertNull(writerError.get(), "Writer thread should not have errors"); - assertNull(readerError.get(), "Reader thread should not have errors"); - assertEquals(totalRows, rowsWritten.get()); assertEquals(totalRows, rowsRead.get()); - // Should have exactly 4 batches of 25 rows (row-count based flushing) - assertEquals(4, batchCount.get()); + // Should have more batches than if only row-count-based flushing + assertTrue(batchCount.get() > 1, "Should have multiple batches from byte-based flush"); writeBuffer.close(); } } @Test public void testByteBasedFlushSingleLargeRow() throws Exception { - // A single row exceeds maxBatchBytes - should flush after each row. + // A single row exceeding the byte limit should flush as a batch of 1 try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createStringSchema(); - StructType sparkSchema = createStringSparkSchema(); + Field stringField = + new Field( + "data", + FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()), + null); + Schema schema = new Schema(Collections.singletonList(stringField)); + StructType sparkSchema = + new StructType( + new StructField[] {DataTypes.createStructField("data", DataTypes.StringType, true)}); final int totalRows = 5; final int batchSize = 1000; - final long maxBatchBytes = 1024; // 1KB - each row will exceed this - final int rowSizeBytes = 10 * 1024; // 10KB per row - final int queueDepth = 4; - final QueuedArrowBatchWriteBuffer writeBuffer = - new QueuedArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, queueDepth, maxBatchBytes); + // 1KB byte limit with ~10KB strings → 1 row per batch + final long maxBatchBytes = 1024; + + final PooledArrowBatchWriteBuffer writeBuffer = + new PooledArrowBatchWriteBuffer( + allocator, schema, sparkSchema, batchSize, 4, maxBatchBytes); + + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 10 * 1024; i++) { + sb.append('y'); + } + String largeString = sb.toString(); - AtomicInteger rowsWritten = new AtomicInteger(0); AtomicInteger rowsRead = new AtomicInteger(0); AtomicInteger batchCount = new AtomicInteger(0); - AtomicReference writerError = new AtomicReference<>(); - AtomicReference readerError = new AtomicReference<>(); Thread writerThread = new Thread( () -> { - try { - for (int i = 0; i < totalRows; i++) { - UTF8String largeValue = generateLargeString(rowSizeBytes); - InternalRow row = new GenericInternalRow(new Object[] {largeValue}); - writeBuffer.write(row); - rowsWritten.incrementAndGet(); - } - } catch (Throwable e) { - writerError.set(e); - e.printStackTrace(); - } finally { - writeBuffer.setFinished(); + for (int i = 0; i < totalRows; i++) { + writeBuffer.write( + new GenericInternalRow(new Object[] {UTF8String.fromString(largeString)})); } + writeBuffer.setFinished(); }); Thread readerThread = @@ -803,12 +702,12 @@ public void testByteBasedFlushSingleLargeRow() throws Exception { () -> { try { while (writeBuffer.loadNextBatch()) { + batchCount.incrementAndGet(); VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); rowsRead.addAndGet(root.getRowCount()); - batchCount.incrementAndGet(); + assertEquals(1, root.getRowCount(), "Each batch should contain exactly 1 row"); } - } catch (Throwable e) { - readerError.set(e); + } catch (Exception e) { e.printStackTrace(); } }); @@ -818,15 +717,8 @@ public void testByteBasedFlushSingleLargeRow() throws Exception { writerThread.join(); readerThread.join(); - assertNull(writerError.get(), "Writer thread should not have errors"); - assertNull(readerError.get(), "Reader thread should not have errors"); - assertEquals(totalRows, rowsWritten.get()); assertEquals(totalRows, rowsRead.get()); - // Each row should produce its own batch since each exceeds the byte limit - assertEquals( - totalRows, - batchCount.get(), - "Each row should be its own batch when row size exceeds maxBatchBytes"); + assertEquals(totalRows, batchCount.get()); // 1 row per batch writeBuffer.close(); } } diff --git a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java b/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java deleted file mode 100644 index cf23c249..00000000 --- a/lance-spark-base_2.12/src/test/java/org/lance/spark/write/SemaphoreArrowBatchWriteBufferTest.java +++ /dev/null @@ -1,460 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.lance.spark.write; - -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.types.UTF8String; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.Collections; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.FutureTask; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class SemaphoreArrowBatchWriteBufferTest { - - private Schema createIntSchema() { - Field field = - new Field( - "column1", - FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.INT.getType()), - null); - return new Schema(Collections.singletonList(field)); - } - - private StructType createIntSparkSchema() { - return new StructType( - new StructField[] {DataTypes.createStructField("column1", DataTypes.IntegerType, true)}); - } - - private void runWriterReader( - SemaphoreArrowBatchWriteBuffer writeBuffer, - int totalRows, - AtomicInteger rowsWritten, - AtomicInteger rowsRead) - throws Exception { - Thread writerThread = - new Thread( - () -> { - for (int i = 0; i < totalRows; i++) { - InternalRow row = - new GenericInternalRow(new Object[] {rowsWritten.incrementAndGet()}); - writeBuffer.write(row); - } - writeBuffer.setFinished(); - }); - - Callable readerCallable = - () -> { - while (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - int rowCount = root.getRowCount(); - int baseValue = rowsRead.get(); - rowsRead.addAndGet(rowCount); - for (int i = 0; i < rowCount; i++) { - int value = (int) root.getVector("column1").getObject(i); - assertEquals(baseValue + i + 1, value); - } - } - return null; - }; - - FutureTask readerTask = writeBuffer.createTrackedTask(readerCallable); - - Thread readerThread = new Thread(readerTask); - writerThread.start(); - readerThread.start(); - writerThread.join(); - readerThread.join(); - } - - @Test - public void test() throws Exception { - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); - - final int totalRows = 125; - final int batchSize = 34; - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - - runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead); - - try { - assertEquals(totalRows, rowsWritten.get()); - assertEquals(totalRows, rowsRead.get()); - } finally { - writeBuffer.close(); - } - } - } - - @Test - public void testEmptyWrite() throws Exception { - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); - - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer(allocator, schema, sparkSchema, 34); - - AtomicInteger batchCount = new AtomicInteger(0); - - Thread writerThread = new Thread(writeBuffer::setFinished); - - Callable readerCallable = - () -> { - while (writeBuffer.loadNextBatch()) { - batchCount.incrementAndGet(); - } - return null; - }; - - FutureTask readerTask = writeBuffer.createTrackedTask(readerCallable); - - Thread readerThread = new Thread(readerTask); - writerThread.start(); - readerThread.start(); - writerThread.join(); - readerThread.join(); - - assertEquals(0, batchCount.get()); - writeBuffer.close(); - } - } - - @Test - public void testExactBatchBoundary() throws Exception { - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); - - final int batchSize = 25; - final int totalRows = batchSize * 4; // exactly 100 rows = 4 full batches - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - - runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead); - - try { - assertEquals(totalRows, rowsWritten.get()); - assertEquals(totalRows, rowsRead.get()); - } finally { - writeBuffer.close(); - } - } - } - - @Test - public void testSingleRowBatch() throws Exception { - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); - - final int totalRows = 50; - final int batchSize = 1; - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - - runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead); - - try { - assertEquals(totalRows, rowsWritten.get()); - assertEquals(totalRows, rowsRead.get()); - } finally { - writeBuffer.close(); - } - } - } - - @Test - public void testWriteErrorPropagation() throws Exception { - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); - - final int totalRows = 125; - final int batchSize = 34; - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer(allocator, schema, sparkSchema, batchSize); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - AtomicLong expectedBytesRead = new AtomicLong(0); - - Callable read = - () -> { - if (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - int rowCount = root.getRowCount(); - rowsRead.addAndGet(rowCount); - try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { - expectedBytesRead.addAndGet(recordBatch.computeBodyLength()); - } - for (int i = 0; i < rowCount; i++) { - int value = (int) root.getVector("column1").getObject(i); - assertEquals(value, rowsRead.get() - rowCount + i + 1); - } - - // Throw a mock exception after reading a batch - throw new RuntimeException("Mock exception"); - } - return rowsRead.get(); - }; - - // Start background thread to read data - FutureTask readTask = writeBuffer.createTrackedTask(read); - Thread readerThread = new Thread(readTask); - readerThread.start(); - - // Write data - Assertions.assertThrows( - RuntimeException.class, - () -> { - for (int i = 0; i < totalRows; i++) { - InternalRow row = new GenericInternalRow(new Object[] {i + 1}); - writeBuffer.write(row); - rowsWritten.incrementAndGet(); - } - writeBuffer.setFinished(); - }); - - Assertions.assertThrows(ExecutionException.class, readTask::get); - - assertEquals(batchSize, rowsWritten.get()); - assertEquals(batchSize, rowsRead.get()); - writeBuffer.close(); - } - } - - // ========== Byte-based flush tests ========== - - private Schema createStringSchema() { - Field field = - new Field( - "data", - FieldType.nullable(org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()), - null); - return new Schema(Collections.singletonList(field)); - } - - private StructType createStringSparkSchema() { - return new StructType( - new StructField[] {DataTypes.createStructField("data", DataTypes.StringType, true)}); - } - - /** Generate a string of approximately the given size in bytes. */ - private UTF8String generateLargeString(int sizeBytes) { - byte[] data = new byte[sizeBytes]; - java.util.Arrays.fill(data, (byte) 'A'); - return UTF8String.fromBytes(data); - } - - @Test - public void testByteBasedFlushWithSmallRows() throws Exception { - // With small rows, the row count limit should be reached before byte limit. - // This verifies that maxBatchBytes does not interfere with normal row-count flushing. - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createIntSchema(); - StructType sparkSchema = createIntSparkSchema(); - - final int totalRows = 100; - final int batchSize = 25; - final long maxBatchBytes = 100 * 1024 * 1024; // 100MB - should never be reached - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, maxBatchBytes); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - - runWriterReader(writeBuffer, totalRows, rowsWritten, rowsRead); - - try { - assertEquals(totalRows, rowsWritten.get()); - assertEquals(totalRows, rowsRead.get()); - } finally { - writeBuffer.close(); - } - } - } - - @Test - public void testByteBasedFlush() throws Exception { - // Each row is ~100KB. With batchSize=1000 (would be 100MB), but maxBatchBytes=256KB, - // we should see batches flush after ~2-3 rows instead of waiting for 1000. - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createStringSchema(); - StructType sparkSchema = createStringSparkSchema(); - - final int totalRows = 20; - final int batchSize = 1000; // High row limit - should never be reached - final long maxBatchBytes = 256 * 1024; // 256KB - should trigger flush after ~2 rows - final int rowSizeBytes = 100 * 1024; // ~100KB per row - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, maxBatchBytes); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - AtomicInteger batchCount = new AtomicInteger(0); - AtomicInteger maxRowsInBatch = new AtomicInteger(0); - - Thread writerThread = - new Thread( - () -> { - try { - for (int i = 0; i < totalRows; i++) { - UTF8String largeValue = generateLargeString(rowSizeBytes); - InternalRow row = new GenericInternalRow(new Object[] {largeValue}); - writeBuffer.write(row); - rowsWritten.incrementAndGet(); - } - } finally { - writeBuffer.setFinished(); - } - }); - - Callable readerCallable = - () -> { - while (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - int rowCount = root.getRowCount(); - rowsRead.addAndGet(rowCount); - batchCount.incrementAndGet(); - maxRowsInBatch.updateAndGet(prev -> Math.max(prev, rowCount)); - } - return null; - }; - - FutureTask readerTask = writeBuffer.createTrackedTask(readerCallable); - Thread readerThread = new Thread(readerTask); - writerThread.start(); - readerThread.start(); - writerThread.join(); - readerThread.join(); - - try { - assertEquals(totalRows, rowsWritten.get()); - assertEquals(totalRows, rowsRead.get()); - // With 100KB rows and 256KB limit, each batch should have at most ~3 rows. - // Without byte-based flush, we'd get 1 batch of 20 rows (since batchSize=1000). - Assertions.assertTrue( - batchCount.get() > 1, - "Should have multiple batches due to byte-based flushing, but got " + batchCount.get()); - Assertions.assertTrue( - maxRowsInBatch.get() < batchSize, - "Max rows per batch (" - + maxRowsInBatch.get() - + ") should be less than batchSize (" - + batchSize - + ") due to byte-based flushing"); - } finally { - writeBuffer.close(); - } - } - } - - @Test - public void testByteBasedFlushSingleLargeRow() throws Exception { - // A single row exceeds maxBatchBytes - should flush after each row. - try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { - Schema schema = createStringSchema(); - StructType sparkSchema = createStringSparkSchema(); - - final int totalRows = 5; - final int batchSize = 1000; - final long maxBatchBytes = 1024; // 1KB - each row will exceed this - final int rowSizeBytes = 10 * 1024; // 10KB per row - final SemaphoreArrowBatchWriteBuffer writeBuffer = - new SemaphoreArrowBatchWriteBuffer( - allocator, schema, sparkSchema, batchSize, maxBatchBytes); - - AtomicInteger rowsWritten = new AtomicInteger(0); - AtomicInteger rowsRead = new AtomicInteger(0); - AtomicInteger batchCount = new AtomicInteger(0); - - Thread writerThread = - new Thread( - () -> { - try { - for (int i = 0; i < totalRows; i++) { - UTF8String largeValue = generateLargeString(rowSizeBytes); - InternalRow row = new GenericInternalRow(new Object[] {largeValue}); - writeBuffer.write(row); - rowsWritten.incrementAndGet(); - } - } finally { - writeBuffer.setFinished(); - } - }); - - Callable readerCallable = - () -> { - while (writeBuffer.loadNextBatch()) { - VectorSchemaRoot root = writeBuffer.getVectorSchemaRoot(); - rowsRead.addAndGet(root.getRowCount()); - batchCount.incrementAndGet(); - } - return null; - }; - - FutureTask readerTask = writeBuffer.createTrackedTask(readerCallable); - Thread readerThread = new Thread(readerTask); - writerThread.start(); - readerThread.start(); - writerThread.join(); - readerThread.join(); - - try { - assertEquals(totalRows, rowsWritten.get()); - assertEquals(totalRows, rowsRead.get()); - // Each row should produce its own batch since each exceeds the byte limit - assertEquals( - totalRows, - batchCount.get(), - "Each row should be its own batch when row size exceeds maxBatchBytes"); - } finally { - writeBuffer.close(); - } - } - } -}