Skip to content

Commit e2478aa

Browse files
authored
Save children dtypes/length in ZstdBuffersMetadata (#8572)
Before this change, ZstdBuffers::deserialize set children's dtype to its own dtype. In ZstdBuffers children belong to inner encodings and they have their own dtypes. This triggered a deserialization issue. Add children dtypes and lengths to metadata. Resolves: #8549 Signed-off-by: Mikhail Kot <mikhail@spiraldb.com>
1 parent 2a19323 commit e2478aa

2 files changed

Lines changed: 68 additions & 2 deletions

File tree

encodings/zstd/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
//! ```
2323
2424
pub use array::*;
25+
use vortex_array::dtype::proto::dtype as pb;
2526
#[cfg(feature = "unstable_encodings")]
2627
pub use zstd_buffers::*;
2728

@@ -73,4 +74,12 @@ pub struct ZstdBuffersMetadata {
7374
/// Alignment of each buffer in bytes (must be a power of two).
7475
#[prost(uint32, repeated, tag = "4")]
7576
pub buffer_alignments: Vec<u32>,
77+
/// DType of child arrays. Children belong to inner encodings, and their
78+
/// dtypes don't persist after serialization, so we need to retrieve them
79+
/// from metadata.
80+
#[prost(message, repeated, tag = "5")]
81+
pub child_dtypes: Vec<pb::DType>,
82+
/// Length of each child array, ordered as "child_dtypes"
83+
#[prost(uint64, repeated, tag = "6")]
84+
pub child_lens: Vec<u64>,
7685
}

encodings/zstd/src/zstd_buffers.rs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,21 @@ impl VTable for ZstdBuffers {
419419
array: ArrayView<'_, Self>,
420420
_session: &VortexSession,
421421
) -> VortexResult<Option<Vec<u8>>> {
422+
let children: Vec<&ArrayRef> = array.slots().iter().flatten().collect();
423+
let child_dtypes = children
424+
.iter()
425+
.map(|child| child.dtype().try_into())
426+
.collect::<VortexResult<Vec<_>>>()?;
427+
let child_lens = children.iter().map(|child| child.len() as u64).collect();
428+
422429
Ok(Some(
423430
ZstdBuffersMetadata {
424431
inner_encoding_id: array.inner_encoding_id.to_string(),
425432
inner_metadata: array.inner_metadata.clone(),
426433
uncompressed_sizes: array.uncompressed_sizes.clone(),
427434
buffer_alignments: array.buffer_alignments.clone(),
435+
child_dtypes,
436+
child_lens,
428437
}
429438
.encode_to_vec(),
430439
))
@@ -437,13 +446,23 @@ impl VTable for ZstdBuffers {
437446
metadata: &[u8],
438447
buffers: &[BufferHandle],
439448
children: &dyn ArrayChildren,
440-
_session: &VortexSession,
449+
session: &VortexSession,
441450
) -> VortexResult<ArrayParts<Self>> {
442451
let metadata = ZstdBuffersMetadata::decode(metadata)?;
443452
let compressed_buffers: Vec<BufferHandle> = buffers.to_vec();
444453

454+
// Children belong to inner encodings, and serialization doesn't
455+
// preserve their dtypes and values. Check dtypes are recovered from
456+
// metadata.
457+
vortex_ensure_eq!(metadata.child_dtypes.len(), children.len());
458+
vortex_ensure_eq!(metadata.child_lens.len(), children.len());
459+
445460
let slots: ArraySlots = (0..children.len())
446-
.map(|i| children.get(i, dtype, len).map(Some))
461+
.map(|i| {
462+
let child_dtype = DType::from_proto(&metadata.child_dtypes[i], session)?;
463+
let child_len = usize::try_from(metadata.child_lens[i])?;
464+
children.get(i, &child_dtype, child_len).map(Some)
465+
})
447466
.collect::<VortexResult<Vec<_>>>()?
448467
.into();
449468

@@ -506,6 +525,7 @@ impl ValidityVTable<ZstdBuffers> for ZstdBuffers {
506525
#[cfg(test)]
507526
mod tests {
508527
use rstest::rstest;
528+
use vortex_array::ArrayContext;
509529
use vortex_array::ArrayRef;
510530
use vortex_array::IntoArray;
511531
use vortex_array::VortexSessionExecute;
@@ -516,7 +536,12 @@ mod tests {
516536
use vortex_array::expr::stats::Precision;
517537
use vortex_array::expr::stats::Stat;
518538
use vortex_array::expr::stats::StatsProvider;
539+
use vortex_array::serde::SerializeOptions;
540+
use vortex_array::serde::SerializedArray;
541+
use vortex_array::session::ArraySessionExt;
542+
use vortex_buffer::ByteBufferMut;
519543
use vortex_error::VortexResult;
544+
use vortex_session::registry::ReadContext;
520545

521546
use super::*;
522547

@@ -572,6 +597,38 @@ mod tests {
572597
Ok(())
573598
}
574599

600+
#[rstest]
601+
#[case::primitive(make_primitive_array())]
602+
#[case::varbinview(make_varbinview_array())]
603+
#[case::nullable_primitive(make_nullable_primitive_array())]
604+
#[case::nullable_varbinview(make_nullable_varbinview_array())]
605+
#[case::empty_primitive(make_empty_primitive_array())]
606+
#[case::inlined_varbinview(make_inlined_varbinview_array())]
607+
fn test_serde_roundtrip(#[case] input: ArrayRef) -> VortexResult<()> {
608+
let session = array_session();
609+
session.arrays().register(ZstdBuffers);
610+
611+
let compressed = ZstdBuffers::compress(&input, 3, &session)?.into_array();
612+
let dtype = compressed.dtype().clone();
613+
let len = compressed.len();
614+
615+
let array_ctx = ArrayContext::empty();
616+
let serialized =
617+
compressed.serialize(&array_ctx, &session, &SerializeOptions::default())?;
618+
619+
let mut concat = ByteBufferMut::empty();
620+
for buf in serialized {
621+
concat.extend_from_slice(buf.as_ref());
622+
}
623+
let parts = SerializedArray::try_from(concat.freeze())?;
624+
let decoded = parts.decode(&dtype, len, &ReadContext::new(array_ctx.to_ids()), &session)?;
625+
626+
let mut ctx = session.create_execution_ctx();
627+
let decoded = decoded.execute::<ArrayRef>(&mut ctx)?;
628+
assert_arrays_eq!(input, decoded, &mut ctx);
629+
Ok(())
630+
}
631+
575632
#[test]
576633
fn test_compress_inherits_stats() -> VortexResult<()> {
577634
let input = make_primitive_array();

0 commit comments

Comments
 (0)