Skip to content

Commit c617ba9

Browse files
committed
add normalized vector type
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 30042ee commit c617ba9

29 files changed

Lines changed: 1855 additions & 717 deletions

File tree

vortex-array/src/arrays/extension/array.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,93 @@ impl Array<Extension> {
7878
Self::try_new(ext_dtype, storage_array)
7979
}
8080
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
use std::sync::Arc;
85+
86+
use vortex_buffer::Buffer;
87+
88+
use super::*;
89+
use crate::IntoArray;
90+
use crate::arrays::ExtensionArray;
91+
use crate::arrays::FixedSizeListArray;
92+
use crate::arrays::PrimitiveArray;
93+
use crate::dtype::Nullability;
94+
use crate::dtype::PType;
95+
use crate::dtype::extension::ExtId;
96+
use crate::extension::EmptyMetadata;
97+
use crate::scalar::ScalarValue;
98+
use crate::validity::Validity;
99+
100+
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
101+
struct TestExt;
102+
103+
impl ExtVTable for TestExt {
104+
type Metadata = EmptyMetadata;
105+
type NativeValue<'a> = &'a ScalarValue;
106+
107+
fn id(&self) -> ExtId {
108+
ExtId::new("vortex.test.extension")
109+
}
110+
111+
fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
112+
Ok(Vec::new())
113+
}
114+
115+
fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult<Self::Metadata> {
116+
Ok(EmptyMetadata)
117+
}
118+
119+
fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
120+
Ok(())
121+
}
122+
123+
fn unpack_native<'a>(
124+
_ext_dtype: &'a ExtDType<Self>,
125+
storage_value: &'a ScalarValue,
126+
) -> VortexResult<Self::NativeValue<'a>> {
127+
Ok(storage_value)
128+
}
129+
}
130+
131+
fn fsl_dtype(element_nullability: Nullability, list_nullability: Nullability) -> DType {
132+
DType::FixedSizeList(
133+
Arc::new(DType::Primitive(PType::F32, element_nullability)),
134+
2,
135+
list_nullability,
136+
)
137+
}
138+
139+
#[test]
140+
fn extension_storage_allows_outer_nullability_mismatch() -> VortexResult<()> {
141+
let ext_dtype = ExtDType::<TestExt>::try_new(
142+
EmptyMetadata,
143+
fsl_dtype(Nullability::NonNullable, Nullability::NonNullable),
144+
)?
145+
.erased();
146+
147+
let elements = PrimitiveArray::from_iter([1.0f32, 0.0]).into_array();
148+
let storage = FixedSizeListArray::try_new(elements, 2, Validity::AllValid, 1)?.into_array();
149+
150+
ExtensionArray::try_new(ext_dtype, storage)?;
151+
Ok(())
152+
}
153+
154+
#[test]
155+
fn extension_storage_rejects_nested_nullability_mismatch() -> VortexResult<()> {
156+
let ext_dtype = ExtDType::<TestExt>::try_new(
157+
EmptyMetadata,
158+
fsl_dtype(Nullability::NonNullable, Nullability::NonNullable),
159+
)?
160+
.erased();
161+
162+
let elements =
163+
PrimitiveArray::new(Buffer::copy_from([1.0f32, 0.0]), Validity::AllValid).into_array();
164+
let storage =
165+
FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 1)?.into_array();
166+
167+
assert!(ExtensionArray::try_new(ext_dtype, storage).is_err());
168+
Ok(())
169+
}
170+
}

vortex-tensor/public-api.lock

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,64 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::
250250

251251
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
252252

