Skip to content

Commit 1280166

Browse files
committed
add normalized vector type
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 3e00b5a commit 1280166

20 files changed

Lines changed: 1595 additions & 451 deletions

File tree

vortex-array/src/arrays/extension/array.rs

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::fmt::Formatter;
66

77
use vortex_error::VortexExpect;
88
use vortex_error::VortexResult;
9+
use vortex_error::vortex_ensure;
910

1011
use crate::ArrayRef;
1112
use crate::array::Array;
@@ -90,10 +91,18 @@ impl ExtensionData {
9091
pub fn try_new(ext_dtype: ExtDTypeRef, storage_dtype: &DType) -> VortexResult<Self> {
9192
// TODO(connor): Replace these statements once we add `validate_storage_array`.
9293
// ext_dtype.validate_storage_array(&storage_array)?;
93-
assert_eq!(
94+
//
95+
// The storage array's outer nullability is allowed to differ from the extension's declared
96+
// storage outer nullability. Nested storage nullability must still match exactly.
97+
vortex_ensure!(
98+
storage_dtypes_match_ignoring_outer_nullability(
99+
ext_dtype.storage_dtype(),
100+
storage_dtype
101+
),
102+
"ExtensionArray: storage_dtype must match storage array DType (ignoring outer \
103+
nullability only), got extension storage {} and array storage {}",
94104
ext_dtype.storage_dtype(),
95105
storage_dtype,
96-
"ExtensionArray: storage_dtype must match storage array DType",
97106
);
98107

99108
// SAFETY: we validate that the inputs are valid above.
@@ -113,10 +122,18 @@ impl ExtensionData {
113122
// ext_dtype
114123
// .validate_storage_array(&storage_array)
115124
// .vortex_expect("[Debug Assertion]: Invalid storage array for `ExtensionArray`");
116-
debug_assert_eq!(
125+
//
126+
// Match the contract of [`Self::try_new`]: the storage dtype must match the extension's
127+
// declared storage dtype ignoring only outer nullability.
128+
debug_assert!(
129+
storage_dtypes_match_ignoring_outer_nullability(
130+
ext_dtype.storage_dtype(),
131+
storage_dtype
132+
),
133+
"ExtensionArray: storage_dtype must match storage array DType (ignoring outer \
134+
nullability only), got extension storage {} and array storage {}",
117135
ext_dtype.storage_dtype(),
118136
storage_dtype,
119-
"ExtensionArray: storage_dtype must match storage array DType",
120137
);
121138

122139
Self { ext_dtype }
@@ -128,6 +145,13 @@ impl ExtensionData {
128145
}
129146
}
130147

148+
fn storage_dtypes_match_ignoring_outer_nullability(
149+
ext_storage_dtype: &DType,
150+
storage_dtype: &DType,
151+
) -> bool {
152+
ext_storage_dtype.with_nullability(storage_dtype.nullability()) == *storage_dtype
153+
}
154+
131155
pub trait ExtensionArrayExt: TypedArrayRef<Extension> {
132156
fn storage_array(&self) -> &ArrayRef {
133157
self.as_ref().slots()[STORAGE_SLOT]
@@ -179,3 +203,93 @@ impl Array<Extension> {
179203
Self::try_new(ext_dtype, storage_array)
180204
}
181205
}
206+
207+
#[cfg(test)]
208+
mod tests {
209+
use std::sync::Arc;
210+
211+
use vortex_buffer::Buffer;
212+
213+
use super::*;
214+
use crate::IntoArray;
215+
use crate::arrays::ExtensionArray;
216+
use crate::arrays::FixedSizeListArray;
217+
use crate::arrays::PrimitiveArray;
218+
use crate::dtype::Nullability;
219+
use crate::dtype::PType;
220+
use crate::dtype::extension::ExtId;
221+
use crate::extension::EmptyMetadata;
222+
use crate::scalar::ScalarValue;
223+
use crate::validity::Validity;
224+
225+
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
226+
struct TestExt;
227+
228+
impl ExtVTable for TestExt {
229+
type Metadata = EmptyMetadata;
230+
type NativeValue<'a> = &'a ScalarValue;
231+
232+
fn id(&self) -> ExtId {
233+
ExtId::new("vortex.test.extension")
234+
}
235+
236+
fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
237+
Ok(Vec::new())
238+
}
239+
240+
fn deserialize_metadata(&self, _metadata: &[u8]) -> VortexResult<Self::Metadata> {
241+
Ok(EmptyMetadata)
242+
}
243+
244+
fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
245+
Ok(())
246+
}
247+
248+
fn unpack_native<'a>(
249+
_ext_dtype: &'a ExtDType<Self>,
250+
storage_value: &'a ScalarValue,
251+
) -> VortexResult<Self::NativeValue<'a>> {
252+
Ok(storage_value)
253+
}
254+
}
255+
256+
fn fsl_dtype(element_nullability: Nullability, list_nullability: Nullability) -> DType {
257+
DType::FixedSizeList(
258+
Arc::new(DType::Primitive(PType::F32, element_nullability)),
259+
2,
260+
list_nullability,
261+
)
262+
}
263+
264+
#[test]
265+
fn extension_storage_allows_outer_nullability_mismatch() -> VortexResult<()> {
266+
let ext_dtype = ExtDType::<TestExt>::try_new(
267+
EmptyMetadata,
268+
fsl_dtype(Nullability::NonNullable, Nullability::NonNullable),
269+
)?
270+
.erased();
271+
272+
let elements = PrimitiveArray::from_iter([1.0f32, 0.0]).into_array();
273+
let storage = FixedSizeListArray::try_new(elements, 2, Validity::AllValid, 1)?.into_array();
274+
275+
ExtensionArray::try_new(ext_dtype, storage)?;
276+
Ok(())
277+
}
278+
279+
#[test]
280+
fn extension_storage_rejects_nested_nullability_mismatch() -> VortexResult<()> {
281+
let ext_dtype = ExtDType::<TestExt>::try_new(
282+
EmptyMetadata,
283+
fsl_dtype(Nullability::NonNullable, Nullability::NonNullable),
284+
)?
285+
.erased();
286+
287+
let elements =
288+
PrimitiveArray::new(Buffer::copy_from([1.0f32, 0.0]), Validity::AllValid).into_array();
289+
let storage =
290+
FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 1)?.into_array();
291+
292+
assert!(ExtensionArray::try_new(ext_dtype, storage).is_err());
293+
Ok(())
294+
}
295+
}

