Skip to content

Commit d4c1cdf

Browse files
committed
fix casting issues and other minor things
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 54b59b9 commit d4c1cdf

6 files changed

Lines changed: 37 additions & 60 deletions

File tree

vortex-tensor/public-api.lock

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ impl vortex_tensor::encodings::turboquant::TurboQuant
1010

1111
pub const vortex_tensor::encodings::turboquant::TurboQuant::ID: vortex_array::array::ArrayId
1212

13+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
14+
1315
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef>
1416

1517
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant
@@ -24,8 +26,6 @@ impl vortex_array::array::vtable::VTable for vortex_tensor::encodings::turboquan
2426

2527
pub type vortex_tensor::encodings::turboquant::TurboQuant::ArrayData = vortex_tensor::encodings::turboquant::TurboQuantData
2628

27-
pub type vortex_tensor::encodings::turboquant::TurboQuant::Metadata = vortex_tensor::encodings::turboquant::TurboQuantMetadata
28-
2929
pub type vortex_tensor::encodings::turboquant::TurboQuant::OperationsVTable = vortex_tensor::encodings::turboquant::TurboQuant
3030

3131
pub type vortex_tensor::encodings::turboquant::TurboQuant::ValidityVTable = vortex_array::array::vtable::validity::ValidityVTableFromChild
@@ -38,35 +38,25 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer(_array: vortex_a
3838

3939
pub fn vortex_tensor::encodings::turboquant::TurboQuant::buffer_name(_array: vortex_array::array::view::ArrayView<'_, Self>, _idx: usize) -> core::option::Option<alloc::string::String>
4040

41-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantData>
42-
43-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Metadata>
44-
45-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::dtype(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::dtype::DType
41+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::deserialize(&self, dtype: &vortex_array::dtype::DType, len: usize, metadata: &[u8], _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren, _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::ArrayData>
4642

4743
pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute(array: vortex_array::array::typed::Array<Self>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::executor::ExecutionResult>
4844

4945
pub fn vortex_tensor::encodings::turboquant::TurboQuant::execute_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>
5046

5147
pub fn vortex_tensor::encodings::turboquant::TurboQuant::id(&self) -> vortex_array::array::ArrayId
5248

53-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::len(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> usize
54-
55-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::metadata(array: vortex_array::array::view::ArrayView<'_, Self>) -> vortex_error::VortexResult<Self::Metadata>
56-
5749
pub fn vortex_tensor::encodings::turboquant::TurboQuant::nbuffers(_array: vortex_array::array::view::ArrayView<'_, Self>) -> usize
5850

5951
pub fn vortex_tensor::encodings::turboquant::TurboQuant::reduce_parent(array: vortex_array::array::view::ArrayView<'_, Self>, parent: &vortex_array::array::erased::ArrayRef, child_idx: usize) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::erased::ArrayRef>>
6052

61-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
53+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::serialize(array: vortex_array::array::view::ArrayView<'_, Self>) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
6254

6355
pub fn vortex_tensor::encodings::turboquant::TurboQuant::slot_name(_array: vortex_array::array::view::ArrayView<'_, Self>, idx: usize) -> alloc::string::String
6456

6557
pub fn vortex_tensor::encodings::turboquant::TurboQuant::slots(array: vortex_array::array::view::ArrayView<'_, Self>) -> &[core::option::Option<vortex_array::array::erased::ArrayRef>]
6658

67-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::stats(array: &vortex_tensor::encodings::turboquant::TurboQuantData) -> &vortex_array::stats::array::ArrayStats
68-
69-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::vtable(_array: &Self::ArrayData) -> &Self
59+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate(&self, data: &Self::ArrayData, dtype: &vortex_array::dtype::DType, len: usize) -> vortex_error::VortexResult<()>
7060

7161
pub fn vortex_tensor::encodings::turboquant::TurboQuant::with_slots(array: &mut vortex_tensor::encodings::turboquant::TurboQuantData, slots: alloc::vec::Vec<core::option::Option<vortex_array::array::erased::ArrayRef>>) -> vortex_error::VortexResult<()>
7262

@@ -116,46 +106,26 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantData::codes(&self) -> &vo
116106

117107
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::dimension(&self) -> u32
118108

119-
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> Self
109+
pub unsafe fn vortex_tensor::encodings::turboquant::TurboQuantData::new_unchecked(dtype: &vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> Self
120110

121111
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::norms(&self) -> &vortex_array::array::erased::ArrayRef
122112

123113
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::padded_dim(&self) -> u32
124114

125115
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::rotation_signs(&self) -> &vortex_array::array::erased::ArrayRef
126116

127-
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<Self>
117+
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::try_new(dtype: &vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<Self>
128118

129119
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::validate(dtype: &vortex_array::dtype::DType, codes: &vortex_array::array::erased::ArrayRef, norms: &vortex_array::array::erased::ArrayRef, centroids: &vortex_array::array::erased::ArrayRef, rotation_signs: &vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<()>
130120

131121
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantData
132122

133123
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantData
134124

135-
impl core::convert::From<vortex_tensor::encodings::turboquant::TurboQuantData> for vortex_array::array::erased::ArrayRef
136-
137-
pub fn vortex_array::array::erased::ArrayRef::from(value: vortex_tensor::encodings::turboquant::TurboQuantData) -> vortex_array::array::erased::ArrayRef
138-
139125
impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantData
140126

141127
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
142128

143-
impl vortex_array::array::IntoArray for vortex_tensor::encodings::turboquant::TurboQuantData
144-
145-
pub fn vortex_tensor::encodings::turboquant::TurboQuantData::into_array(self) -> vortex_array::array::erased::ArrayRef
146-
147-
pub struct vortex_tensor::encodings::turboquant::TurboQuantMetadata
148-
149-
pub vortex_tensor::encodings::turboquant::TurboQuantMetadata::bit_width: u8
150-
151-
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantMetadata
152-
153-
pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::clone(&self) -> vortex_tensor::encodings::turboquant::TurboQuantMetadata
154-
155-
impl core::fmt::Debug for vortex_tensor::encodings::turboquant::TurboQuantMetadata
156-
157-
pub fn vortex_tensor::encodings::turboquant::TurboQuantMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
158-
159129
pub struct vortex_tensor::encodings::turboquant::TurboQuantScheme
160130

161131
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuantScheme
@@ -186,10 +156,10 @@ pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::matches(&self, ca
186156

187157
pub fn vortex_tensor::encodings::turboquant::TurboQuantScheme::scheme_name(&self) -> &'static str
188158

189-
pub fn vortex_tensor::encodings::turboquant::initialize(session: &mut vortex_session::VortexSession)
190-
191159
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: &vortex_array::arrays::extension::vtable::ExtensionArray, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
192160

161+
pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::array::typed::Array<vortex_tensor::encodings::turboquant::TurboQuant>
162+
193163
pub mod vortex_tensor::fixed_shape
194164

195165
pub struct vortex_tensor::fixed_shape::FixedShapeTensor

vortex-tensor/src/encodings/turboquant/array/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@ pub(crate) mod centroids;
1111
pub(crate) mod rotation;
1212

1313
pub(crate) mod scheme;
14+
15+
use num_traits::Float;
16+
use num_traits::FromPrimitive;
17+
use vortex_error::VortexExpect;
18+
19+
/// Convert an f32 value to a float type `T`.
20+
///
21+
/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the
22+
/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening.
23+
pub(crate) fn float_from_f32<T: Float + FromPrimitive>(v: f32) -> T {
24+
FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible")
25+
}

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
//! TurboQuant encoding (quantization) logic.
55
6+
use num_traits::ToPrimitive;
67
use vortex_array::ArrayRef;
78
use vortex_array::ExecutionCtx;
89
use vortex_array::IntoArray;
@@ -15,6 +16,7 @@ use vortex_array::dtype::PType;
1516
use vortex_array::match_each_float_ptype;
1617
use vortex_array::validity::Validity;
1718
use vortex_buffer::BufferMut;
19+
use vortex_error::VortexExpect;
1820
use vortex_error::VortexResult;
1921
use vortex_error::vortex_bail;
2022
use vortex_error::vortex_ensure;
@@ -105,14 +107,18 @@ fn turboquant_quantize_core(
105107
// the input, so null vectors get null norms automatically.
106108
let norms_sfn = L2Norm::try_new_array(&ApproxOptions::Exact, ext.as_ref().clone(), num_rows)?;
107109
let norms_array: ArrayRef = norms_sfn.into_array().execute(ctx)?;
108-
let norms_prim: PrimitiveArray = norms_array.to_canonical()?.into_primitive();
110+
let norms_prim: PrimitiveArray = norms_array.clone().execute(ctx)?;
109111

110112
// Extract f32 norms for the internal quantization loop.
111113
let f32_norms: Vec<f32> = match_each_float_ptype!(norms_prim.ptype(), |T| {
112114
norms_prim
113115
.as_slice::<T>()
114116
.iter()
115-
.map(|&v| num_traits::ToPrimitive::to_f32(&v).unwrap_or(0.0))
117+
.map(|&v| {
118+
// `ToPrimitive::to_f32` is infallible for all float types: f16 -> f32 is lossless,
119+
// f32 is identity, and f64 -> f32 saturates to +-inf.
120+
ToPrimitive::to_f32(&v).vortex_expect("float-to-f32 conversion is infallible")
121+
})
116122
.collect()
117123
});
118124

vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
//! usually sufficient -- the relative ordering of cosine similarities is preserved
3131
//! even if the absolute values have bounded error.
3232
33-
use num_traits::FromPrimitive;
34-
use num_traits::Zero;
3533
use vortex_array::ArrayRef;
3634
use vortex_array::ArrayView;
3735
use vortex_array::ExecutionCtx;
@@ -44,19 +42,9 @@ use vortex_error::VortexResult;
4442
use vortex_error::vortex_ensure_eq;
4543

4644
use crate::encodings::turboquant::TurboQuant;
45+
use crate::encodings::turboquant::array::float_from_f32;
4746
use crate::utils::extension_element_ptype;
4847

49-
/// Convert an f32 value to `T`, returning `T::zero()` if the conversion fails.
50-
///
51-
/// This helper exists because `half::f16` has an inherent `from_f32` method that shadows
52-
/// the [`FromPrimitive`] trait method, causing compilation errors when used inside
53-
/// [`match_each_float_ptype!`].
54-
#[inline]
55-
fn f32_to_t<T: FromPrimitive + Zero>(v: f32) -> T {
56-
// TODO(connor): Is this actually correct? How should we handle f64 overflow?
57-
FromPrimitive::from_f32(v).unwrap_or_else(T::zero)
58-
}
59-
6048
/// Compute the per-row unit-norm dot products in f32 (centroids are always f32).
6149
///
6250
/// Returns a `Vec<f32>` of length `num_rows`.
@@ -124,7 +112,7 @@ pub fn cosine_similarity_quantized_column(
124112
let mut result = BufferMut::<T>::with_capacity(dots.len());
125113
for &dot in &dots {
126114
// SAFETY: We allocated the correct amount.
127-
unsafe { result.push_unchecked(f32_to_t(dot)) };
115+
unsafe { result.push_unchecked(float_from_f32(dot)) };
128116
}
129117

130118
// SAFETY: `result` has the same length as the input arrays, matching `validity`.
@@ -164,7 +152,7 @@ pub fn dot_product_quantized_column(
164152

165153
let mut result = BufferMut::<T>::with_capacity(num_rows);
166154
for row in 0..num_rows {
167-
let dot_t: T = f32_to_t(dots[row]);
155+
let dot_t: T = float_from_f32(dots[row]);
168156
// SAFETY: We allocated the correct amount.
169157
unsafe { result.push_unchecked(na[row] * nb[row] * dot_t) };
170158
}

vortex-tensor/src/encodings/turboquant/decompress.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
//! TurboQuant decoding (dequantization) logic.
55
6+
use num_traits::Float;
67
use num_traits::FromPrimitive;
7-
use num_traits::Zero;
88
use vortex_array::Array;
99
use vortex_array::ArrayRef;
1010
use vortex_array::ExecutionCtx;
@@ -20,6 +20,7 @@ use vortex_buffer::BufferMut;
2020
use vortex_error::VortexResult;
2121

2222
use crate::encodings::turboquant::TurboQuant;
23+
use crate::encodings::turboquant::array::float_from_f32;
2324
use crate::encodings::turboquant::array::rotation::RotationMatrix;
2425
use crate::utils::extension_element_ptype;
2526

@@ -103,7 +104,7 @@ pub fn execute_decompress(
103104
}
104105

105106
/// Typed decompress: reads norms as `T`, dequantizes in f32, and produces output as `T`.
106-
fn decompress_typed<T: NativePType + FromPrimitive + Zero>(
107+
fn decompress_typed<T: NativePType + Float + FromPrimitive>(
107108
norms_prim: &PrimitiveArray,
108109
centroids: &[f32],
109110
rotation: &RotationMatrix,
@@ -129,8 +130,7 @@ fn decompress_typed<T: NativePType + FromPrimitive + Zero>(
129130
rotation.inverse_rotate(&dequantized, &mut unrotated);
130131

131132
for idx in 0..dim {
132-
// Convert f32 dequantized value to T, then scale by the native-precision norm.
133-
let val = T::from_f32(unrotated[idx]).unwrap_or_else(T::zero) * norm;
133+
let val = float_from_f32::<T>(unrotated[idx]) * norm;
134134
output.push(val);
135135
}
136136
}

vortex-tensor/src/encodings/turboquant/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ pub(crate) mod compute;
9898

9999
mod vtable;
100100
pub use vtable::TurboQuant;
101+
pub use vtable::TurboQuantArray;
101102

102103
mod compress;
103104
pub use compress::TurboQuantConfig;

0 commit comments

Comments
 (0)