Skip to content

Commit fb6bbcf

Browse files
committed
fix minor things
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent b59e8b8 commit fb6bbcf

5 files changed

Lines changed: 131 additions & 68 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,14 @@ pub mod vortex_tensor::encodings
44

55
pub mod vortex_tensor::encodings::turboquant
66

7-
pub mod vortex_tensor::encodings::turboquant::scheme
8-
9-
pub struct vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
10-
11-
impl core::clone::Clone for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
12-
13-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::clone(&self) -> vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
14-
15-
impl core::cmp::Eq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
16-
17-
impl core::cmp::PartialEq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
18-
19-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::eq(&self, other: &vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme) -> bool
20-
21-
impl core::fmt::Debug for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
22-
23-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
24-
25-
impl core::marker::Copy for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
26-
27-
impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
28-
29-
impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
30-
31-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::compress(&self, compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
32-
33-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult<f64>
34-
35-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::matches(&self, canonical: &vortex_array::canonical::Canonical) -> bool
36-
37-
pub fn vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme::scheme_name(&self) -> &'static str
38-
39-
pub static vortex_tensor::encodings::turboquant::scheme::TURBOQUANT_SCHEME: vortex_tensor::encodings::turboquant::scheme::TurboQuantScheme
40-
417
pub struct vortex_tensor::encodings::turboquant::TurboQuant
428

439
impl vortex_tensor::encodings::turboquant::TurboQuant
4410

4511
pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId
4612

13+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef>
14+
4715
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant
4816

4917
pub fn vortex_tensor::encodings::turboquant::TurboQuant::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuant
@@ -56,7 +24,7 @@ impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquan
5624

5725
pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData
5826

59-
pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::array::TurboQuantMetadata
27+
pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::TurboQuantMetadata
6028

6129
pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant
6230

@@ -176,9 +144,47 @@ impl vortex_array::array::IntoArray for vortex_tensor::encodings::turboquant::Tu
176144

177145
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::into_array(self) -> vortex_array::array::erased::ArrayRef
178146

179-
pub const vortex_tensor::encodings::turboquant::FIXED_SHAPE_TENSOR_EXT_ID: &str
147+
pub struct vortex_tensor::encodings::turboquant::TurboQuantMetadata
148+
149+
pub vortex_tensor::encodings::turboquant::TurboQuantMetadata::bit_width: u8
150+
151+
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantMetadata
152+
153+
pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantMetadata
154+
155+
impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantMetadata
156+
157+
pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
158+
159+
pub struct vortex_tensor::encodings::turboquant::TurboQuantScheme
160+
161+
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantScheme
162+
163+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantScheme
164+
165+
impl core::cmp::Eq for vortex_tensor::encodings::turboquant::TurboQuantScheme
166+
167+
impl core::cmp::PartialEq for vortex_tensor::encodings::turboquant::TurboQuantScheme
168+
169+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::eq(&self, other: &vortex_tensor::encodings::turboquant::TurboQuantScheme) -> bool
170+
171+
impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantScheme
172+
173+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
174+
175+
impl core::marker::Copy for vortex_tensor::encodings::turboquant::TurboQuantScheme
176+
177+
impl core::marker::StructuralPartialEq for vortex_tensor::encodings::turboquant::TurboQuantScheme
178+
179+
impl vortex_compressor::scheme::Scheme for vortex_tensor::encodings::turboquant::TurboQuantScheme
180+
181+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::compress(&self, compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
182+
183+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::expected_compression_ratio(&self, _compressor: &vortex_compressor::compressor::CascadingCompressor, data: &mut vortex_compressor::stats::cache::ArrayAndStats, _ctx: vortex_compressor::ctx::CompressorContext) -> vortex_error::VortexResult<f64>
184+
185+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::matches(&self, canonical: &vortex_array::canonical::Canonical) -> bool
180186

181-
pub const vortex_tensor::encodings::turboquant::VECTOR_EXT_ID: &str
187+
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self) -> &'static str
182188

183189
pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession)
184190

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
46
use vortex_array::ArrayRef;
57
use vortex_array::dtype::DType;
68
use vortex_array::dtype::Nullability;
@@ -22,7 +24,7 @@ use crate::utils::extension_list_size;
2224
/// extension arrays. It stores quantized coordinate codes and per-vector norms, along with shared
2325
/// codebook centroids and SRHT rotation signs.
2426
///
25-
/// See the [module docs](super) for algorithmic details.
27+
/// See the [module docs](crate::encodings::turboquant) for algorithmic details.
2628
///
2729
/// A degenerate TurboQuant array has zero rows and `bit_width == 0`, with all slots empty.
2830
#[derive(Clone, Debug)]
@@ -128,7 +130,7 @@ impl TurboQuantData {
128130
let bit_width = if centroids.is_empty() {
129131
0
130132
} else {
131-
// Guaranteed to be 0-8 by validate().
133+
// Guaranteed to be 1-8 by validate().
132134
#[expect(clippy::cast_possible_truncation)]
133135
{
134136
centroids.len().trailing_zeros() as u8
@@ -162,6 +164,19 @@ impl TurboQuantData {
162164
) -> VortexResult<()> {
163165
let ext = TurboQuant::validate_dtype(dtype)?;
164166
let dimension = extension_list_size(ext)?;
167+
let padded_dim = dimension.next_power_of_two();
168+
169+
// Codes must be a FixedSizeList<u8> with list_size == padded_dim.
170+
let expected_codes_dtype = DType::FixedSizeList(
171+
Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), // FIX THIS!!!
172+
padded_dim,
173+
dtype.nullability(),
174+
);
175+
vortex_ensure_eq!(
176+
*codes.dtype(),
177+
expected_codes_dtype,
178+
"codes dtype does not match expected {expected_codes_dtype}",
179+
);
165180

166181
let num_rows = codes.len();
167182
vortex_ensure_eq!(
@@ -206,27 +221,25 @@ impl TurboQuantData {
206221

207222
// Norms dtype must match the element ptype of the Vector.
208223
let element_ptype = extension_element_ptype(ext)?;
209-
let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable);
224+
let expected_norms_dtype = DType::Primitive(element_ptype, Nullability::NonNullable); // FIX THIS!!!
210225
vortex_ensure_eq!(
211226
*norms.dtype(),
212227
expected_norms_dtype,
213-
"norms dtype does not match expected {expected_norms_dtype} \
214-
(must match Vector element type)",
228+
"norms dtype does not match expected (must match Vector element type)",
215229
);
216230

217231
// Centroids are always f32 regardless of element type.
218232
let centroids_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
219233
vortex_ensure_eq!(
220234
*centroids.dtype(),
221235
centroids_dtype,
222-
"centroids dtype must be non-nullable f32",
236+
"centroids dtype must be non-nullable f32",
223237
);
224238

225239
// Rotation signs count must be 3 * padded_dim.
226-
let padded_dim = dimension.next_power_of_two() as usize;
227240
vortex_ensure_eq!(
228241
rotation_signs.len(),
229-
3 * padded_dim,
242+
3 * padded_dim as usize,
230243
"rotation_signs length does not match expected 3 * {padded_dim}",
231244
);
232245

vortex-tensor/src/encodings/turboquant/array/scheme.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::utils::extension_list_size;
2828
/// use vortex_tensor::encodings::turboquant::TurboQuantScheme;
2929
///
3030
/// let compressor = BtrBlocksCompressorBuilder::default()
31-
/// .with_scheme(&TurboQuantScheme)
31+
/// .with_new_scheme(&TurboQuantScheme)
3232
/// .build();
3333
/// ```
3434
///

vortex-tensor/src/encodings/turboquant/decompress.rs

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@
33

44
//! TurboQuant decoding (dequantization) logic.
55
6+
use num_traits::FromPrimitive;
7+
use num_traits::Zero;
68
use vortex_array::Array;
79
use vortex_array::ArrayRef;
810
use vortex_array::ExecutionCtx;
911
use vortex_array::IntoArray;
1012
use vortex_array::arrays::ExtensionArray;
1113
use vortex_array::arrays::FixedSizeListArray;
1214
use vortex_array::arrays::PrimitiveArray;
15+
use vortex_array::dtype::NativePType;
16+
use vortex_array::match_each_float_ptype;
1317
use vortex_array::validity::Validity;
1418
use vortex_buffer::BufferMut;
1519
use vortex_error::VortexResult;
1620

1721
use crate::encodings::turboquant::TurboQuant;
1822
use crate::encodings::turboquant::array::rotation::RotationMatrix;
23+
use crate::utils::extension_element_ptype;
1924

2025
/// Decompress a `TurboQuantArray` into a [`Vector`] extension array.
2126
///
@@ -31,19 +36,23 @@ pub fn execute_decompress(
3136
let padded_dim = array.padded_dim() as usize;
3237
let num_rows = array.norms().len();
3338
let ext_dtype = array.dtype.as_extension().clone();
39+
let element_ptype = extension_element_ptype(&ext_dtype)?;
3440

3541
if num_rows == 0 {
36-
let elements = PrimitiveArray::empty::<f32>(ext_dtype.storage_dtype().nullability());
37-
let fsl = FixedSizeListArray::try_new(
38-
elements.into_array(),
39-
array.dimension(),
40-
Validity::NonNullable,
41-
0,
42-
)?;
43-
return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array());
42+
let nn = vortex_array::dtype::Nullability::NonNullable;
43+
match_each_float_ptype!(element_ptype, |T| {
44+
let elements = PrimitiveArray::empty::<T>(nn);
45+
let fsl = FixedSizeListArray::try_new(
46+
elements.into_array(),
47+
array.dimension(),
48+
Validity::NonNullable,
49+
0,
50+
)?;
51+
return Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array());
52+
})
4453
}
4554

46-
// Read stored centroids -- no recomputation.
55+
// Read stored centroids (always f32).
4756
let centroids_prim = array.centroids().clone().execute::<PrimitiveArray>(ctx)?;
4857
let centroids = centroids_prim.as_slice::<f32>();
4958

@@ -61,11 +70,47 @@ pub fn execute_decompress(
6170
let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive();
6271
let indices = codes_prim.as_slice::<u8>();
6372

73+
// Read norms in their native precision.
6474
let norms_prim = array.norms().clone().execute::<PrimitiveArray>(ctx)?;
65-
let norms = norms_prim.as_slice::<f32>();
6675

67-
// MSE decode: dequantize -> inverse rotate -> scale by norm.
68-
let mut output = BufferMut::<f32>::with_capacity(num_rows * dim);
76+
// MSE decode: dequantize (f32) -> inverse rotate (f32) -> scale by norm -> cast to T.
77+
// The rotation and centroid lookup always happen in f32. The final output is cast to the
78+
// Vector's element type to match the original storage dtype.
79+
match_each_float_ptype!(element_ptype, |T| {
80+
decompress_typed::<T>(
81+
&norms_prim,
82+
centroids,
83+
&rotation,
84+
indices,
85+
dim,
86+
padded_dim,
87+
num_rows,
88+
)
89+
.and_then(|elements| {
90+
let fsl = FixedSizeListArray::try_new(
91+
elements.into_array(),
92+
array.dimension(),
93+
Validity::NonNullable,
94+
num_rows,
95+
)?;
96+
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
97+
})
98+
})
99+
}
100+
101+
/// Typed decompress: reads norms as `T`, dequantizes in f32, and produces output as `T`.
102+
fn decompress_typed<T: NativePType + FromPrimitive + Zero>(
103+
norms_prim: &PrimitiveArray,
104+
centroids: &[f32],
105+
rotation: &RotationMatrix,
106+
indices: &[u8],
107+
dim: usize,
108+
padded_dim: usize,
109+
num_rows: usize,
110+
) -> VortexResult<PrimitiveArray> {
111+
let norms = norms_prim.as_slice::<T>();
112+
113+
let mut output = BufferMut::<T>::with_capacity(num_rows * dim);
69114
let mut dequantized = vec![0.0f32; padded_dim];
70115
let mut unrotated = vec![0.0f32; padded_dim];
71116

@@ -80,18 +125,14 @@ pub fn execute_decompress(
80125
rotation.inverse_rotate(&dequantized, &mut unrotated);
81126

82127
for idx in 0..dim {
83-
unrotated[idx] *= norm;
128+
// Convert f32 dequantized value to T, then scale by the native-precision norm.
129+
let val = T::from_f32(unrotated[idx]).unwrap_or_else(T::zero) * norm;
130+
output.push(val);
84131
}
85-
86-
output.extend_from_slice(&unrotated[..dim]);
87132
}
88133

89-
let elements = PrimitiveArray::new::<f32>(output.freeze(), Validity::NonNullable);
90-
let fsl = FixedSizeListArray::try_new(
91-
elements.into_array(),
92-
array.dimension(),
134+
Ok(PrimitiveArray::new::<T>(
135+
output.freeze(),
93136
Validity::NonNullable,
94-
num_rows,
95-
)?;
96-
Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array())
137+
))
97138
}

vortex-tensor/src/encodings/turboquant/tests.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
14
use std::sync::LazyLock;
25

36
use rand::SeedableRng;

0 commit comments

Comments
 (0)