vortex-tensor/public-api.lock

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,64 @@ pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::
250250

251251
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
252252

253+
pub mod vortex_tensor::normalized_vector
254+
255+
pub struct vortex_tensor::normalized_vector::AnyNormalizedVector
256+
257+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::normalized_vector::AnyNormalizedVector
258+
259+
pub type vortex_tensor::normalized_vector::AnyNormalizedVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata
260+
261+
pub fn vortex_tensor::normalized_vector::AnyNormalizedVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
262+
263+
pub struct vortex_tensor::normalized_vector::NormalizedVector
264+
265+
impl vortex_tensor::normalized_vector::NormalizedVector
266+
267+
pub unsafe fn vortex_tensor::normalized_vector::NormalizedVector::new_unchecked(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
268+
269+
pub fn vortex_tensor::normalized_vector::NormalizedVector::try_new(storage: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
270+
271+
impl core::clone::Clone for vortex_tensor::normalized_vector::NormalizedVector
272+
273+
pub fn vortex_tensor::normalized_vector::NormalizedVector::clone(&self) -> vortex_tensor::normalized_vector::NormalizedVector
274+
275+
impl core::cmp::Eq for vortex_tensor::normalized_vector::NormalizedVector
276+
277+
impl core::cmp::PartialEq for vortex_tensor::normalized_vector::NormalizedVector
278+
279+
pub fn vortex_tensor::normalized_vector::NormalizedVector::eq(&self, other: &vortex_tensor::normalized_vector::NormalizedVector) -> bool
280+
281+
impl core::default::Default for vortex_tensor::normalized_vector::NormalizedVector
282+
283+
pub fn vortex_tensor::normalized_vector::NormalizedVector::default() -> vortex_tensor::normalized_vector::NormalizedVector
284+
285+
impl core::fmt::Debug for vortex_tensor::normalized_vector::NormalizedVector
286+
287+
pub fn vortex_tensor::normalized_vector::NormalizedVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
288+
289+
impl core::hash::Hash for vortex_tensor::normalized_vector::NormalizedVector
290+
291+
pub fn vortex_tensor::normalized_vector::NormalizedVector::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
292+
293+
impl core::marker::StructuralPartialEq for vortex_tensor::normalized_vector::NormalizedVector
294+
295+
impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_tensor::normalized_vector::NormalizedVector
296+
297+
pub type vortex_tensor::normalized_vector::NormalizedVector::Metadata = vortex_array::extension::EmptyMetadata
298+
299+
pub type vortex_tensor::normalized_vector::NormalizedVector::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue
300+
301+
pub fn vortex_tensor::normalized_vector::NormalizedVector::deserialize_metadata(&self, _metadata: &[u8]) -> vortex_error::VortexResult<Self::Metadata>
302+
303+
pub fn vortex_tensor::normalized_vector::NormalizedVector::id(&self) -> vortex_array::dtype::extension::ExtId
304+
305+
pub fn vortex_tensor::normalized_vector::NormalizedVector::serialize_metadata(&self, _metadata: &Self::Metadata) -> vortex_error::VortexResult<alloc::vec::Vec<u8>>
306+
307+
pub fn vortex_tensor::normalized_vector::NormalizedVector::unpack_native<'a>(_ext_dtype: &'a vortex_array::dtype::extension::typed::ExtDType<Self>, storage_value: &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult<Self::NativeValue>
308+
309+
pub fn vortex_tensor::normalized_vector::NormalizedVector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
310+
253311
pub mod vortex_tensor::scalar_fns
254312

255313
pub mod vortex_tensor::scalar_fns::cosine_similarity
@@ -382,8 +440,6 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:
382440

383441
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
384442

385-
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
386-
387443
pub mod vortex_tensor::scalar_fns::l2_norm
388444

389445
pub struct vortex_tensor::scalar_fns::l2_norm::L2Norm
@@ -574,7 +630,9 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32
574630

575631
pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType
576632

577-
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
633+
pub fn vortex_tensor::vector::VectorMatcherMetadata::is_normalized(&self) -> bool
634+
635+
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32, is_normalized: bool) -> vortex_error::VortexResult<Self>
578636

579637
impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata
580638

vortex-tensor/src/encodings/l2_denorm.rs

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ use vortex_compressor::scheme::Scheme;
1313
use vortex_compressor::stats::ArrayAndStats;
1414
use vortex_error::VortexResult;
1515

16-
use crate::matcher::AnyTensor;
16+
use crate::normalized_vector::AnyNormalizedVector;
1717
use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
18+
use crate::types::vector::AnyVector;
1819

1920
#[derive(Debug)]
2021
pub struct L2DenormScheme;
@@ -25,10 +26,11 @@ impl Scheme for L2DenormScheme {
2526
}
2627

2728
fn matches(&self, canonical: &Canonical) -> bool {
28-
matches!(
29-
canonical,
30-
Canonical::Extension(ext) if ext.ext_dtype().is::<AnyTensor>()
31-
)
29+
let Canonical::Extension(ext) = canonical else {
30+
return false;
31+
};
32+
33+
ext.ext_dtype().is::<AnyVector>() && !ext.ext_dtype().is::<AnyNormalizedVector>()
3234
}
3335

3436
fn expected_compression_ratio(
@@ -37,6 +39,7 @@ impl Scheme for L2DenormScheme {
3739
_compress_ctx: CompressorContext,
3840
_exec_ctx: &mut ExecutionCtx,
3941
) -> CompressionEstimate {
42+
// We almost always want to pre-normalize our data if the vector is not already normalized.
4043
CompressionEstimate::Verdict(EstimateVerdict::AlwaysUse)
4144
}
4245

@@ -51,3 +54,62 @@ impl Scheme for L2DenormScheme {
5154
Ok(l2_denorm.into_array())
5255
}
5356
}
57+
58+
#[cfg(test)]
59+
mod tests {
60+
use std::sync::Arc;
61+
62+
use vortex_array::Canonical;
63+
use vortex_array::IntoArray;
64+
use vortex_array::arrays::ExtensionArray;
65+
use vortex_array::arrays::FixedSizeListArray;
66+
use vortex_array::arrays::PrimitiveArray;
67+
use vortex_array::dtype::DType;
68+
use vortex_array::dtype::Nullability;
69+
use vortex_array::dtype::PType;
70+
use vortex_array::dtype::extension::ExtDType;
71+
use vortex_array::extension::EmptyMetadata;
72+
use vortex_array::validity::Validity;
73+
use vortex_compressor::scheme::Scheme;
74+
use vortex_error::VortexResult;
75+
76+
use super::L2DenormScheme;
77+
use crate::types::fixed_shape::FixedShapeTensor;
78+
use crate::types::fixed_shape::FixedShapeTensorMetadata;
79+
use crate::types::vector::Vector;
80+
81+
fn fsl_storage(elements: &[f32], list_size: u32) -> VortexResult<FixedSizeListArray> {
82+
let len = elements.len() / list_size as usize;
83+
let elements = PrimitiveArray::from_iter(elements.iter().copied()).into_array();
84+
FixedSizeListArray::try_new(elements, list_size, Validity::NonNullable, len)
85+
}
86+
87+
#[test]
88+
fn matches_vector() -> VortexResult<()> {
89+
let fsl = fsl_storage(&[1.0, 0.0], 2)?;
90+
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
91+
let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array()));
92+
93+
assert!(L2DenormScheme.matches(&canonical));
94+
Ok(())
95+
}
96+
97+
#[test]
98+
fn rejects_fixed_shape_tensor() -> VortexResult<()> {
99+
let fsl = fsl_storage(&[1.0, 0.0, 0.0, 1.0], 4)?;
100+
let storage_dtype = DType::FixedSizeList(
101+
Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)),
102+
4,
103+
Nullability::NonNullable,
104+
);
105+
let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
106+
FixedShapeTensorMetadata::new(vec![2, 2]),
107+
storage_dtype,
108+
)?
109+
.erased();
110+
let canonical = Canonical::Extension(ExtensionArray::new(ext_dtype, fsl.into_array()));
111+
112+
assert!(!L2DenormScheme.matches(&canonical));
113+
Ok(())
114+
}
115+
}

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
3939
use crate::scalar_fns::sorf_transform::SorfMatrix;
4040
use crate::scalar_fns::sorf_transform::SorfOptions;
4141
use crate::scalar_fns::sorf_transform::SorfTransform;
42+
use crate::types::normalized_vector::NormalizedVector;
4243
use crate::types::vector::AnyVector;
43-
use crate::types::vector::Vector;
4444
use crate::utils::cast_to_f32;
4545

