Skip to content

Commit c1b9223

Browse files
authored
Fix Array::validity (#7285)
The vtables refactor ended up obscuring the validity function in favor of ValidityHelper, which was incorrect in a few places. --------- Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 16bbd12 commit c1b9223

153 files changed

Lines changed: 703 additions & 383 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

encodings/alp/src/alp/array.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ mod tests {
571571
use vortex_array::arrays::PrimitiveArray;
572572
use vortex_array::assert_arrays_eq;
573573
use vortex_array::session::ArraySession;
574+
use vortex_error::VortexExpect;
574575
use vortex_session::VortexSession;
575576

576577
use super::*;
@@ -775,7 +776,11 @@ mod tests {
775776
for idx in 0..slice_len {
776777
let expected_value = values[slice_start + idx];
777778

778-
let result_valid = result_primitive.validity().is_valid(idx).unwrap();
779+
let result_valid = result_primitive
780+
.validity()
781+
.vortex_expect("result validity should be derivable")
782+
.is_valid(idx)
783+
.unwrap();
779784
assert_eq!(
780785
result_valid,
781786
expected_value.is_some(),

encodings/alp/src/alp/compress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ where
6666
let (exponents, encoded, exceptional_positions, exceptional_values, mut chunk_offsets) =
6767
T::encode(values_slice, exponents);
6868

69-
let encoded_array = PrimitiveArray::new(encoded, values.validity()).into_array();
69+
let encoded_array = PrimitiveArray::new(encoded, values.validity()?).into_array();
7070

7171
let validity = values.validity_mask()?;
7272
// exceptional_positions may contain exceptions at invalid positions (which contain garbage

encodings/alp/src/alp/decompress.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use vortex_array::dtype::DType;
1111
use vortex_array::match_each_unsigned_integer_ptype;
1212
use vortex_array::patches::Patches;
1313
use vortex_buffer::BufferMut;
14+
use vortex_error::VortexExpect;
1415
use vortex_error::VortexResult;
1516

1617
use crate::ALPArray;
@@ -102,7 +103,9 @@ fn decompress_chunked_core(
102103
patches: &Patches,
103104
dtype: DType,
104105
) -> PrimitiveArray {
105-
let validity = encoded.validity();
106+
let validity = encoded
107+
.validity()
108+
.vortex_expect("ALP validity should be derivable");
106109
let ptype = dtype.as_ptype();
107110
let array_len = encoded.len();
108111
let offset_within_chunk = patches.offset_within_chunk().unwrap_or(0);
@@ -152,7 +155,7 @@ fn decompress_unchunked_core(
152155
dtype: DType,
153156
ctx: &mut ExecutionCtx,
154157
) -> VortexResult<PrimitiveArray> {
155-
let validity = encoded.validity();
158+
let validity = encoded.validity()?;
156159
let ptype = dtype.as_ptype();
157160

158161
let decoded = match_each_alp_float_ptype!(ptype, |T| {

encodings/alp/src/alp_rd/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,12 @@ impl RDEncoder {
227227
}
228228

229229
// Bit-pack down the encoded left-parts array that have been dictionary encoded.
230-
let primitive_left = PrimitiveArray::new(left_parts, array.validity());
230+
let primitive_left = PrimitiveArray::new(
231+
left_parts,
232+
array
233+
.validity()
234+
.vortex_expect("ALP RD validity should be derivable"),
235+
);
231236
// SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
232237
let packed_left = unsafe {
233238
bitpack_encode_unchecked(primitive_left, left_bit_width as _)

encodings/bytebool/public-api.lock

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub type vortex_bytebool::ByteBool::ArrayData = vortex_bytebool::ByteBoolData
2626

2727
pub type vortex_bytebool::ByteBool::OperationsVTable = vortex_bytebool::ByteBool
2828

29-
pub type vortex_bytebool::ByteBool::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromValidityHelper
29+
pub type vortex_bytebool::ByteBool::ValidityVTable = vortex_bytebool::ByteBool
3030

3131
pub fn vortex_bytebool::ByteBool::array_eq(array: &vortex_bytebool::ByteBoolData, other: &vortex_bytebool::ByteBoolData, precision: vortex_array::hash::Precision) -> bool
3232

@@ -62,6 +62,10 @@ impl vortex_array::array::vtable::operations::OperationsVTable<vortex_bytebool::
6262

6363
pub fn vortex_bytebool::ByteBool::scalar_at(array: vortex_array::array::view::ArrayView<'_, vortex_bytebool::ByteBool>, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
6464

65+
impl vortex_array::array::vtable::validity::ValidityVTable<vortex_bytebool::ByteBool> for vortex_bytebool::ByteBool
66+
67+
pub fn vortex_bytebool::ByteBool::validity(array: vortex_array::array::view::ArrayView<'_, vortex_bytebool::ByteBool>) -> vortex_error::VortexResult<vortex_array::validity::Validity>
68+
6569
impl vortex_array::arrays::dict::take::TakeExecute for vortex_bytebool::ByteBool
6670

6771
pub fn vortex_bytebool::ByteBool::take(array: vortex_array::array::view::ArrayView<'_, Self>, indices: &vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>
@@ -96,6 +100,8 @@ pub fn vortex_bytebool::ByteBoolData::new(buffer: vortex_array::buffer::BufferHa
96100

97101
pub fn vortex_bytebool::ByteBoolData::validate(buffer: &vortex_array::buffer::BufferHandle, validity: &vortex_array::validity::Validity, dtype: &vortex_array::dtype::DType, len: usize) -> vortex_error::VortexResult<()>
98102

103+
pub fn vortex_bytebool::ByteBoolData::validity(&self) -> vortex_array::validity::Validity
104+
99105
pub fn vortex_bytebool::ByteBoolData::validity_mask(&self) -> vortex_mask::Mask
100106

101107
impl core::clone::Clone for vortex_bytebool::ByteBoolData
@@ -114,8 +120,4 @@ impl core::fmt::Debug for vortex_bytebool::ByteBoolData
114120

115121
pub fn vortex_bytebool::ByteBoolData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
116122

117-
impl vortex_array::array::vtable::validity::ValidityHelper for vortex_bytebool::ByteBoolData
118-
119-
pub fn vortex_bytebool::ByteBoolData::validity(&self) -> &vortex_array::validity::Validity
120-
121123
pub type vortex_bytebool::ByteBoolArray = vortex_array::array::typed::Array<vortex_bytebool::ByteBool>

encodings/bytebool/src/array.rs

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ use vortex_array::Precision;
1717
use vortex_array::arrays::BoolArray;
1818
use vortex_array::buffer::BufferHandle;
1919
use vortex_array::dtype::DType;
20+
use vortex_array::dtype::Nullability;
2021
use vortex_array::scalar::Scalar;
2122
use vortex_array::serde::ArrayChildren;
2223
use vortex_array::validity::Validity;
2324
use vortex_array::vtable;
2425
use vortex_array::vtable::OperationsVTable;
2526
use vortex_array::vtable::VTable;
26-
use vortex_array::vtable::ValidityHelper;
27-
use vortex_array::vtable::ValidityVTableFromValidityHelper;
27+
use vortex_array::vtable::ValidityVTable;
28+
use vortex_array::vtable::child_to_validity;
2829
use vortex_array::vtable::validity_to_child;
2930
use vortex_buffer::BitBuffer;
3031
use vortex_buffer::ByteBuffer;
@@ -43,24 +44,25 @@ impl VTable for ByteBool {
4344
type ArrayData = ByteBoolData;
4445

4546
type OperationsVTable = Self;
46-
type ValidityVTable = ValidityVTableFromValidityHelper;
47+
type ValidityVTable = Self;
4748

4849
fn id(&self) -> ArrayId {
4950
Self::ID
5051
}
5152

5253
fn validate(&self, data: &Self::ArrayData, dtype: &DType, len: usize) -> VortexResult<()> {
53-
ByteBoolData::validate(data.buffer(), data.validity(), dtype, len)
54+
let validity = data.validity();
55+
ByteBoolData::validate(data.buffer(), &validity, dtype, len)
5456
}
5557

5658
fn array_hash<H: std::hash::Hasher>(array: &ByteBoolData, state: &mut H, precision: Precision) {
5759
array.buffer.array_hash(state, precision);
58-
array.validity.array_hash(state, precision);
60+
array.validity().array_hash(state, precision);
5961
}
6062

6163
fn array_eq(array: &ByteBoolData, other: &ByteBoolData, precision: Precision) -> bool {
6264
array.buffer.array_eq(&other.buffer, precision)
63-
&& array.validity.array_eq(&other.validity, precision)
65+
&& array.validity().array_eq(&other.validity(), precision)
6466
}
6567

6668
fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
@@ -132,10 +134,6 @@ impl VTable for ByteBool {
132134
NUM_SLOTS,
133135
slots.len()
134136
);
135-
array.validity = match &slots[VALIDITY_SLOT] {
136-
Some(arr) => Validity::Array(arr.clone()),
137-
None => Validity::from(array.validity.nullability()),
138-
};
139137
array.slots = slots;
140138
Ok(())
141139
}
@@ -150,7 +148,7 @@ impl VTable for ByteBool {
150148

151149
fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
152150
let boolean_buffer = BitBuffer::from(array.as_slice());
153-
let validity = array.validity().clone();
151+
let validity = array.validity()?;
154152
Ok(ExecutionResult::done(
155153
BoolArray::new(boolean_buffer, validity).into_array(),
156154
))
@@ -174,7 +172,7 @@ pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["validity"];
174172
#[derive(Clone, Debug)]
175173
pub struct ByteBoolData {
176174
buffer: BufferHandle,
177-
validity: Validity,
175+
nullability: Nullability,
178176
pub(super) slots: Vec<Option<ArrayRef>>,
179177
}
180178

@@ -194,15 +192,15 @@ impl ByteBool {
194192
/// Construct a [`ByteBoolArray`] from a `Vec<bool>` and validity.
195193
pub fn from_vec<V: Into<Validity>>(data: Vec<bool>, validity: V) -> ByteBoolArray {
196194
let data = ByteBoolData::from_vec(data, validity);
197-
let dtype = DType::Bool(data.validity.nullability());
195+
let dtype = DType::Bool(data.nullability);
198196
let len = data.len();
199197
unsafe { Array::from_parts_unchecked(ArrayParts::new(ByteBool, dtype, len, data)) }
200198
}
201199

202200
/// Construct a [`ByteBoolArray`] from optional bools.
203201
pub fn from_option_vec(data: Vec<Option<bool>>) -> ByteBoolArray {
204202
let data = ByteBoolData::from(data);
205-
let dtype = DType::Bool(data.validity.nullability());
203+
let dtype = DType::Bool(data.nullability);
206204
let len = data.len();
207205
unsafe { Array::from_parts_unchecked(ArrayParts::new(ByteBool, dtype, len, data)) }
208206
}
@@ -235,6 +233,10 @@ impl ByteBoolData {
235233
vec![validity_to_child(validity, len)]
236234
}
237235

236+
pub fn validity(&self) -> Validity {
237+
child_to_validity(&self.slots[VALIDITY_SLOT], self.nullability)
238+
}
239+
238240
pub fn new(buffer: BufferHandle, validity: Validity) -> Self {
239241
let length = buffer.len();
240242
if let Some(vlen) = validity.maybe_len()
@@ -249,7 +251,7 @@ impl ByteBoolData {
249251
let slots = Self::make_slots(&validity, length);
250252
Self {
251253
buffer,
252-
validity,
254+
nullability: validity.nullability(),
253255
slots,
254256
}
255257
}
@@ -266,7 +268,7 @@ impl ByteBoolData {
266268

267269
/// Returns the validity mask for this array.
268270
pub fn validity_mask(&self) -> Mask {
269-
self.validity.to_mask(self.len())
271+
self.validity().to_mask(self.len())
270272
}
271273

272274
// TODO(ngates): deprecate construction from vec
@@ -287,9 +289,9 @@ impl ByteBoolData {
287289
}
288290
}
289291

290-
impl ValidityHelper for ByteBoolData {
291-
fn validity(&self) -> &Validity {
292-
&self.validity
292+
impl ValidityVTable<ByteBool> for ByteBool {
293+
fn validity(array: ArrayView<'_, ByteBool>) -> VortexResult<Validity> {
294+
Ok(array.data().validity())
293295
}
294296
}
295297

encodings/bytebool/src/compute.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ impl CastReduce for ByteBool {
2626
// If just changing nullability, we can optimize
2727
if array.dtype().eq_ignore_nullability(dtype) {
2828
let new_validity = array
29-
.validity()
30-
.clone()
29+
.validity()?
3130
.cast_nullability(dtype.nullability(), array.len())?;
3231

3332
return Ok(Some(
@@ -45,10 +44,7 @@ impl MaskReduce for ByteBool {
4544
Ok(Some(
4645
ByteBool::new(
4746
array.buffer().clone(),
48-
array
49-
.validity()
50-
.clone()
51-
.and(Validity::Array(mask.clone()))?,
47+
array.validity()?.and(Validity::Array(mask.clone()))?,
5248
)
5349
.into_array(),
5450
))
@@ -65,7 +61,7 @@ impl TakeExecute for ByteBool {
6561
let bools = array.as_slice();
6662

6763
// This handles combining validity from both source array and nullable indices
68-
let validity = array.validity().take(&indices.clone().into_array())?;
64+
let validity = array.validity()?.take(&indices.clone().into_array())?;
6965

7066
let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
7167
indices

encodings/bytebool/src/slice.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ impl SliceReduce for ByteBool {
1616
Ok(Some(
1717
ByteBool::new(
1818
array.buffer().slice(range.clone()),
19-
array.validity().slice(range)?,
19+
array.validity()?.slice(range)?,
2020
)
2121
.into_array(),
2222
))

encodings/datetime-parts/src/canonical.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,12 @@ mod test {
160160
.execute::<PrimitiveArray>(&mut ctx)?;
161161

162162
assert_arrays_eq!(primitive_values, milliseconds);
163-
assert!(primitive_values.validity().mask_eq(&validity, &mut ctx)?);
163+
assert!(
164+
primitive_values
165+
.validity()
166+
.unwrap()
167+
.mask_eq(&validity, &mut ctx)?
168+
);
164169
Ok(())
165170
}
166171
}

encodings/datetime-parts/src/compress.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub fn split_temporal(array: TemporalArray) -> VortexResult<TemporalParts> {
5151
}
5252

5353
Ok(TemporalParts {
54-
days: PrimitiveArray::new(days, temporal_values.validity()).into_array(),
54+
days: PrimitiveArray::new(days, temporal_values.validity()?).into_array(),
5555
seconds: seconds.into_array(),
5656
subseconds: subseconds.into_array(),
5757
})
@@ -83,6 +83,7 @@ mod tests {
8383
use vortex_array::extension::datetime::TimeUnit;
8484
use vortex_array::validity::Validity;
8585
use vortex_buffer::buffer;
86+
use vortex_error::VortexExpect;
8687

8788
use crate::TemporalParts;
8889
use crate::split_temporal;
@@ -114,15 +115,22 @@ mod tests {
114115
assert!(
115116
days.to_primitive()
116117
.validity()
118+
.vortex_expect("days validity should be derivable")
117119
.mask_eq(&validity, &mut ctx)
118120
.unwrap()
119121
);
120122
assert!(matches!(
121-
seconds.to_primitive().validity(),
123+
seconds
124+
.to_primitive()
125+
.validity()
126+
.vortex_expect("seconds validity should be derivable"),
122127
Validity::NonNullable
123128
));
124129
assert!(matches!(
125-
subseconds.to_primitive().validity(),
130+
subseconds
131+
.to_primitive()
132+
.validity()
133+
.vortex_expect("subseconds validity should be derivable"),
126134
Validity::NonNullable
127135
));
128136
}

0 commit comments

Comments
 (0)