Skip to content

Commit a8efc57

Browse files
committed
Handle child nullability better
1 parent d8f0dd1 commit a8efc57

2 files changed

Lines changed: 44 additions & 17 deletions

File tree

encodings/parquet-variant/src/variant_get/mod.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ fn variant_get_impl(
101101
.ok_or_else(|| vortex_err!("variant_get did not return a StructArray"))?,
102102
)
103103
.map_err(|e| vortex_err!("failed to create VariantArray from result: {e}"))?;
104+
let value_nullable = result_variant
105+
.inner()
106+
.fields()
107+
.iter()
108+
.find(|field| field.name() == "value")
109+
.map(|field| field.is_nullable())
110+
.unwrap_or(false);
111+
let typed_value_nullable = result_variant
112+
.inner()
113+
.fields()
114+
.iter()
115+
.find(|field| field.name() == "typed_value")
116+
.map(|field| field.is_nullable())
117+
.unwrap_or(false);
104118

105119
// Ensure the result is always nullable (matching variant_get's return_dtype).
106120
// Arrow may return a non-nullable result when no nulls are present.
@@ -121,11 +135,11 @@ fn variant_get_impl(
121135
)?;
122136
let value = result_variant
123137
.value_field()
124-
.map(|v| ArrayRef::from_arrow(v as &dyn arrow_array::Array, true))
138+
.map(|v| ArrayRef::from_arrow(v as &dyn arrow_array::Array, value_nullable))
125139
.transpose()?;
126140
let typed_value = result_variant
127141
.typed_value_field()
128-
.map(|tv| ArrayRef::from_arrow(tv.as_ref(), true))
142+
.map(|tv| ArrayRef::from_arrow(tv.as_ref(), typed_value_nullable))
129143
.transpose()?;
130144

131145
let pv = ParquetVariant::try_new(validity, metadata, value, typed_value)?;

encodings/parquet-variant/src/variant_get/tests.rs

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,31 @@ fn vortex_to_arrow_variant(arr: &ArrayRef) -> ArrowVariantArray {
6868
pv.to_arrow(&mut ctx).unwrap()
6969
}
7070

71+
fn assert_variant_storage_matches(expected: &ArrowVariantArray, actual: &ArrowVariantArray) {
72+
assert_eq!(actual.len(), expected.len(), "length mismatch");
73+
assert_eq!(
74+
actual.inner().column_names(),
75+
expected.inner().column_names(),
76+
"column mismatch"
77+
);
78+
assert_eq!(actual.inner().nulls(), expected.inner().nulls());
79+
assert_eq!(
80+
actual.inner().fields().len(),
81+
expected.inner().fields().len()
82+
);
83+
84+
for (expected, actual) in expected
85+
.inner()
86+
.fields()
87+
.iter()
88+
.zip(actual.inner().fields().iter())
89+
{
90+
assert_eq!(actual.name(), expected.name());
91+
assert_eq!(actual.data_type(), expected.data_type());
92+
assert_eq!(actual.is_nullable(), expected.is_nullable());
93+
}
94+
}
95+
7196
/// Run variant_get through both Arrow and Vortex on the same input, and assert
7297
/// the per-row results (value + validity) are identical by comparing at the Arrow level.
7398
fn assert_matches_arrow(json_rows: &[&str], field: &str) {
@@ -89,11 +114,7 @@ fn assert_matches_arrow(json_rows: &[&str], field: &str) {
89114
let vortex_result = apply_variant_get!(&vortex_input, field).unwrap();
90115
let vortex_as_arrow = vortex_to_arrow_variant(&vortex_result);
91116

92-
assert_eq!(
93-
vortex_as_arrow.len(),
94-
arrow_result_variant.len(),
95-
"length mismatch"
96-
);
117+
assert_variant_storage_matches(&arrow_result_variant, &vortex_as_arrow);
97118

98119
for i in 0..arrow_result_variant.len() {
99120
let arrow_is_null = arrow_result_variant.is_null(i);
@@ -139,11 +160,7 @@ fn assert_matches_arrow_with_path(
139160
let vortex_result = apply_variant_get!(&vortex_input, path).unwrap();
140161
let vortex_as_arrow = vortex_to_arrow_variant(&vortex_result);
141162

142-
assert_eq!(
143-
vortex_as_arrow.len(),
144-
arrow_result_variant.len(),
145-
"length mismatch"
146-
);
163+
assert_variant_storage_matches(&arrow_result_variant, &vortex_as_arrow);
147164

148165
for i in 0..arrow_result_variant.len() {
149166
let arrow_is_null = arrow_result_variant.is_null(i);
@@ -226,11 +243,7 @@ fn assert_matches_arrow_nullable(json_rows: &[&str], validity: &[bool], field: &
226243
let vortex_result = apply_variant_get!(&vortex_input, field).unwrap();
227244
let vortex_as_arrow = vortex_to_arrow_variant(&vortex_result);
228245

229-
assert_eq!(
230-
vortex_as_arrow.len(),
231-
arrow_result_variant.len(),
232-
"length mismatch"
233-
);
246+
assert_variant_storage_matches(&arrow_result_variant, &vortex_as_arrow);
234247

235248
for i in 0..arrow_result_variant.len() {
236249
let arrow_is_null = arrow_result_variant.is_null(i);

0 commit comments

Comments
 (0)