|
6 | 6 | use std::ops::BitAnd; |
7 | 7 |
|
8 | 8 | use vortex_error::VortexResult; |
9 | | -use vortex_error::vortex_bail; |
10 | 9 | use vortex_mask::Mask; |
11 | 10 |
|
12 | 11 | use crate::Canonical; |
13 | 12 | use crate::IntoArray; |
| 13 | +use crate::ArrayVisitor; |
14 | 14 | use crate::arrays::BoolArray; |
15 | 15 | use crate::arrays::DecimalArray; |
16 | 16 | use crate::arrays::ExtensionArray; |
17 | 17 | use crate::arrays::FixedSizeListArray; |
18 | 18 | use crate::arrays::ListViewArray; |
| 19 | +use crate::arrays::MaskedArray; |
19 | 20 | use crate::arrays::NullArray; |
20 | 21 | use crate::arrays::PrimitiveArray; |
21 | 22 | use crate::arrays::StructArray; |
22 | 23 | use crate::arrays::VarBinViewArray; |
| 24 | +use crate::arrays::VariantArray; |
23 | 25 | use crate::dtype::Nullability; |
24 | 26 | use crate::executor::ExecutionCtx; |
25 | 27 | use crate::match_each_decimal_value_type; |
@@ -54,8 +56,8 @@ pub fn mask_validity_canonical( |
54 | 56 | Canonical::Extension(a) => { |
55 | 57 | Canonical::Extension(mask_validity_extension(a, validity_mask, ctx)?) |
56 | 58 | } |
57 | | - Canonical::Variant(_) => { |
58 | | - vortex_bail!("Variant arrays don't masking validity") |
| 59 | + Canonical::Variant(a) => { |
| 60 | + Canonical::Variant(mask_validity_variant(a, validity_mask, ctx)?) |
59 | 61 | } |
60 | 62 | }) |
61 | 63 | } |
@@ -200,3 +202,35 @@ fn mask_validity_extension( |
200 | 202 | masked_storage, |
201 | 203 | )) |
202 | 204 | } |
| 205 | + |
| 206 | +fn mask_validity_variant( |
| 207 | + array: VariantArray, |
| 208 | + mask: &Mask, |
| 209 | + ctx: &mut ExecutionCtx, |
| 210 | +) -> VortexResult<VariantArray> { |
| 211 | + let child = array.child().clone(); |
| 212 | + let len = child.len(); |
| 213 | + let child_validity = child.validity()?; |
| 214 | + |
| 215 | + match child_validity { |
| 216 | + Validity::NonNullable | Validity::AllValid => { |
| 217 | + // Child has no nulls — wrap in MaskedArray to apply the mask. |
| 218 | + let new_validity = Validity::from_mask(mask.clone(), Nullability::Nullable); |
| 219 | + let masked_child = MaskedArray::try_new(child, new_validity)?; |
| 220 | + Ok(VariantArray::new(masked_child.into_array())) |
| 221 | + } |
| 222 | + Validity::AllInvalid => { |
| 223 | + // Already all-null, ANDing with any mask is still all-null. |
| 224 | + Ok(array) |
| 225 | + } |
| 226 | + Validity::Array(_) => { |
| 227 | + // Child has an array-backed validity stored as its first child. |
| 228 | + // Combine with the mask and replace that child via with_children. |
| 229 | + let combined = combine_validity(&child_validity, mask, len, ctx)?; |
| 230 | + let mut children = child.children(); |
| 231 | + children[0] = combined.to_array(len); |
| 232 | + let new_child = child.with_children(children)?; |
| 233 | + Ok(VariantArray::new(new_child)) |
| 234 | + } |
| 235 | + } |
| 236 | +} |
0 commit comments