Skip to content

Commit 4c754b1

Browse files
committed
fix
1 parent 0d55ba7 commit 4c754b1

3 files changed

Lines changed: 261 additions & 3 deletions

File tree

compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ protected ArrowBuf doCompress(BufferAllocator allocator, ArrowBuf uncompressedBu
4444
long bytesWritten =
4545
Zstd.compressUnsafe(
4646
compressedBuffer.memoryAddress() + CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH,
47-
dstSize,
47+
maxSize,
4848
/*src*/ uncompressedBuffer.memoryAddress(),
4949
/* srcSize= */ uncompressedBuffer.writerIndex(),
5050
/* level= */ this.compressionLevel);

compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
import org.apache.arrow.memory.BufferAllocator;
3939
import org.apache.arrow.memory.RootAllocator;
4040
import org.apache.arrow.util.AutoCloseables;
41+
import org.apache.arrow.vector.FieldVector;
4142
import org.apache.arrow.vector.IntVector;
43+
import org.apache.arrow.vector.TimeStampMilliVector;
4244
import org.apache.arrow.vector.VarBinaryVector;
4345
import org.apache.arrow.vector.VarCharVector;
4446
import org.apache.arrow.vector.VectorSchemaRoot;
@@ -53,12 +55,15 @@
5355
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
5456
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
5557
import org.apache.arrow.vector.ipc.message.IpcOption;
58+
import org.apache.arrow.vector.types.TimeUnit;
5659
import org.apache.arrow.vector.types.pojo.ArrowType;
5760
import org.apache.arrow.vector.types.pojo.Field;
61+
import org.apache.arrow.vector.types.pojo.FieldType;
5862
import org.apache.arrow.vector.types.pojo.Schema;
5963
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
6064
import org.junit.jupiter.api.AfterEach;
6165
import org.junit.jupiter.api.BeforeEach;
66+
import org.junit.jupiter.api.Test;
6267
import org.junit.jupiter.params.ParameterizedTest;
6368
import org.junit.jupiter.params.provider.Arguments;
6469
import org.junit.jupiter.params.provider.MethodSource;
@@ -347,6 +352,253 @@ void testUnloadCompressed(CompressionUtil.CodecType codec) {
347352
});
348353
}
349354

