|
38 | 38 | import org.apache.arrow.memory.BufferAllocator; |
39 | 39 | import org.apache.arrow.memory.RootAllocator; |
40 | 40 | import org.apache.arrow.util.AutoCloseables; |
| 41 | +import org.apache.arrow.vector.FieldVector; |
41 | 42 | import org.apache.arrow.vector.IntVector; |
| 43 | +import org.apache.arrow.vector.TimeStampMilliVector; |
42 | 44 | import org.apache.arrow.vector.VarBinaryVector; |
43 | 45 | import org.apache.arrow.vector.VarCharVector; |
44 | 46 | import org.apache.arrow.vector.VectorSchemaRoot; |
|
53 | 55 | import org.apache.arrow.vector.ipc.ArrowStreamWriter; |
54 | 56 | import org.apache.arrow.vector.ipc.message.ArrowFieldNode; |
55 | 57 | import org.apache.arrow.vector.ipc.message.IpcOption; |
| 58 | +import org.apache.arrow.vector.types.TimeUnit; |
56 | 59 | import org.apache.arrow.vector.types.pojo.ArrowType; |
57 | 60 | import org.apache.arrow.vector.types.pojo.Field; |
| 61 | +import org.apache.arrow.vector.types.pojo.FieldType; |
58 | 62 | import org.apache.arrow.vector.types.pojo.Schema; |
59 | 63 | import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; |
60 | 64 | import org.junit.jupiter.api.AfterEach; |
61 | 65 | import org.junit.jupiter.api.BeforeEach; |
| 66 | +import org.junit.jupiter.api.Test; |
62 | 67 | import org.junit.jupiter.params.ParameterizedTest; |
63 | 68 | import org.junit.jupiter.params.provider.Arguments; |
64 | 69 | import org.junit.jupiter.params.provider.MethodSource; |
@@ -347,6 +352,253 @@ void testUnloadCompressed(CompressionUtil.CodecType codec) { |
347 | 352 | }); |
348 | 353 | } |
349 | 354 |
|
| 355 | + /** |
| 356 | + * Test multi-batch streaming with ZSTD compression, wide schema, VectorSchemaRoot reuse, and |
| 357 | + * all-null columns. This reproduces the scenario from GH-1116 where the 8-byte |
| 358 | + * uncompressed-length prefix of a compressed buffer could be incorrectly written as 0. |
| 359 | + */ |
| 360 | + @Test |
| 361 | + void testMultiBatchZstdStreamWithWideSchemaAndAllNulls() throws Exception { |
| 362 | + final int fieldCount = 100; |
| 363 | + final int batchCount = 10; |
| 364 | + final int rowsPerBatch = 500; |
| 365 | + |
| 366 | + // Build a wide schema: mix of int, timestamp, and varchar fields |
| 367 | + List<Field> fields = new ArrayList<>(); |
| 368 | + for (int i = 0; i < fieldCount; i++) { |
| 369 | + switch (i % 3) { |
| 370 | + case 0: |
| 371 | + fields.add(Field.nullable("int_" + i, new ArrowType.Int(32, true))); |
| 372 | + break; |
| 373 | + case 1: |
| 374 | + fields.add( |
| 375 | + Field.nullable("ts_" + i, new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); |
| 376 | + break; |
| 377 | + case 2: |
| 378 | + fields.add(Field.nullable("str_" + i, ArrowType.Utf8.INSTANCE)); |
| 379 | + break; |
| 380 | + default: |
| 381 | + break; |
| 382 | + } |
| 383 | + } |
| 384 | + Schema schema = new Schema(fields); |
| 385 | + |
| 386 | + ByteArrayOutputStream out = new ByteArrayOutputStream(); |
| 387 | + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); |
| 388 | + ArrowStreamWriter writer = |
| 389 | + new ArrowStreamWriter( |
| 390 | + root, |
| 391 | + new DictionaryProvider.MapDictionaryProvider(), |
| 392 | + Channels.newChannel(out), |
| 393 | + IpcOption.DEFAULT, |
| 394 | + CommonsCompressionFactory.INSTANCE, |
| 395 | + CompressionUtil.CodecType.ZSTD)) { |
| 396 | + writer.start(); |
| 397 | + |
| 398 | + for (int batch = 0; batch < batchCount; batch++) { |
| 399 | + // Clear and reallocate — mimics the reporter's reuse pattern |
| 400 | + root.clear(); |
| 401 | + for (FieldVector vector : root.getFieldVectors()) { |
| 402 | + vector.allocateNew(); |
| 403 | + } |
| 404 | + root.setRowCount(rowsPerBatch); |
| 405 | + |
| 406 | + for (int col = 0; col < fieldCount; col++) { |
| 407 | + FieldVector vector = root.getVector(col); |
| 408 | + // Make some batches have all-null columns for certain fields |
| 409 | + boolean allNull = (batch % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch |
| 410 | + switch (col % 3) { |
| 411 | + case 0: |
| 412 | + { |
| 413 | + IntVector iv = (IntVector) vector; |
| 414 | + for (int row = 0; row < rowsPerBatch; row++) { |
| 415 | + if (allNull || row % 7 == 0) { |
| 416 | + iv.setNull(row); |
| 417 | + } else { |
| 418 | + iv.setSafe(row, batch * rowsPerBatch + row); |
| 419 | + } |
| 420 | + } |
| 421 | + break; |
| 422 | + } |
| 423 | + case 1: |
| 424 | + { |
| 425 | + TimeStampMilliVector tv = (TimeStampMilliVector) vector; |
| 426 | + for (int row = 0; row < rowsPerBatch; row++) { |
| 427 | + if (allNull || row % 5 == 0) { |
| 428 | + tv.setNull(row); |
| 429 | + } else { |
| 430 | + tv.setSafe(row, 1_700_000_000_000L + (long) batch * rowsPerBatch + row); |
| 431 | + } |
| 432 | + } |
| 433 | + break; |
| 434 | + } |
| 435 | + case 2: |
| 436 | + { |
| 437 | + VarCharVector sv = (VarCharVector) vector; |
| 438 | + for (int row = 0; row < rowsPerBatch; row++) { |
| 439 | + if (allNull || row % 9 == 0) { |
| 440 | + sv.setNull(row); |
| 441 | + } else { |
| 442 | + sv.setSafe(row, ("val_" + batch + "_" + row).getBytes(StandardCharsets.UTF_8)); |
| 443 | + } |
| 444 | + } |
| 445 | + break; |
| 446 | + } |
| 447 | + default: |
| 448 | + break; |
| 449 | + } |
| 450 | + vector.setValueCount(rowsPerBatch); |
| 451 | + } |
| 452 | + |
| 453 | + writer.writeBatch(); |
| 454 | + } |
| 455 | + writer.end(); |
| 456 | + } |
| 457 | + |
| 458 | + // Read back and verify all batches round-trip correctly |
| 459 | + try (ArrowStreamReader reader = |
| 460 | + new ArrowStreamReader( |
| 461 | + new ByteArrayReadableSeekableByteChannel(out.toByteArray()), |
| 462 | + allocator, |
| 463 | + CommonsCompressionFactory.INSTANCE)) { |
| 464 | + int batchesRead = 0; |
| 465 | + while (reader.loadNextBatch()) { |
| 466 | + VectorSchemaRoot readRoot = reader.getVectorSchemaRoot(); |
| 467 | + assertEquals(rowsPerBatch, readRoot.getRowCount()); |
| 468 | + assertEquals(fieldCount, readRoot.getFieldVectors().size()); |
| 469 | + |
| 470 | + // Verify data values, null patterns, and all-null columns |
| 471 | + for (int col = 0; col < fieldCount; col++) { |
| 472 | + FieldVector vector = readRoot.getVector(col); |
| 473 | + boolean allNull = |
| 474 | + (batchesRead % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch |
| 475 | + if (allNull) { |
| 476 | + // The key scenario: all-null columns must survive compression round-trip |
| 477 | + assertEquals( |
| 478 | + rowsPerBatch, |
| 479 | + vector.getNullCount(), |
| 480 | + "All-null column col=" + col + " batch=" + batchesRead); |
| 481 | + } |
| 482 | + for (int row = 0; row < rowsPerBatch; row++) { |
| 483 | + switch (col % 3) { |
| 484 | + case 0: |
| 485 | + { |
| 486 | + IntVector iv = (IntVector) vector; |
| 487 | + if (allNull || row % 7 == 0) { |
| 488 | + assertTrue( |
| 489 | + iv.isNull(row), |
| 490 | + "Expected null at col=" + col + " row=" + row + " batch=" + batchesRead); |
| 491 | + } else { |
| 492 | + assertEquals( |
| 493 | + batchesRead * rowsPerBatch + row, |
| 494 | + iv.get(row), |
| 495 | + "Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead); |
| 496 | + } |
| 497 | + break; |
| 498 | + } |
| 499 | + case 1: |
| 500 | + { |
| 501 | + TimeStampMilliVector tv = (TimeStampMilliVector) vector; |
| 502 | + if (allNull || row % 5 == 0) { |
| 503 | + assertTrue( |
| 504 | + tv.isNull(row), |
| 505 | + "Expected null at col=" + col + " row=" + row + " batch=" + batchesRead); |
| 506 | + } else { |
| 507 | + assertEquals( |
| 508 | + 1_700_000_000_000L + (long) batchesRead * rowsPerBatch + row, |
| 509 | + tv.get(row), |
| 510 | + "Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead); |
| 511 | + } |
| 512 | + break; |
| 513 | + } |
| 514 | + case 2: |
| 515 | + { |
| 516 | + VarCharVector sv = (VarCharVector) vector; |
| 517 | + if (allNull || row % 9 == 0) { |
| 518 | + assertTrue( |
| 519 | + sv.isNull(row), |
| 520 | + "Expected null at col=" + col + " row=" + row + " batch=" + batchesRead); |
| 521 | + } else { |
| 522 | + assertArrayEquals( |
| 523 | + ("val_" + batchesRead + "_" + row).getBytes(StandardCharsets.UTF_8), |
| 524 | + sv.get(row), |
| 525 | + "Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead); |
| 526 | + } |
| 527 | + break; |
| 528 | + } |
| 529 | + default: |
| 530 | + break; |
| 531 | + } |
| 532 | + } |
| 533 | + } |
| 534 | + batchesRead++; |
| 535 | + } |
| 536 | + assertEquals(batchCount, batchesRead); |
| 537 | + } |
| 538 | + } |
| 539 | + |
| 540 | + /** |
| 541 | + * Test that an all-null fixed-width vector compresses and decompresses correctly. The data buffer |
| 542 | + * for such a vector contains all zeros but has a non-zero writerIndex (valueCount * typeWidth). |
| 543 | + * The compressed buffer's uncompressed-length prefix must reflect this non-zero size. |
| 544 | + */ |
| 545 | + @Test |
| 546 | + void testAllNullFixedWidthVectorZstdRoundTrip() throws Exception { |
| 547 | + final int rowCount = 3469; // same count as the reported issue |
| 548 | + final CompressionCodec codec = new ZstdCompressionCodec(); |
| 549 | + |
| 550 | + try (TimeStampMilliVector origVec = |
| 551 | + new TimeStampMilliVector( |
| 552 | + "ts", |
| 553 | + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), |
| 554 | + allocator)) { |
| 555 | + origVec.allocateNew(rowCount); |
| 556 | + // Set all values to null |
| 557 | + for (int i = 0; i < rowCount; i++) { |
| 558 | + origVec.setNull(i); |
| 559 | + } |
| 560 | + origVec.setValueCount(rowCount); |
| 561 | + |
| 562 | + assertEquals(rowCount, origVec.getNullCount()); |
| 563 | + |
| 564 | + // Compress and decompress each buffer |
| 565 | + List<ArrowBuf> origBuffers = origVec.getFieldBuffers(); |
| 566 | + assertEquals(2, origBuffers.size()); |
| 567 | + |
| 568 | + // The data buffer (index 1) should have non-zero writerIndex even though all values are null |
| 569 | + ArrowBuf dataBuffer = origBuffers.get(1); |
| 570 | + long expectedDataSize = (long) rowCount * 8; // TimestampMilli = 8 bytes per value |
| 571 | + assertEquals(expectedDataSize, dataBuffer.writerIndex()); |
| 572 | + |
| 573 | + // Retain buffers before compressing since compress() closes the input buffer. |
| 574 | + // This mirrors what VectorUnloader.appendNodes() does. |
| 575 | + for (ArrowBuf buf : origBuffers) { |
| 576 | + buf.getReferenceManager().retain(); |
| 577 | + } |
| 578 | + List<ArrowBuf> compressedBuffers = compressBuffers(codec, origBuffers); |
| 579 | + List<ArrowBuf> decompressedBuffers = deCompressBuffers(codec, compressedBuffers); |
| 580 | + |
| 581 | + assertEquals(2, decompressedBuffers.size()); |
| 582 | + |
| 583 | + // The decompressed data buffer should have the same writerIndex as the original |
| 584 | + assertEquals(expectedDataSize, decompressedBuffers.get(1).writerIndex()); |
| 585 | + |
| 586 | + // Load into a new vector and verify |
| 587 | + try (TimeStampMilliVector newVec = |
| 588 | + new TimeStampMilliVector( |
| 589 | + "ts_new", |
| 590 | + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), |
| 591 | + allocator)) { |
| 592 | + newVec.loadFieldBuffers(new ArrowFieldNode(rowCount, rowCount), decompressedBuffers); |
| 593 | + assertEquals(rowCount, newVec.getValueCount()); |
| 594 | + for (int i = 0; i < rowCount; i++) { |
| 595 | + assertTrue(newVec.isNull(i)); |
| 596 | + } |
| 597 | + } |
| 598 | + AutoCloseables.close(decompressedBuffers); |
| 599 | + } |
| 600 | + } |
| 601 | + |
350 | 602 | void withRoot( |
351 | 603 | CompressionUtil.CodecType codec, |
352 | 604 | BiConsumer<CompressionCodec.Factory, VectorSchemaRoot> testBody) { |
|
0 commit comments