Skip to content

Commit 8e92de5

Browse files
authored
Fix Scalar de/serialization with proto and outline nulls from ScalarValue (#6309)
This PR is generally just a refactoring of the `vortex-scalar` crate and all of its dependents. Currently, we store an opaque `InnerScalarValue` inside a `ScalarValue`, and the `InnerScalarValue` is allowed to be `Null`. This PR both removes outlines the `InnerScalarValue::Null` variant, where a nullable scalar is now an `Option<ScalarValue>` (instead of just a `ScalarValue`. This also completely removes `InnerScalarValue` in favor of a public `ScalarValue` enum: ```rust // Before: pub struct ScalarValue(pub(crate) InnerScalarValue); // After: pub enum ScalarValue { // No `Null` variant! /// A boolean value. Bool(bool), /// A primitive numeric value. Primitive(PValue), /// A decimal value. Decimal(DecimalValue), /// A UTF-8 encoded string value. Utf8(BufferString), /// A binary (byte array) value. Binary(ByteBuffer), /// A list of potentially null scalar values. List(Vec<Option<ScalarValue>>), /// TODO? // Extension(ExtScalarRef), // ? } ``` (**IMPORTANT CHANGE**) Additionally, all `Scalar`s are verified to be sound on construction by checking that the `DType` of the `Scalar` `is_compatible` with the `Option<&ScalarValue>` that is passed. The stricter construction changes then mean that we need to change how deserialization of scalars work. The protobuf format is not exactly 1-1 with our `Scalar`s (notably, it only supports 64-bit integers). This means that we might write 8-bit integers and the round trip returns a 64-bit integer. So this PR also changes some APIs to allow us to pass a `DType` to the statistics deserializers. TBD on if this needs to happen in more places (`FileStatistics`?). For reviewers: try to look over all of the diffs since a large majority is **not** just renaming variables, they are semantic changes that I am not super confidant is right. ## Breaks Breaks the old `file_stats` method on `VortexFile` to return `FileStatistics` instead of the array of `StatsSet`. We needed to do this in order to correctly read statistics from DataFusion, specifically in the case of schema evolution. ## TODO Some benchmarks are failing, and also I still need to review everything myself to make sure everything is correct. I also want to add more tests in certain places where I'm very scared things are wrong. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent b8d106c commit 8e92de5

File tree

148 files changed

+6298
-5968
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

148 files changed

+6298
-5968
lines changed

encodings/alp/src/alp/compute/between.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ impl BetweenKernel for ALPVTable {
4545
match_each_alp_float_ptype!(array.ptype(), |F| {
4646
between_impl::<F>(
4747
array,
48-
F::try_from(lower)?,
49-
F::try_from(upper)?,
48+
F::try_from(&lower)?,
49+
F::try_from(&upper)?,
5050
nullability,
5151
options,
5252
)

encodings/alp/src/alp/compute/compare.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use vortex_array::register_kernel;
1515
use vortex_dtype::NativePType;
1616
use vortex_error::VortexResult;
1717
use vortex_error::vortex_bail;
18-
use vortex_scalar::PrimitiveScalar;
18+
use vortex_error::vortex_err;
1919
use vortex_scalar::Scalar;
2020

2121
use crate::ALPArray;
@@ -42,7 +42,13 @@ impl CompareKernel for ALPVTable {
4242
}
4343

4444
if let Some(const_scalar) = rhs.as_constant() {
45-
let pscalar = PrimitiveScalar::try_from(&const_scalar)?;
45+
let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
46+
vortex_err!(
47+
"ALP Compare RHS had the wrong type {}, expected {}",
48+
const_scalar,
49+
const_scalar.dtype()
50+
)
51+
})?;
4652

4753
match_each_alp_float_ptype!(pscalar.ptype(), |T| {
4854
match pscalar.typed_value::<T>() {

encodings/alp/src/alp/ops.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ impl OperationsVTable<ALPVTable> for ALPVTable {
2222
let encoded_val = array.encoded().scalar_at(index)?;
2323

2424
Ok(match_each_alp_float_ptype!(array.ptype(), |T| {
25-
let encoded_val: <T as ALPFloat>::ALPInt = encoded_val
26-
.as_ref()
27-
.try_into()
28-
.vortex_expect("invalid ALPInt");
25+
let encoded_val: <T as ALPFloat>::ALPInt =
26+
(&encoded_val).try_into().vortex_expect("invalid ALPInt");
2927
Scalar::primitive(
3028
<T as ALPFloat>::decode_single(encoded_val, array.exponents()),
3129
array.dtype().nullability(),

encodings/alp/src/alp_rd/compute/take.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use vortex_array::arrays::TakeExecute;
99
use vortex_array::compute::fill_null;
1010
use vortex_error::VortexResult;
1111
use vortex_scalar::Scalar;
12-
use vortex_scalar::ScalarValue;
1312

1413
use crate::ALPRDArray;
1514
use crate::ALPRDVTable;
@@ -36,7 +35,7 @@ impl TakeExecute for ALPRDVTable {
3635
.transpose()?;
3736
let right_parts = fill_null(
3837
&array.right_parts().take(indices.to_array())?,
39-
&Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
38+
&Scalar::zero_value(array.right_parts().dtype()),
4039
)?;
4140

4241
Ok(Some(

encodings/datetime-parts/src/compute/rules.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ fn try_extract_days_constant(array: &ArrayRef) -> Option<i64> {
171171
fn is_constant_zero(array: &ArrayRef) -> bool {
172172
array
173173
.as_opt::<ConstantVTable>()
174-
.is_some_and(|c| c.scalar().is_zero())
174+
.is_some_and(|c| c.scalar().is_zero() == Some(true))
175175
}
176176

177177
#[cfg(test)]

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ impl CompareKernel for DecimalBytePartsVTable {
4646
.vortex_expect("checked for null in entry func");
4747

4848
match decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype())
49-
.map(|value| Scalar::new(scalar_type.clone(), value))
5049
{
51-
Ok(encoded_scalar) => {
50+
Ok(value) => {
51+
let encoded_scalar = Scalar::try_new(scalar_type, Some(value))?;
5252
let encoded_const = ConstantArray::new(encoded_scalar, rhs.len());
5353
compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some)
5454
}
@@ -165,7 +165,10 @@ mod tests {
165165
)
166166
.unwrap()
167167
.to_array();
168-
let rhs = ConstantArray::new(Scalar::new(dtype, DecimalValue::I64(400).into()), lhs.len());
168+
let rhs = ConstantArray::new(
169+
Scalar::try_new(dtype, Some(DecimalValue::I64(400).into())).unwrap(),
170+
lhs.len(),
171+
);
169172

170173
let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
171174

@@ -215,10 +218,11 @@ mod tests {
215218
.to_array();
216219
// This cannot be converted to a i32.
217220
let rhs = ConstantArray::new(
218-
Scalar::new(
221+
Scalar::try_new(
219222
dtype.clone(),
220-
DecimalValue::I128(-9999999999999965304).into(),
221-
),
223+
Some(DecimalValue::I128(-9999999999999965304).into()),
224+
)
225+
.unwrap(),
222226
lhs.len(),
223227
);
224228

@@ -236,7 +240,7 @@ mod tests {
236240

237241
// This cannot be converted to a i32.
238242
let rhs = ConstantArray::new(
239-
Scalar::new(dtype, DecimalValue::I128(9999999999999965304).into()),
243+
Scalar::try_new(dtype, Some(DecimalValue::I128(9999999999999965304).into())).unwrap(),
240244
lhs.len(),
241245
);
242246

encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use vortex_error::vortex_bail;
4444
use vortex_error::vortex_ensure;
4545
use vortex_scalar::DecimalValue;
4646
use vortex_scalar::Scalar;
47+
use vortex_scalar::ScalarValue;
4748
use vortex_session::VortexSession;
4849

4950
use crate::decimal_byte_parts::compute::kernel::PARENT_KERNELS;
@@ -285,10 +286,10 @@ impl OperationsVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
285286
let primitive_scalar = scalar.as_primitive();
286287
// TODO(joe): extend this to support multiple parts.
287288
let value = primitive_scalar.as_::<i64>().vortex_expect("non-null");
288-
Ok(Scalar::new(
289+
Scalar::try_new(
289290
array.dtype.clone(),
290-
DecimalValue::I64(value).into(),
291-
))
291+
Some(ScalarValue::Decimal(DecimalValue::I64(value))),
292+
)
292293
}
293294
}
294295

@@ -319,6 +320,7 @@ mod tests {
319320
use vortex_dtype::Nullability;
320321
use vortex_scalar::DecimalValue;
321322
use vortex_scalar::Scalar;
323+
use vortex_scalar::ScalarValue;
322324

323325
use crate::DecimalBytePartsArray;
324326

@@ -339,11 +341,15 @@ mod tests {
339341

340342
assert_eq!(Scalar::null(dtype.clone()), array.scalar_at(0).unwrap());
341343
assert_eq!(
342-
Scalar::new(dtype.clone(), DecimalValue::I64(200).into()),
344+
Scalar::try_new(
345+
dtype.clone(),
346+
Some(ScalarValue::Decimal(DecimalValue::I64(200)))
347+
)
348+
.unwrap(),
343349
array.scalar_at(1).unwrap()
344350
);
345351
assert_eq!(
346-
Scalar::new(dtype, DecimalValue::I64(400).into()),
352+
Scalar::try_new(dtype, Some(ScalarValue::Decimal(DecimalValue::I64(400)))).unwrap(),
347353
array.scalar_at(2).unwrap()
348354
);
349355
}

encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ mod tests {
262262
.iter()
263263
.enumerate()
264264
.for_each(|(i, v)| {
265-
let scalar: u16 = unpack_single(&compressed, i).try_into().unwrap();
265+
let scalar: u16 = (&unpack_single(&compressed, i)).try_into().unwrap();
266266
assert_eq!(scalar, *v);
267267
});
268268
}

encodings/fastlanes/src/for/array/for_compress.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,7 @@ mod test {
175175
.iter()
176176
.enumerate()
177177
.for_each(|(i, v)| {
178-
assert_eq!(
179-
*v,
180-
i8::try_from(compressed.scalar_at(i).unwrap().as_ref()).unwrap()
181-
);
178+
assert_eq!(*v, i8::try_from(&compressed.scalar_at(i).unwrap()).unwrap());
182179
});
183180
assert_arrays_eq!(decompressed, array);
184181
Ok(())

encodings/fastlanes/src/for/compute/compare.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use vortex_error::VortexError;
1919
use vortex_error::VortexExpect as _;
2020
use vortex_error::VortexResult;
2121
use vortex_scalar::PValue;
22-
use vortex_scalar::PrimitiveScalar;
2322
use vortex_scalar::Scalar;
2423

2524
use crate::FoRArray;
@@ -33,7 +32,7 @@ impl CompareKernel for FoRVTable {
3332
operator: Operator,
3433
) -> VortexResult<Option<ArrayRef>> {
3534
if let Some(constant) = rhs.as_constant()
36-
&& let Ok(constant) = PrimitiveScalar::try_from(&constant)
35+
&& let Some(constant) = constant.as_primitive_opt()
3736
{
3837
match_each_integer_ptype!(constant.ptype(), |T| {
3938
return compare_constant(

0 commit comments

Comments
 (0)