Skip to content

Commit af3cb41

Browse files
committed
fix nullability handling
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent cfc3383 commit af3cb41

5 files changed

Lines changed: 257 additions & 45 deletions

File tree

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ pub struct TurboQuantData {
3737

3838
/// Child arrays stored as slots. See [`Slot`] for positions:
3939
///
40-
/// - [`Codes`](Slot::Codes): `FixedSizeListArray<u8>` with `list_size == padded_dim`. Each row
41-
/// holds one u8 centroid index per padded coordinate. The cascade compressor handles packing
42-
/// to the actual `bit_width` on disk. The validity of the entire array is stored with this.
40+
/// - [`Codes`](Slot::Codes): Non-nullable `FixedSizeListArray<u8>` with
41+
/// `list_size == padded_dim`. Each row holds one u8 centroid index per padded coordinate.
42+
/// Null vectors are represented by all-zero codes. The cascade compressor handles packing
43+
/// to the actual `bit_width` on disk.
4344
///
4445
/// - [`Norms`](Slot::Norms): Per-vector L2 norms, one per row. The dtype matches the element
45-
/// type of the Vector (e.g., f64 norms for f64 vectors). Exact norms are stored during
46-
/// compression, enabling O(1) L2 norm readthrough without decompression.
46+
/// type of the Vector (e.g., f64 norms for f64 vectors) and carries the nullability of the
47+
/// parent dtype. Null vectors have null norms. This child determines the validity of the
48+
/// entire TurboQuant array, enabling O(1) L2 norm readthrough without decompression.
4749
///
4850
/// - [`Centroids`](Slot::Centroids): `PrimitiveArray<f32>` codebook with `2^bit_width` entries
4951
/// that is shared across all rows. We always store these as f32 regardless of the input
@@ -101,10 +103,11 @@ impl TurboQuantData {
101103
///
102104
/// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size
103105
/// is >= 3.
104-
/// - `codes` is a `FixedSizeListArray<u8>` with `list_size == padded_dim` and
105-
/// `codes.len() == norms.len()`.
106+
/// - `codes` is a non-nullable `FixedSizeListArray<u8>` with `list_size == padded_dim` and
107+
/// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes.
106108
/// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage
107-
/// dtype. This must match the validity of the `codes` array.
109+
/// dtype. The nullability must match `dtype.nullability()`. Norms carry the validity of the
110+
/// entire array, since null vectors have null norms.
108111
/// - `centroids` is a non-nullable `PrimitiveArray<f32>` whose length is a power of 2 in
109112
/// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays.
110113
/// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays.
@@ -166,11 +169,12 @@ impl TurboQuantData {
166169
let dimension = extension_list_size(ext)?;
167170
let padded_dim = dimension.next_power_of_two();
168171

169-
// Codes must be a FixedSizeList<u8> with list_size == padded_dim.
172+
// Codes must be a non-nullable FixedSizeList<u8> with list_size == padded_dim.
173+
// Null vectors are represented by all-zero codes since validity lives in the norms array.
170174
let expected_codes_dtype = DType::FixedSizeList(
171-
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), // FIX THIS!!!
175+
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)),
172176
padded_dim,
173-
dtype.nullability(),
177+
Nullability::NonNullable,
174178
);
175179
vortex_ensure_eq!(
176180
*codes.dtype(),
@@ -185,10 +189,6 @@ impl TurboQuantData {
185189
"norms length must match codes length",
186190
);
187191

188-
// TODO(connor): Should we check that the codes and norms have the same validity? We could
189-
// also make it so that norms holds the validity and any null vectors encoded as codes is
190-
// just 0...
191-
192192
// Degenerate (empty) case: all children must be empty, and bit_width is 0.
193193
if num_rows == 0 {
194194
vortex_ensure!(
@@ -219,13 +219,14 @@ impl TurboQuantData {
219219
"derived bit_width must be 1-8, got {bit_width}"
220220
);
221221

222-
// Norms dtype must match the element ptype of the Vector.
222+
// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
223+
// Norms carry the validity of the entire TurboQuant array.
223224
let element_ptype = extension_element_ptype(ext)?;
224-
let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); // FIX THIS!!!
225+
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
225226
vortex_ensure_eq!(
226227
*norms.dtype(),
227228
expected_norms_dtype,
228-
"norms dtype does not match expected (must match Vector element type)",
229+
"norms dtype does not match expected {expected_norms_dtype}",
229230
);
230231

231232
// Centroids are always f32 regardless of element type.

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ struct QuantizationResult {
7676
rotation: RotationMatrix,
7777
centroids: Vec<f32>,
7878
all_indices: BufferMut<u8>,
79-
/// Native-precision norms (matching the Vector element type).
79+
/// Native-precision norms (matching the Vector element type). Carries validity: null vectors
80+
/// have null norms.
8081
norms_array: ArrayRef,
8182
padded_dim: usize,
8283
}
@@ -85,19 +86,22 @@ struct QuantizationResult {
8586
/// normalize/rotate/quantize all rows.
8687
///
8788
/// Norms are computed in the native element precision via the [`L2Norm`] scalar function.
88-
/// The rotation and centroid lookup happen in f32.
89+
/// The rotation and centroid lookup happen in f32. Null rows (per the input validity) produce
90+
/// all-zero codes.
8991
#[allow(clippy::cast_possible_truncation)]
9092
fn turboquant_quantize_core(
9193
ext: &ExtensionArray,
9294
fsl: &FixedSizeListArray,
9395
seed: u64,
9496
bit_width: u8,
97+
validity: &Validity,
9598
ctx: &mut ExecutionCtx,
9699
) -> VortexResult<QuantizationResult> {
97100
let dimension = fsl.list_size() as usize;
98101
let num_rows = fsl.len();
99102

100-
// Compute native-precision norms via the L2Norm scalar fn.
103+
// Compute native-precision norms via the L2Norm scalar fn. L2Norm propagates validity from
104+
// the input, so null vectors get null norms automatically.
101105
let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?;
102106
let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?;
103107
let norms_prim: PrimitiveArray = norms_array.to_canonical()?.into_primitive();
@@ -125,6 +129,12 @@ fn turboquant_quantize_core(
125129

126130
let f32_slice = f32_elements.as_slice::<f32>();
127131
for row in 0..num_rows {
132+
// Null vectors get all-zero codes.
133+
if !validity.is_valid(row)? {
134+
all_indices.extend(std::iter::repeat_n(0u8, padded_dim));
135+
continue;
136+
}
137+
128138
let x = &f32_slice[row * dimension..(row + 1) * dimension];
129139
let norm = f32_norms[row];
130140

@@ -189,12 +199,10 @@ fn build_turboquant(
189199
)
190200
}
191201

192-
/// Encode a [`Vector`] extension array into a `TurboQuantArray`.
193-
///
194-
/// The input must be a non-nullable [`Vector`] extension array. TurboQuant is a lossy encoding
195-
/// that does not preserve null positions; callers must handle validity externally.
202+
/// Encode a [`Vector`](crate::vector::Vector) extension array into a `TurboQuantArray`.
196203
///
197-
/// [`Vector`]: crate::vector::Vector
204+
/// Nullable inputs are supported: null vectors get all-zero codes and null norms. The validity
205+
/// of the resulting TurboQuant array is carried by the norms child.
198206
pub fn turboquant_encode(
199207
ext: &ExtensionArray,
200208
config: &TurboQuantConfig,
@@ -204,10 +212,6 @@ pub fn turboquant_encode(
204212
let storage = ext.storage_array();
205213
let fsl = storage.to_canonical()?.into_fixed_size_list();
206214

207-
vortex_ensure!(
208-
fsl.dtype().nullability() == Nullability::NonNullable,
209-
"TurboQuant requires non-nullable input, got nullable FixedSizeListArray"
210-
);
211215
vortex_ensure!(
212216
config.bit_width >= 1 && config.bit_width <= 8,
213217
"bit_width must be 1-8, got {}",
@@ -228,10 +232,11 @@ pub fn turboquant_encode(
228232
0,
229233
)?;
230234

231-
// Norms dtype matches the element type.
235+
// Norms dtype matches the element type and carries the parent's nullability.
232236
let element_ptype = fsl.elements().dtype().as_ptype();
237+
let norms_nullability = ext_dtype.nullability();
233238
let empty_norms: ArrayRef = match_each_float_ptype!(element_ptype, |T| {
234-
PrimitiveArray::empty::<T>(Nullability::NonNullable).into_array()
239+
PrimitiveArray::empty::<T>(norms_nullability).into_array()
235240
});
236241

237242
let empty_centroids = PrimitiveArray::empty::<f32>(Nullability::NonNullable);
@@ -246,8 +251,9 @@ pub fn turboquant_encode(
246251
.into_array());
247252
}
248253

254+
let validity = ext.as_ref().validity()?;
249255
let seed = config.seed.unwrap_or(42);
250-
let core = turboquant_quantize_core(ext, &fsl, seed, config.bit_width, ctx)?;
256+
let core = turboquant_quantize_core(ext, &fsl, seed, config.bit_width, &validity, ctx)?;
251257

252258
Ok(build_turboquant(&fsl, core, ext_dtype)?.into_array())
253259
}

vortex-tensor/src/encodings/turboquant/decompress.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use vortex_array::arrays::ExtensionArray;
1313
use vortex_array::arrays::FixedSizeListArray;
1414
use vortex_array::arrays::PrimitiveArray;
1515
use vortex_array::dtype::NativePType;
16+
use vortex_array::dtype::Nullability;
1617
use vortex_array::match_each_float_ptype;
1718
use vortex_array::validity::Validity;
1819
use vortex_buffer::BufferMut;
@@ -39,15 +40,17 @@ pub fn execute_decompress(
3940
let element_ptype = extension_element_ptype(&ext_dtype)?;
4041

4142
if num_rows == 0 {
42-
let nn = vortex_array::dtype::Nullability::NonNullable;
43+
let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability());
44+
4345
match_each_float_ptype!(element_ptype, |T| {
44-
let elements = PrimitiveArray::empty::<T>(nn);
46+
let elements = PrimitiveArray::empty::<T>(Nullability::NonNullable);
4547
let fsl = FixedSizeListArray::try_new(
4648
elements.into_array(),
4749
array.dimension(),
48-
Validity::NonNullable,
50+
fsl_validity,
4951
0,
5052
)?;
53+
5154
return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array());
5255
})
5356
}
@@ -70,8 +73,9 @@ pub fn execute_decompress(
7073
let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive();
7174
let indices = codes_prim.as_slice::<u8>();
7275

73-
// Read norms in their native precision.
76+
// Read norms in their native precision. Norms carry the validity of the array.
7477
let norms_prim = array.norms().clone().execute::<PrimitiveArray>(ctx)?;
78+
let output_validity = array.norms().validity()?;
7579

7680
// MSE decode: dequantize (f32) -> inverse rotate (f32) -> scale by norm -> cast to T.
7781
// The rotation and centroid lookup always happen in f32. The final output is cast to the
@@ -90,7 +94,7 @@ pub fn execute_decompress(
9094
let fsl = FixedSizeListArray::try_new(
9195
elements.into_array(),
9296
array.dimension(),
93-
Validity::NonNullable,
97+
output_validity,
9498
num_rows,
9599
)?;
96100
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())

0 commit comments

Comments
 (0)