4646
/// Configuration for TurboQuant encoding.
@@ -80,8 +80,8 @@ impl Default for TurboQuantConfig {
8080
///
8181
/// # Errors
8282
///
83-
/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or
84-
/// if [`turboquant_encode_unchecked`] rejects the input shape.
83+
/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or if
84+
/// [`turboquant_encode_unchecked`] rejects the input shape.
8585
pub fn turboquant_encode(
8686
input: ArrayRef,
8787
config: &TurboQuantConfig,
@@ -101,7 +101,7 @@ pub fn turboquant_encode(
101101
let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?;
102102

103103
// SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally
104-
// bypass the strict normalized-row validation when reattaching the stored norms.
104+
// bypass the strict normalized-row and zero-row validation when reattaching the stored norms.
105105
Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array())
106106
}
107107

@@ -156,7 +156,10 @@ pub unsafe fn turboquant_encode_unchecked(
156156
Validity::NonNullable,
157157
0,
158158
)?;
159-
let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?;
159+
// SAFETY: An empty FSL contains no rows, so the unit-norm-or-zero invariant holds
160+
// vacuously.
161+
let empty_padded_vector =
162+
unsafe { NormalizedVector::new_unchecked(empty_fsl.into_array()) }?;
160163

161164
let sorf_options = SorfOptions {
162165
seed,
@@ -172,7 +175,11 @@ pub unsafe fn turboquant_encode_unchecked(
172175
let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
173176
let quantized_fsl =
174177
build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?;
175-
let padded_vector = Vector::try_new_vector_array(quantized_fsl)?;
178+
// SAFETY: TurboQuant is a lossy approximation of the already-unit-norm input. The
179+
// quantized rows are approximately unit-norm by construction; downstream callers
180+
// (notably the enclosing `L2Denorm` wrapper) treat the stored-norm + NormalizedVector
181+
// claim as authoritative rather than decode-verified.
182+
let padded_vector = unsafe { NormalizedVector::new_unchecked(quantized_fsl) }?;
176183

177184
let sorf_options = SorfOptions {
178185
seed,

0 commit comments

Comments
 (0)