Skip to content

Commit 12457aa

Browse files
authored
Clean up vortex-tensor (#7525)
## Summary Tracking issue: #7297 Given our fast velocity on this crate, quite a few things slipped through the cracks. This change cleans up the `vortex-tensor` crate by clearly defining the abstraction points, fixing a few bugs (the only real bug was a Tensor `Display` bug), cleaning up some TODOs, and generally raising the quality. ## API Changes The only relevant change is new helper functions ## Testing Just fixed up existing tests. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent d0a6dba commit 12457aa

File tree

27 files changed

+724
-944
lines changed

27 files changed

+724
-944
lines changed

vortex-array/public-api.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21546,6 +21546,8 @@ pub fn vortex_array::Array<vortex_array::arrays::Extension>::new(ext_dtype: vort
2154621546

2154721547
pub fn vortex_array::Array<vortex_array::arrays::Extension>::try_new(ext_dtype: vortex_array::dtype::extension::ExtDTypeRef, storage_array: vortex_array::ArrayRef) -> vortex_error::VortexResult<Self>
2154821548

21549+
pub fn vortex_array::Array<vortex_array::arrays::Extension>::try_new_from_vtable<V: vortex_array::dtype::extension::ExtVTable>(vtable: V, metadata: <V as vortex_array::dtype::extension::ExtVTable>::Metadata, storage_array: vortex_array::ArrayRef) -> vortex_error::VortexResult<Self>
21550+
2154921551
impl vortex_array::Array<vortex_array::arrays::Filter>
2155021552

2155121553
pub fn vortex_array::Array<vortex_array::arrays::Filter>::new(array: vortex_array::ArrayRef, mask: vortex_mask::Mask) -> Self

vortex-array/src/arrays/extension/array.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use crate::array::ArrayParts;
1313
use crate::array::TypedArrayRef;
1414
use crate::arrays::Extension;
1515
use crate::dtype::DType;
16+
use crate::dtype::extension::ExtDType;
1617
use crate::dtype::extension::ExtDTypeRef;
18+
use crate::dtype::extension::ExtVTable;
1719

1820
/// The backing storage array for this extension array.
1921
pub(super) const STORAGE_SLOT: usize = 0;
@@ -163,4 +165,17 @@ impl Array<Extension> {
163165
)
164166
})
165167
}
168+
169+
/// Creates a new [`ExtensionArray`](crate::arrays::ExtensionArray) from a vtable, metadata, and
170+
/// a storage array.
171+
pub fn try_new_from_vtable<V: ExtVTable>(
172+
vtable: V,
173+
metadata: V::Metadata,
174+
storage_array: ArrayRef,
175+
) -> VortexResult<Self> {
176+
let ext_dtype =
177+
ExtDType::<V>::try_with_vtable(vtable, metadata, storage_array.dtype().clone())?
178+
.erased();
179+
Self::try_new(ext_dtype, storage_array)
180+
}
166181
}

vortex-array/src/dtype/extension/vtable.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use crate::scalar::ScalarValue;
1414