253+
pub mod vortex_tensor::normalized_vector
254+
255+
pub struct vortex_tensor::normalized_vector::AnyNormalizedVector
256+
257+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector
258+
259+
pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata
260+
261+
pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
262+
263+
pub struct vortex_tensor::normalized_vector::NormalizedVector
264+
265+
impl vortex_tensor::normalized_vector::NormalizedVector
266+
267+
pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
268+
269+
pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(storage: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
270+
271+
impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector
272+
273+
pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector
274+
275+
impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector
276+
277+
impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector
278+
279+
pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool
280+
281+
impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector
282+
283+
pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector
284+
285+
impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector
286+
287+
pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
288+
289+
impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector
290+
291+
pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
292+
293+
impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector
294+
295+
impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector
296+
297+
pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata
298+
299+
pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue
300+
301+
pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>
302+
303+
pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId
304+
305+
pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>
306+
307+
pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>
308+
309+
pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
310+
253311
pub mod vortex_tensor::scalar_fns
254312

255313
pub mod vortex_tensor::scalar_fns::cosine_similarity
@@ -382,8 +440,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:
382440

383441
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
384442

385-
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
386-
387443
pub mod vortex_tensor::scalar_fns::l2_norm
388444

389445
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
@@ -574,7 +630,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32
574630

575631
pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType
576632

577-
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
633+
pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(&self) -> bool
634+
635+
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult<Self>
578636

579637
impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata
580638

vortex-tensor/src/encodings/l2_denorm.rs

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ use vortex_compressor::scheme::Scheme;
1414
use vortex_compressor::stats::ArrayAndStats;
1515
use vortex_error::VortexResult;
1616

17-
use crate::matcher::AnyTensor;
17+
use crate::normalized_vector::AnyNormalizedVector;
1818
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
19+
use crate::types::vector::AnyVector;
1920

2021
#[derive(Debug)]
2122
pub struct L2DenormScheme;
@@ -26,10 +27,11 @@ impl Scheme for L2DenormScheme {
2627
}
2728

2829
fn matches(&self, canonical: &Canonical) -> bool {
29-
matches!(
30-
canonical,
31-
Canonical::Extension(ext) if ext.ext_dtype().is::<AnyTensor>()
32-
)
30+
let Canonical::Extension(ext) = canonical else {
31+
return false;
32+
};
33+
34+
ext.ext_dtype().is::<AnyVector>() && !ext.ext_dtype().is::<AnyNormalizedVector>()
3335
}
3436

3537
fn expected_compression_ratio(
@@ -38,6 +40,7 @@ impl Scheme for L2DenormScheme {
3840
_compress_ctx: CompressorContext,
3941
_exec_ctx: &mut ExecutionCtx,
4042
) -> CompressionEstimate {
43+
// We almost always want to pre-normalize our data if the vector is not already normalized.
4144
CompressionEstimate::Verdict(EstimateVerdict::AlwaysUse)
4245
}
4346

@@ -52,3 +55,62 @@ impl Scheme for L2DenormScheme {
5255
Ok(l2_denorm.into_array())
5356
}
5457
}
58+
59+
#[cfg(test)]
60+
mod tests {
61+
use std::sync::Arc;
62+
63+
use vortex_array::Canonical;
64+
use vortex_array::IntoArray;
65+
use vortex_array::arrays::ExtensionArray;
66+
use vortex_array::arrays::FixedSizeListArray;
67+
use vortex_array::arrays::PrimitiveArray;
68+
use vortex_array::dtype::DType;
69+
use vortex_array::dtype::Nullability;
70+
use vortex_array::dtype::PType;
71+
use vortex_array::dtype::extension::ExtDType;
72+
use vortex_array::extension::EmptyMetadata;
73+
use vortex_array::validity::Validity;
74+
use vortex_compressor::scheme::Scheme;
75+
use vortex_error::VortexResult;
76+
77+
use super::L2DenormScheme;
78+
use crate::types::fixed_shape::FixedShapeTensor;
79+
use crate::types::fixed_shape::FixedShapeTensorMetadata;
80+
use crate::types::vector::Vector;
81+
82+
fn fsl_storage(elements: &[f32], list_size: u32) -> VortexResult<FixedSizeListArray> {
83+
let len = elements.len() / list_size as usize;
84+
let elements = PrimitiveArray::from_iter(elements.iter().copied()).into_array();
85+
FixedSizeListArray::try_new(elements, list_size, Validity::NonNullable, len)
86+
}
87+
88+
#[test]
89+
fn matches_vector() -> VortexResult<()> {
90+
let fsl = fsl_storage(&[1.0, 0.0], 2)?;
91+
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
92+
let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array()));
93+
94+
assert!(L2DenormScheme.matches(&canonical));
95+
Ok(())
96+
}
97+
98+
#[test]
99+
fn rejects_fixed_shape_tensor() -> VortexResult<()> {
100+
let fsl = fsl_storage(&[1.0, 0.0, 0.0, 1.0], 4)?;
101+
let storage_dtype = DType::FixedSizeList(
102+
Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)),
103+
4,
104+
Nullability::NonNullable,
105+
);
106+
let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
107+
FixedShapeTensorMetadata::new(vec![2, 2]),
108+
storage_dtype,
109+
)?
110+
.erased();
111+
let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array()));
112+
113+
assert!(!L2DenormScheme.matches(&canonical));
114+
Ok(())
115+
}
116+
}

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new
3636
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
3737
/// quantization levels for the coordinate distribution after random rotation in
3838
/// `dimension`-dimensional space.
39-
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
39+
pub fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
4040
vortex_ensure!(
4141
(1..=MAX_BIT_WIDTH).contains(&bit_width),
4242
"TurboQuant bit_width must be 1-{}, got {bit_width}",
@@ -239,7 +239,7 @@ mod tests {
239239
#[case] bits: u8,
240240
#[case] expected: usize,
241241
) -> VortexResult<()> {
242-
let centroids = get_centroids(dim, bits)?;
242+
let centroids = compute_or_get_centroids(dim, bits)?;
243243
assert_eq!(centroids.len(), expected);
244244
Ok(())
245245
}
@@ -251,7 +251,7 @@ mod tests {
251251
#[case(128, 4)]
252252
#[case(768, 2)]
253253
fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
254-
let centroids = get_centroids(dim, bits)?;
254+
let centroids = compute_or_get_centroids(dim, bits)?;
255255
for window in centroids.windows(2) {
256256
assert!(
257257
window[0] < window[1],
@@ -268,7 +268,7 @@ mod tests {
268268
#[case(256, 2)]
269269
#[case(768, 2)]
270270
fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
271-
let centroids = get_centroids(dim, bits)?;
271+
let centroids = compute_or_get_centroids(dim, bits)?;
272272
let count = centroids.len();
273273
for idx in 0..count / 2 {
274274
let diff = (centroids[idx] + centroids[count - 1 - idx]).abs();
@@ -287,7 +287,7 @@ mod tests {
287287
#[case(128, 1)]
288288
#[case(128, 4)]
289289
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
290-
let centroids = get_centroids(dim, bits)?;
290+
let centroids = compute_or_get_centroids(dim, bits)?;
291291
for &val in centroids.iter() {
292292
assert!(
293293
(-1.0..=1.0).contains(&val),
@@ -299,15 +299,15 @@ mod tests {
299299

300300
#[test]
301301
fn centroids_cached() -> VortexResult<()> {
302-
let c1 = get_centroids(128, 2)?;
303-
let c2 = get_centroids(128, 2)?;
302+
let c1 = compute_or_get_centroids(128, 2)?;
303+
let c2 = compute_or_get_centroids(128, 2)?;
304304
assert_eq!(c1, c2);
305305
Ok(())
306306
}
307307

308308
#[test]
309309
fn find_nearest_basic() -> VortexResult<()> {
310-
let centroids = get_centroids(128, 2)?;
310+
let centroids = compute_or_get_centroids(128, 2)?;
311311
let boundaries = compute_centroid_boundaries(&centroids);
312312
assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0);
313313

@@ -324,9 +324,9 @@ mod tests {
324324

325325
#[test]
326326
fn rejects_invalid_params() {
327-
assert!(get_centroids(128, 0).is_err());
328-
assert!(get_centroids(128, 9).is_err());
329-
assert!(get_centroids(1, 2).is_err());
330-
assert!(get_centroids(127, 2).is_err());
327+
assert!(compute_or_get_centroids(128, 0).is_err());
328+
assert!(compute_or_get_centroids(128, 9).is_err());
329+
assert!(compute_or_get_centroids(1, 2).is_err());
330+
assert!(compute_or_get_centroids(127, 2).is_err());
331331
}
332332
}

0 commit comments

Comments
 (0)