Skip to content

Commit eced293

Browse files
Fix Validity::mask_eq semantics for mixed variants (#8334)
## Summary **PR 2 of a 4-PR stack** (stacked on #8333) preparing `Validity` for lazy validity arrays. `mask_eq` previously returned `false` for any mixed-variant pairing without executing — e.g. a `Validity::Array` that resolves to all-true compared against `Validity::AllValid`. With lazy validity arrays, unresolved `Array` variants frequently hold constant masks, making this silently wrong rather than merely conservative. --------- Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 1535ced commit eced293

9 files changed

Lines changed: 81 additions & 33 deletions

File tree

encodings/datetime-parts/src/canonical.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,11 @@ mod test {
143143
&mut ctx,
144144
)?;
145145

146-
assert!(
147-
date_times
148-
.as_array()
149-
.validity()?
150-
.mask_eq(&validity, &mut ctx)?
151-
);
146+
assert!(date_times.as_array().validity()?.mask_eq(
147+
&validity,
148+
milliseconds.len(),
149+
&mut ctx
150+
)?);
152151

153152
let dtype = date_times.dtype().clone();
154153
let parts = DateTimePartsParts {
@@ -163,7 +162,6 @@ mod test {
163162
.execute::<PrimitiveArray>(&mut ctx)?;
164163

165164
assert_arrays_eq!(primitive_values, milliseconds);
166-
assert!(primitive_values.validity()?.mask_eq(&validity, &mut ctx)?);
167165
Ok(())
168166
}
169167
}

encodings/datetime-parts/src/compress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ mod tests {
103103
days_prim
104104
.validity()
105105
.vortex_expect("days validity should be derivable")
106-
.mask_eq(&validity, &mut ctx)
106+
.mask_eq(&validity, days_prim.len(), &mut ctx)
107107
.unwrap()
108108
);
109109
let seconds_prim = seconds.execute::<PrimitiveArray>(&mut ctx).unwrap();

encodings/pco/src/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ fn test_validity_and_multiple_chunks_and_pages() {
149149
.unwrap()
150150
.mask_eq(
151151
&Validity::Array(BoolArray::from_iter(vec![true, false, true]).into_array()),
152+
primitive.len(),
152153
&mut ctx,
153154
)
154155
.unwrap()

encodings/zstd/src/test.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ fn test_zstd_with_validity_and_multi_frame() {
8989
decompressed
9090
.validity()
9191
.unwrap()
92-
.mask_eq(&array.validity().unwrap(), &mut ctx)
92+
.mask_eq(&array.validity().unwrap(), decompressed.len(), &mut ctx)
9393
.unwrap()
9494
);
9595

@@ -106,6 +106,7 @@ fn test_zstd_with_validity_and_multi_frame() {
106106
.unwrap()
107107
.mask_eq(
108108
&Validity::Array(BoolArray::from_iter(vec![false, true, false]).into_array()),
109+
primitive.len(),
109110
&mut ctx
110111
)
111112
.unwrap()

vortex-array/src/arrays/masked/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ fn test_masked_child_preserves_length(#[case] validity: Validity) {
134134
array
135135
.validity()
136136
.vortex_expect("masked validity should be derivable")
137-
.mask_eq(&validity, &mut ctx)
137+
.mask_eq(&validity, array.len(), &mut ctx)
138138
.unwrap(),
139139
);
140140
}

vortex-array/src/builders/bool.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@ mod tests {
209209
#[expect(deprecated)]
210210
let into_canon = chunk.to_bool();
211211

212-
assert!(
213-
canon_into
214-
.validity()?
215-
.mask_eq(&into_canon.validity()?, &mut ctx)?
216-
);
212+
assert!(canon_into.validity()?.mask_eq(
213+
&into_canon.validity()?,
214+
canon_into.len(),
215+
&mut ctx
216+
)?);
217217
assert_eq!(canon_into.to_bit_buffer(), into_canon.to_bit_buffer());
218218
Ok(())
219219
}

vortex-array/src/builders/list.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ mod tests {
490490
&expected
491491
.validity()
492492
.vortex_expect("list validity should be derivable"),
493+
actual.len(),
493494
&mut ctx,
494495
)
495496
.unwrap(),

vortex-array/src/validity.rs

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl Validity {
244244
}
245245
}
246246

247+
#[inline]
247248
pub fn execute_mask(&self, length: usize, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
248249
match self {
249250
Self::NonNullable | Self::AllValid => Ok(Mask::AllTrue(length)),
@@ -263,18 +264,22 @@ impl Validity {
263264
}
264265
}
265266

266-
/// Compare two Validity values of the same length by executing them into masks if necessary.
267-
pub fn mask_eq(&self, other: &Validity, ctx: &mut ExecutionCtx) -> VortexResult<bool> {
267+
/// Compare the logical masks of two Validity values of the given length, executing them
268+
/// into [`Mask`]s if necessary.
269+
pub fn mask_eq(
270+
&self,
271+
other: &Validity,
272+
length: usize,
273+
ctx: &mut ExecutionCtx,
274+
) -> VortexResult<bool> {
268275
match (self, other) {
269-
(Validity::NonNullable, Validity::NonNullable) => Ok(true),
270-
(Validity::AllValid, Validity::AllValid) => Ok(true),
271-
(Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
272-
(Validity::Array(a), Validity::Array(b)) => {
273-
let a = a.clone().execute::<Mask>(ctx)?;
274-
let b = b.clone().execute::<Mask>(ctx)?;
275-
Ok(a == b)
276-
}
277-
_ => Ok(false),
276+
// Fast paths that avoid executing: constant variants with known-equal masks.
277+
(
278+
Validity::NonNullable | Validity::AllValid,
279+
Validity::NonNullable | Validity::AllValid,
280+
)
281+
| (Validity::AllInvalid, Validity::AllInvalid) => Ok(true),
282+
_ => Ok(self.execute_mask(length, ctx)? == other.execute_mask(length, ctx)?),
278283
}
279284
}
280285

@@ -703,7 +708,7 @@ mod tests {
703708
validity
704709
.patch(len, 0, &indices, &patches, &mut ctx,)
705710
.unwrap()
706-
.mask_eq(&expected, &mut ctx)
711+
.mask_eq(&expected, len, &mut ctx)
707712
.unwrap()
708713
);
709714
}
@@ -768,8 +773,50 @@ mod tests {
768773
validity
769774
.take(&indices)
770775
.unwrap()
771-
.mask_eq(&expected, &mut ctx)
776+
.mask_eq(&expected, indices.len(), &mut ctx)
772777
.unwrap()
773778
);
774779
}
780+
781+
#[rstest]
782+
// Mixed constant variants with equal masks.
783+
#[case(Validity::NonNullable, Validity::AllValid, true)]
784+
#[case(Validity::AllValid, Validity::NonNullable, true)]
785+
#[case(Validity::AllValid, Validity::AllInvalid, false)]
786+
#[case(Validity::NonNullable, Validity::AllInvalid, false)]
787+
// An array that resolves to a constant mask must equal the constant variant.
788+
#[case(
789+
Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
790+
Validity::AllValid,
791+
true
792+
)]
793+
#[case(
794+
Validity::NonNullable,
795+
Validity::Array(BoolArray::from_iter([true, true, true]).into_array()),
796+
true
797+
)]
798+
#[case(
799+
Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
800+
Validity::AllInvalid,
801+
true
802+
)]
803+
#[case(
804+
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
805+
Validity::AllValid,
806+
false
807+
)]
808+
#[case(
809+
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
810+
Validity::AllInvalid,
811+
false
812+
)]
813+
fn mask_eq_mixed_variants(
814+
#[case] lhs: Validity,
815+
#[case] rhs: Validity,
816+
#[case] expected: bool,
817+
) -> vortex_error::VortexResult<()> {
818+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
819+
assert_eq!(lhs.mask_eq(&rhs, 3, &mut ctx)?, expected);
820+
Ok(())
821+
}
775822
}

vortex/src/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,11 @@ mod test {
345345
let mut ctx = LEGACY_SESSION.create_execution_ctx();
346346

347347
let recovered_primitive = recovered_array.execute::<PrimitiveArray>(&mut ctx)?;
348-
assert!(
349-
recovered_primitive
350-
.validity()?
351-
.mask_eq(&array.validity()?, &mut ctx)?
352-
);
348+
assert!(recovered_primitive.validity()?.mask_eq(
349+
&array.validity()?,
350+
array.len(),
351+
&mut ctx
352+
)?);
353353
assert_eq!(
354354
recovered_primitive.to_buffer::<u64>(),
355355
array.to_buffer::<u64>()

0 commit comments

Comments
 (0)