1515
/// The public API for defining new extension types.
1616
///
17-
/// This is the non-object-safe trait that plugin authors implement to define a new extension
18-
/// type. It specifies the type's identity, metadata, serialization, and validation.
17+
/// This is the non-object-safe trait that plugin authors implement to define a new extension type.
18+
/// It specifies the type's identity, metadata, serialization, and validation.
1919
pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash {
2020
/// Associated type containing the deserialized metadata for this extension type.
2121
type Metadata: 'static + Send + Sync + Clone + Debug + Display + Eq + Hash;
@@ -39,26 +39,27 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash {
3939
/// Validate that the given storage type is compatible with this extension type.
4040
fn validate_dtype(ext_dtype: &ExtDType<Self>) -> VortexResult<()>;
4141

42-
/// Can a value of `other` be implicitly widened into this type?
43-
/// e.g. GeographyType might accept Point, LineString, etc.
42+
/// Can a value of `other` be implicitly widened into this type? (e.g. GeographyType might
43+
/// accept Point, LineString, etc.)
4444
///
45-
/// Implementors only need to override one of `can_coerce_from` or `can_coerce_to` both
46-
/// exist so that either side of the coercion can provide the logic.
45+
/// Implementors only need to override one of `can_coerce_from` or `can_coerce_to`. We have both
46+
/// so that either side of the coercion can provide the logic.
4747
fn can_coerce_from(ext_dtype: &ExtDType<Self>, other: &DType) -> bool {
4848
let _ = (ext_dtype, other);
4949
false
5050
}
5151

5252
/// Can this type be implicitly widened into `other`?
5353
///
54-
/// Implementors only need to override one of `can_coerce_from` or `can_coerce_to` both
55-
/// exist so that either side of the coercion can provide the logic.
54+
/// Implementors only need to override one of `can_coerce_from` or `can_coerce_to`. We have both
55+
/// so that either side of the coercion can provide the logic.
5656
fn can_coerce_to(ext_dtype: &ExtDType<Self>, other: &DType) -> bool {
5757
let _ = (ext_dtype, other);
5858
false
5959
}
6060

6161
/// Given two types in a Uniform context, what is their least supertype?
62+
///
6263
/// Return None if no supertype exists.
6364
fn least_supertype(ext_dtype: &ExtDType<Self>, other: &DType) -> Option<DType> {
6465
let _ = (ext_dtype, other);
@@ -69,7 +70,8 @@ pub trait ExtVTable: 'static + Sized + Send + Sync + Clone + Debug + Eq + Hash {
6970

7071
/// Validate the given storage value is compatible with the extension type.
7172
///
72-
/// By default, this calls [`unpack_native()`](ExtVTable::unpack_native) and discards the result.
73+
/// By default, this calls [`unpack_native()`](ExtVTable::unpack_native) and discards the
74+
/// result.
7375
///
7476
/// # Errors
7577
///

vortex-tensor/public-api.lock

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32
8080

8181
pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_tensor::vector::VectorMatcherMetadata>
8282

83-
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
83+
pub fn vortex_tensor::encodings::turboquant::turboquant_encode(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
8484

8585
pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
8686

@@ -142,7 +142,7 @@ impl vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>
142142

143143
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::element_ptype(&self) -> vortex_array::dtype::ptype::PType
144144

145-
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::list_size(&self) -> usize
145+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::flat_list_size(&self) -> u32
146146

147147
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::metadata(&self) -> &vortex_tensor::fixed_shape::FixedShapeTensorMetadata
148148

@@ -222,7 +222,7 @@ impl vortex_tensor::matcher::TensorMatch<'_>
222222

223223
pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType
224224

225-
pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> usize
225+
pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> u32
226226

227227
impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a>
228228

@@ -382,7 +382,7 @@ pub fn vortex_tensor::scalar_fns::l2_denorm::L2Denorm::validity(&self, _options:
382382

383383
pub fn vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm(input: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::scalar_fn::vtable::ScalarFnArray>
384384

385-
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows(input: &vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
385+
pub fn vortex_tensor::scalar_fns::l2_denorm::validate_l2_normalized_rows_against_norms(normalized: &vortex_array::array::erased::ArrayRef, norms: core::option::Option<&vortex_array::array::erased::ArrayRef>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<()>
386386

387387
pub mod vortex_tensor::scalar_fns::l2_norm
388388

@@ -502,7 +502,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::child_name(&sel
502502

503503
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
504504

505-
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
505+
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
506506

507507
pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::id(&self) -> vortex_array::scalar_fn::ScalarFnId
508508

@@ -600,12 +600,8 @@ impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherM
600600

601601
pub mod vortex_tensor::vector_search
602602

603-
pub fn vortex_tensor::vector_search::build_constant_query_vector<T: vortex_array::dtype::ptype::NativePType + core::convert::Into<vortex_array::scalar::typed_view::primitive::pvalue::PValue>>(query: &[T], num_rows: usize) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
604-
605603
pub fn vortex_tensor::vector_search::build_similarity_search_tree<T: vortex_array::dtype::ptype::NativePType + core::convert::Into<vortex_array::scalar::typed_view::primitive::pvalue::PValue>>(data: vortex_array::array::erased::ArrayRef, query: &[T], threshold: T) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
606604

607-
pub fn vortex_tensor::vector_search::compress_turboquant(data: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::erased::ArrayRef>
608-
609605
pub const vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str
610606

611607
pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession)

vortex-tensor/src/encodings/l2_denorm.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm;
1919
pub struct L2DenormScheme;
2020

2121
impl Scheme for L2DenormScheme {
22-
// TODO(connor): FIX THIS!!!
2322
fn scheme_name(&self) -> &'static str {
24-
"vortex.tensor.UNSTABLE.l2_denorm"
23+
"vortex.tensor.l2_denorm"
2524
}
2625

2726
fn matches(&self, canonical: &Canonical) -> bool {

vortex-tensor/src/encodings/turboquant/centroids.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1212
use std::sync::LazyLock;
1313

14+
use vortex_buffer::Buffer;
1415
use vortex_error::VortexResult;
1516
use vortex_error::vortex_ensure;
1617
use vortex_utils::aliases::dash_map::DashMap;
@@ -27,16 +28,15 @@ const CONVERGENCE_EPSILON: f64 = 1e-12;
2728
/// Number of numerical integration points for computing conditional expectations.
2829
const INTEGRATION_POINTS: usize = 1000;
2930

30-
// TODO(connor): Maybe we should just store an `ArrayRef` here?
3131
/// Global centroid cache keyed by (dimension, bit_width).
32-
static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Vec<f32>>> = LazyLock::new(DashMap::default);
32+
static CENTROID_CACHE: LazyLock<DashMap<(u32, u8), Buffer<f32>>> = LazyLock::new(DashMap::default);
3333

3434
/// Get or compute cached centroids for the given dimension and bit width.
3535
///
3636
/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar
3737
/// quantization levels for the coordinate distribution after random rotation in
3838
/// `dimension`-dimensional space.
39-
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Vec<f32>> {
39+
pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult<Buffer<f32>> {
4040
vortex_ensure!(
4141
(1..=MAX_BIT_WIDTH).contains(&bit_width),
4242
"TurboQuant bit_width must be 1-{}, got {bit_width}",
@@ -92,7 +92,7 @@ impl HalfIntExponent {
9292
/// The probability distribution function is:
9393
/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]`
9494
/// where `C_d` is the normalizing constant.
95-
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec<f32> {
95+
fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer<f32> {
9696
debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width));
9797
let num_centroids = 1usize << bit_width;
9898

@@ -288,7 +288,7 @@ mod tests {
288288
#[case(128, 4)]
289289
fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> {
290290
let centroids = get_centroids(dim, bits)?;
291-
for &val in &centroids {
291+
for &val in centroids.iter() {
292292
assert!(
293293
(-1.0..=1.0).contains(&val),
294294
"centroid out of [-1, 1]: {val}",

0 commit comments

Comments
 (0)