355+
/**
356+
* Test multi-batch streaming with ZSTD compression, wide schema, VectorSchemaRoot reuse, and
357+
* all-null columns. This reproduces the scenario from GH-1116 where the 8-byte
358+
* uncompressed-length prefix of a compressed buffer could be incorrectly written as 0.
359+
*/
360+
@Test
361+
void testMultiBatchZstdStreamWithWideSchemaAndAllNulls() throws Exception {
362+
final int fieldCount = 100;
363+
final int batchCount = 10;
364+
final int rowsPerBatch = 500;
365+
366+
// Build a wide schema: mix of int, timestamp, and varchar fields
367+
List<Field> fields = new ArrayList<>();
368+
for (int i = 0; i < fieldCount; i++) {
369+
switch (i % 3) {
370+
case 0:
371+
fields.add(Field.nullable("int_" + i, new ArrowType.Int(32, true)));
372+
break;
373+
case 1:
374+
fields.add(
375+
Field.nullable("ts_" + i, new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)));
376+
break;
377+
case 2:
378+
fields.add(Field.nullable("str_" + i, ArrowType.Utf8.INSTANCE));
379+
break;
380+
default:
381+
break;
382+
}
383+
}
384+
Schema schema = new Schema(fields);
385+
386+
ByteArrayOutputStream out = new ByteArrayOutputStream();
387+
try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
388+
ArrowStreamWriter writer =
389+
new ArrowStreamWriter(
390+
root,
391+
new DictionaryProvider.MapDictionaryProvider(),
392+
Channels.newChannel(out),
393+
IpcOption.DEFAULT,
394+
CommonsCompressionFactory.INSTANCE,
395+
CompressionUtil.CodecType.ZSTD)) {
396+
writer.start();
397+
398+
for (int batch = 0; batch < batchCount; batch++) {
399+
// Clear and reallocate — mimics the reporter's reuse pattern
400+
root.clear();
401+
for (FieldVector vector : root.getFieldVectors()) {
402+
vector.allocateNew();
403+
}
404+
root.setRowCount(rowsPerBatch);
405+
406+
for (int col = 0; col < fieldCount; col++) {
407+
FieldVector vector = root.getVector(col);
408+
// Make some batches have all-null columns for certain fields
409+
boolean allNull = (batch % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch
410+
switch (col % 3) {
411+
case 0:
412+
{
413+
IntVector iv = (IntVector) vector;
414+
for (int row = 0; row < rowsPerBatch; row++) {
415+
if (allNull || row % 7 == 0) {
416+
iv.setNull(row);
417+
} else {
418+
iv.setSafe(row, batch * rowsPerBatch + row);
419+
}
420+
}
421+
break;
422+
}
423+
case 1:
424+
{
425+
TimeStampMilliVector tv = (TimeStampMilliVector) vector;
426+
for (int row = 0; row < rowsPerBatch; row++) {
427+
if (allNull || row % 5 == 0) {
428+
tv.setNull(row);
429+
} else {
430+
tv.setSafe(row, 1_700_000_000_000L + (long) batch * rowsPerBatch + row);
431+
}
432+
}
433+
break;
434+
}
435+
case 2:
436+
{
437+
VarCharVector sv = (VarCharVector) vector;
438+
for (int row = 0; row < rowsPerBatch; row++) {
439+
if (allNull || row % 9 == 0) {
440+
sv.setNull(row);
441+
} else {
442+
sv.setSafe(row, ("val_" + batch + "_" + row).getBytes(StandardCharsets.UTF_8));
443+
}
444+
}
445+
break;
446+
}
447+
default:
448+
break;
449+
}
450+
vector.setValueCount(rowsPerBatch);
451+
}
452+
453+
writer.writeBatch();
454+
}
455+
writer.end();
456+
}
457+
458+
// Read back and verify all batches round-trip correctly
459+
try (ArrowStreamReader reader =
460+
new ArrowStreamReader(
461+
new ByteArrayReadableSeekableByteChannel(out.toByteArray()),
462+
allocator,
463+
CommonsCompressionFactory.INSTANCE)) {
464+
int batchesRead = 0;
465+
while (reader.loadNextBatch()) {
466+
VectorSchemaRoot readRoot = reader.getVectorSchemaRoot();
467+
assertEquals(rowsPerBatch, readRoot.getRowCount());
468+
assertEquals(fieldCount, readRoot.getFieldVectors().size());
469+
470+
// Verify data values, null patterns, and all-null columns
471+
for (int col = 0; col < fieldCount; col++) {
472+
FieldVector vector = readRoot.getVector(col);
473+
boolean allNull =
474+
(batchesRead % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch
475+
if (allNull) {
476+
// The key scenario: all-null columns must survive compression round-trip
477+
assertEquals(
478+
rowsPerBatch,
479+
vector.getNullCount(),
480+
"All-null column col=" + col + " batch=" + batchesRead);
481+
}
482+
for (int row = 0; row < rowsPerBatch; row++) {
483+
switch (col % 3) {
484+
case 0:
485+
{
486+
IntVector iv = (IntVector) vector;
487+
if (allNull || row % 7 == 0) {
488+
assertTrue(
489+
iv.isNull(row),
490+
"Expected null at col=" + col + " row=" + row + " batch=" + batchesRead);
491+
} else {
492+
assertEquals(
493+
batchesRead * rowsPerBatch + row,
494+
iv.get(row),
495+
"Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead);
496+
}
497+
break;
498+
}
499+
case 1:
500+
{
501+
TimeStampMilliVector tv = (TimeStampMilliVector) vector;
502+
if (allNull || row % 5 == 0) {
503+
assertTrue(
504+
tv.isNull(row),
505+
"Expected null at col=" + col + " row=" + row + " batch=" + batchesRead);
506+
} else {
507+
assertEquals(
508+
1_700_000_000_000L + (long) batchesRead * rowsPerBatch + row,
509+
tv.get(row),
510+
"Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead);
511+
}
512+
break;
513+
}
514+
case 2:
515+
{
516+
VarCharVector sv = (VarCharVector) vector;
517+
if (allNull || row % 9 == 0) {
518+
assertTrue(
519+
sv.isNull(row),
520+
"Expected null at col=" + col + " row=" + row + " batch=" + batchesRead);
521+
} else {
522+
assertArrayEquals(
523+
("val_" + batchesRead + "_" + row).getBytes(StandardCharsets.UTF_8),
524+
sv.get(row),
525+
"Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead);
526+
}
527+
break;
528+
}
529+
default:
530+
break;
531+
}
532+
}
533+
}
534+
batchesRead++;
535+
}
536+
assertEquals(batchCount, batchesRead);
537+
}
538+
}
539+
540+
/**
541+
* Test that an all-null fixed-width vector compresses and decompresses correctly. The data buffer
542+
* for such a vector contains all zeros but has a non-zero writerIndex (valueCount * typeWidth).
543+
* The compressed buffer's uncompressed-length prefix must reflect this non-zero size.
544+
*/
545+
@Test
546+
void testAllNullFixedWidthVectorZstdRoundTrip() throws Exception {
547+
final int rowCount = 3469; // same count as the reported issue
548+
final CompressionCodec codec = new ZstdCompressionCodec();
549+
550+
try (TimeStampMilliVector origVec =
551+
new TimeStampMilliVector(
552+
"ts",
553+
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
554+
allocator)) {
555+
origVec.allocateNew(rowCount);
556+
// Set all values to null
557+
for (int i = 0; i < rowCount; i++) {
558+
origVec.setNull(i);
559+
}
560+
origVec.setValueCount(rowCount);
561+
562+
assertEquals(rowCount, origVec.getNullCount());
563+
564+
// Compress and decompress each buffer
565+
List<ArrowBuf> origBuffers = origVec.getFieldBuffers();
566+
assertEquals(2, origBuffers.size());
567+
568+
// The data buffer (index 1) should have non-zero writerIndex even though all values are null
569+
ArrowBuf dataBuffer = origBuffers.get(1);
570+
long expectedDataSize = (long) rowCount * 8; // TimestampMilli = 8 bytes per value
571+
assertEquals(expectedDataSize, dataBuffer.writerIndex());
572+
573+
// Retain buffers before compressing since compress() closes the input buffer.
574+
// This mirrors what VectorUnloader.appendNodes() does.
575+
for (ArrowBuf buf : origBuffers) {
576+
buf.getReferenceManager().retain();
577+
}
578+
List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers);
579+
List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers);
580+
581+
assertEquals(2, decompressedBuffers.size());
582+
583+
// The decompressed data buffer should have the same writerIndex as the original
584+
assertEquals(expectedDataSize, decompressedBuffers.get(1).writerIndex());
585+
586+
// Load into a new vector and verify
587+
try (TimeStampMilliVector newVec =
588+
new TimeStampMilliVector(
589+
"ts_new",
590+
FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
591+
allocator)) {
592+
newVec.loadFieldBuffers(new ArrowFieldNode(rowCount, rowCount), decompressedBuffers);
593+
assertEquals(rowCount, newVec.getValueCount());
594+
for (int i = 0; i < rowCount; i++) {
595+
assertTrue(newVec.isNull(i));
596+
}
597+
}
598+
AutoCloseables.close(decompressedBuffers);
599+
}
600+
}
601+
350602
void withRoot(
351603
CompressionUtil.CodecType codec,
352604
BiConsumer<CompressionCodec.Factory, VectorSchemaRoot> testBody) {

vector/src/main/java/org/apache/arrow/vector/compression/AbstractCompressionCodec.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@ public abstract class AbstractCompressionCodec implements CompressionCodec {
2929

3030
@Override
3131
public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) {
32-
if (uncompressedBuffer.writerIndex() == 0L) {
32+
// Capture the uncompressed length once upfront to avoid any inconsistency from
33+
// re-reading writerIndex() at different points. Since the uncompressedBuffer may be
34+
// a shared reference to a vector's internal buffer, reading writerIndex() only once
35+
// ensures the same value is used for the empty-buffer check, compression, size
36+
// comparison, and the 8-byte uncompressed-length prefix.
37+
long uncompressedLength = uncompressedBuffer.writerIndex();
38+
39+
if (uncompressedLength == 0L) {
3340
// shortcut for empty buffer
3441
ArrowBuf compressedBuffer = allocator.buffer(CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH);
3542
compressedBuffer.setLong(0, 0);
@@ -41,7 +48,6 @@ public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer)
4148
ArrowBuf compressedBuffer = doCompress(allocator, uncompressedBuffer);
4249
long compressedLength =
4350
compressedBuffer.writerIndex() - CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH;
44-
long uncompressedLength = uncompressedBuffer.writerIndex();
4551

4652
if (compressedLength > uncompressedLength) {
4753
// compressed buffer is larger, send the raw buffer

0 commit comments

Comments
 (0)