Skip to content

Commit ff30b01

Browse files
committed
make sure we can roundtrip
Signed-off-by: Adam Gutglick <adam@spiraldb.com>
1 parent 12364bb commit ff30b01

2 files changed

Lines changed: 138 additions & 36 deletions

File tree

encodings/parquet-variant/src/lib.rs

Lines changed: 133 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ use vortex_array::vtable::ArrayId;
4848
use vortex_array::vtable::NotSupported;
4949
use vortex_array::vtable::VTable;
5050
use vortex_array::vtable::ValidityVTable;
51+
use vortex_array::vtable::validity_nchildren;
52+
use vortex_array::vtable::validity_to_child;
53+
use vortex_buffer::BitBuffer;
5154
use vortex_error::VortexExpect;
5255
use vortex_error::VortexResult;
5356
use vortex_error::vortex_bail;
@@ -102,6 +105,7 @@ struct ParquetVariantMetadataProto {
102105
/// where nested struct/list elements themselves contain value/typed_value children.
103106
#[derive(Clone, Debug)]
104107
pub struct ParquetVariantArray {
108+
validity: Validity,
105109
metadata: ArrayRef,
106110
value: Option<ArrayRef>,
107111
typed_value: Option<ArrayRef>,
@@ -116,12 +120,28 @@ impl ParquetVariantArray {
116120
metadata: ArrayRef,
117121
value: Option<ArrayRef>,
118122
typed_value: Option<ArrayRef>,
123+
) -> VortexResult<Self> {
124+
Self::try_new_with_validity(Validity::AllValid, metadata, value, typed_value)
125+
}
126+
127+
/// Creates a new ParquetVariantArray with explicit parent validity.
128+
pub fn try_new_with_validity(
129+
validity: Validity,
130+
metadata: ArrayRef,
131+
value: Option<ArrayRef>,
132+
typed_value: Option<ArrayRef>,
119133
) -> VortexResult<Self> {
120134
vortex_ensure!(
121135
value.is_some() || typed_value.is_some(),
122136
"at least one of value or typed_value must be present"
123137
);
124138
let len = metadata.len();
139+
if let Some(validity_len) = validity.maybe_len() {
140+
vortex_ensure!(
141+
validity_len == len,
142+
"validity length must match metadata length"
143+
);
144+
}
125145
if let Some(ref v) = value {
126146
vortex_ensure!(v.len() == len, "value length must match metadata length");
127147
}
@@ -132,6 +152,7 @@ impl ParquetVariantArray {
132152
);
133153
}
134154
Ok(Self {
155+
validity,
135156
metadata,
136157
value,
137158
typed_value,
@@ -154,30 +175,62 @@ impl ParquetVariantArray {
154175
self.typed_value.as_ref()
155176
}
156177

178+
/// Returns the parent row validity for the variant storage struct.
179+
pub fn validity(&self) -> &Validity {
180+
&self.validity
181+
}
182+
157183
/// Converts an Arrow `parquet_variant_compute::VariantArray` into a Vortex `ArrayRef`
158184
/// wrapping `VariantArray(ParquetVariantArray(...))`.
159185
pub fn from_arrow_variant(
160186
arrow_variant: &parquet_variant_compute::VariantArray,
161187
) -> VortexResult<ArrayRef> {
188+
let storage = arrow_variant.inner();
189+
let value_nullable = storage
190+
.fields()
191+
.iter()
192+
.find(|field| field.name() == "value")
193+
.map(|field| field.is_nullable())
194+
.unwrap_or(false);
195+
let typed_value_nullable = storage
196+
.fields()
197+
.iter()
198+
.find(|field| field.name() == "typed_value")
199+
.map(|field| field.is_nullable())
200+
.unwrap_or(false);
201+
let validity = arrow_variant
202+
.nulls()
203+
.map(|nulls| {
204+
if nulls.null_count() == nulls.len() {
205+
Validity::AllInvalid
206+
} else {
207+
Validity::from(BitBuffer::from(nulls.inner().clone()))
208+
}
209+
})
210+
.unwrap_or(Validity::AllValid);
162211
let metadata =
163212
ArrayRef::from_arrow(arrow_variant.metadata_field() as &dyn ArrowArray, false)?;
164213

165214
let value = arrow_variant
166215
.value_field()
167-
.map(|v| ArrayRef::from_arrow(v as &dyn ArrowArray, false))
216+
.map(|v| ArrayRef::from_arrow(v as &dyn ArrowArray, value_nullable))
168217
.transpose()?;
169218

170219
let typed_value = arrow_variant
171220
.typed_value_field()
172-
.map(|tv| ArrayRef::from_arrow(tv.as_ref(), tv.is_nullable()))
221+
.map(|tv| ArrayRef::from_arrow(tv.as_ref(), typed_value_nullable))
173222
.transpose()?;
174223

175-
let pv = ParquetVariantArray::try_new(metadata, value, typed_value)?;
224+
let pv =
225+
ParquetVariantArray::try_new_with_validity(validity, metadata, value, typed_value)?;
176226
Ok(VariantArray::new(pv.into_array()).into_array())
177227
}
178228

179229
fn nchildren(&self) -> usize {
180-
1 + self.value.is_some() as usize + self.typed_value.is_some() as usize
230+
validity_nchildren(&self.validity)
231+
+ 1
232+
+ self.value.is_some() as usize
233+
+ self.typed_value.is_some() as usize
181234
}
182235
}
183236

@@ -204,6 +257,7 @@ impl VTable for ParquetVariantVTable {
204257
}
205258

206259
fn array_hash<H: Hasher>(array: &ParquetVariantArray, state: &mut H, precision: Precision) {
260+
array.validity.array_hash(state, precision);
207261
array.metadata.array_hash(state, precision);
208262
if let Some(ref value) = array.value {
209263
value.array_hash(state, precision);
@@ -218,7 +272,9 @@ impl VTable for ParquetVariantVTable {
218272
other: &ParquetVariantArray,
219273
precision: Precision,
220274
) -> bool {
221-
if !array.metadata.array_eq(&other.metadata, precision) {
275+
if !array.validity.array_eq(&other.validity, precision)
276+
|| !array.metadata.array_eq(&other.metadata, precision)
277+
{
222278
return false;
223279
}
224280
match (&array.value, &other.value) {
@@ -254,31 +310,41 @@ impl VTable for ParquetVariantVTable {
254310
}
255311

256312
fn child(array: &ParquetVariantArray, idx: usize) -> ArrayRef {
257-
match idx {
258-
0 => array.metadata.clone(),
259-
1 if array.value.is_some() => array
260-
.value
261-
.clone()
262-
.vortex_expect("ParquetVariantArray missing value child"),
263-
1 => array
264-
.typed_value
265-
.clone()
266-
.vortex_expect("ParquetVariantArray missing typed_value child"),
267-
2 => array
268-
.typed_value
269-
.clone()
270-
.vortex_expect("ParquetVariantArray missing typed_value child"),
271-
_ => vortex_panic!("ParquetVariantArray child index {idx} out of bounds"),
313+
let vc = validity_nchildren(&array.validity);
314+
if idx < vc {
315+
validity_to_child(&array.validity, array.metadata.len())
316+
.vortex_expect("ParquetVariantArray validity child out of bounds")
317+
} else {
318+
match idx - vc {
319+
0 => array.metadata.clone(),
320+
1 if array.value.is_some() => array
321+
.value
322+
.clone()
323+
.vortex_expect("ParquetVariantArray missing value child"),
324+
1 => array
325+
.typed_value
326+
.clone()
327+
.vortex_expect("ParquetVariantArray missing typed_value child"),
328+
2 => array
329+
.typed_value
330+
.clone()
331+
.vortex_expect("ParquetVariantArray missing typed_value child"),
332+
_ => vortex_panic!("ParquetVariantArray child index {idx} out of bounds"),
333+
}
272334
}
273335
}
274336

275337
fn child_name(array: &ParquetVariantArray, idx: usize) -> String {
338+
let vc = validity_nchildren(&array.validity);
276339
match idx {
277-
0 => "metadata".to_string(),
278-
1 if array.value.is_some() => "value".to_string(),
279-
1 => "typed_value".to_string(),
280-
2 => "typed_value".to_string(),
281-
_ => vortex_panic!("ParquetVariantArray child_name index {idx} out of bounds"),
340+
idx if idx < vc => "validity".to_string(),
341+
idx => match idx - vc {
342+
0 => "metadata".to_string(),
343+
1 if array.value.is_some() => "value".to_string(),
344+
1 => "typed_value".to_string(),
345+
2 => "typed_value".to_string(),
346+
_ => vortex_panic!("ParquetVariantArray child_name index {idx} out of bounds"),
347+
},
282348
}
283349
}
284350

@@ -338,13 +404,18 @@ impl VTable for ParquetVariantVTable {
338404

339405
let expected_children = 1 + metadata.has_value as usize + has_typed_value as usize;
340406
vortex_ensure!(
341-
children.len() == expected_children,
342-
"Expected {} children, got {}",
407+
children.len() == expected_children || children.len() == expected_children + 1,
408+
"Expected {} or {} children, got {}",
343409
expected_children,
410+
expected_children + 1,
344411
children.len()
345412
);
346413

347-
let mut child_idx = 0;
414+
let (validity, mut child_idx) = if children.len() == expected_children {
415+
(Validity::AllValid, 0)
416+
} else {
417+
(Validity::Array(children.get(0, &Validity::DTYPE, len)?), 1)
418+
};
348419
let variant_metadata =
349420
children.get(child_idx, &DType::Binary(Nullability::NonNullable), len)?;
350421
child_idx += 1;
@@ -369,7 +440,7 @@ impl VTable for ParquetVariantVTable {
369440
None
370441
};
371442

372-
ParquetVariantArray::try_new(variant_metadata, value, typed_value)
443+
ParquetVariantArray::try_new_with_validity(validity, variant_metadata, value, typed_value)
373444
}
374445

375446
fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
@@ -380,6 +451,12 @@ impl VTable for ParquetVariantVTable {
380451
children.len()
381452
);
382453
let mut iter = children.into_iter();
454+
if validity_nchildren(&array.validity) == 1 {
455+
array.validity = Validity::Array(
456+
iter.next()
457+
.vortex_expect("ParquetVariantArray missing validity child"),
458+
);
459+
}
383460
array.metadata = iter
384461
.next()
385462
.vortex_expect("ParquetVariantArray missing metadata child");
@@ -449,10 +526,8 @@ impl ArrayParentReduceRule<ParquetVariantVTable> for ParquetVariantGetRule {
449526
}
450527

451528
impl ValidityVTable<ParquetVariantVTable> for ParquetVariantVTable {
452-
fn validity(_array: &ParquetVariantArray) -> VortexResult<Validity> {
453-
// Variant is always nullable. Null-ness of individual values is encoded
454-
// within the Parquet Variant binary format itself, not via a separate validity bitmap.
455-
Ok(Validity::AllValid)
529+
fn validity(array: &ParquetVariantArray) -> VortexResult<Validity> {
530+
Ok(array.validity.clone())
456531
}
457532
}
458533

@@ -627,8 +702,31 @@ mod tests {
627702

628703
let mut ctx = LEGACY_SESSION.create_execution_ctx();
629704
let roundtripped = vortex_arr.execute_arrow(None, &mut ctx)?;
705+
let roundtripped = roundtripped.as_struct();
706+
707+
assert_eq!(struct_array.len(), roundtripped.len());
708+
assert_eq!(struct_array.column_names(), roundtripped.column_names());
709+
assert_eq!(struct_array.nulls(), roundtripped.nulls());
710+
assert_eq!(struct_array.fields().len(), roundtripped.fields().len());
711+
712+
for (expected, actual) in struct_array
713+
.fields()
714+
.iter()
715+
.zip(roundtripped.fields().iter())
716+
{
717+
assert_eq!(expected.name(), actual.name());
718+
assert_eq!(expected.data_type(), actual.data_type());
719+
assert_eq!(expected.is_nullable(), actual.is_nullable());
720+
}
721+
722+
for (expected, actual) in struct_array
723+
.columns()
724+
.iter()
725+
.zip(roundtripped.columns().iter())
726+
{
727+
assert_eq!(expected.to_data(), actual.to_data());
728+
}
630729

631-
assert_eq!(struct_array.to_data(), roundtripped.as_struct().to_data());
632730
Ok(())
633731
}
634732

@@ -726,7 +824,7 @@ mod tests {
726824
builder.append_variant(Variant::from("hello"));
727825
builder.append_variant(Variant::from(true));
728826

729-
assert_arrow_variant_storage_roundtrip(builder.build().into_inner().clone())
827+
assert_arrow_variant_storage_roundtrip(builder.build().into_inner())
730828
}
731829

732830
#[test]

vortex-array/src/arrow/executor/variant.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ use crate::ExecutionCtx;
1616
use crate::array::ArrayVisitor;
1717
use crate::arrays::VariantVTable;
1818
use crate::arrow::ArrowArrayExecutor;
19+
use crate::arrow::executor::validity::to_arrow_null_buffer;
1920

2021
pub(super) fn to_arrow_variant(
2122
array: ArrayRef,
2223
target_fields: Option<&Fields>,
2324
ctx: &mut ExecutionCtx,
2425
) -> VortexResult<ArrowArrayRef> {
26+
let len = array.len();
27+
let nulls = to_arrow_null_buffer(array.validity()?, len, ctx)?;
2528
let inner = match array.try_into::<VariantVTable>() {
2629
Ok(variant) => variant.child().clone(),
2730
Err(array) => array,
@@ -38,6 +41,7 @@ pub(super) fn to_arrow_variant(
3841

3942
for (name, child) in named_children {
4043
match name.as_str() {
44+
"validity" => {}
4145
"metadata" => metadata = Some(child),
4246
"value" => value = Some(child),
4347
"typed_value" => typed_value = Some(child),
@@ -105,5 +109,5 @@ pub(super) fn to_arrow_variant(
105109
(Fields::from(fields), arrays)
106110
};
107111

108-
Ok(Arc::new(ArrowStructArray::try_new(fields, arrays, None)?))
112+
Ok(Arc::new(ArrowStructArray::try_new(fields, arrays, nulls)?))
109113
}

0 commit comments

Comments
 (0)