Skip to content

Commit 0170156

Browse files
authored
Tensor Matchers (#7300)
## Summary Instead of having weird helper functions extract information out of the FSL dtype, we can extract it into new matcher metadata via `AnyVector` and `AnyFixedShapeTensor` matchers. Note that this is _separate_ from the metadata that needs to be serialized to disk on the array vtable. ## Testing N/A since the logic should be the same, just in a different place. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent de4eb39 commit 0170156

File tree

17 files changed

+521
-171
lines changed

17 files changed

+521
-171
lines changed

vortex-array/public-api.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8756,9 +8756,9 @@ impl<V: vortex_array::dtype::extension::ExtVTable> vortex_array::dtype::extensio
87568756

87578757
pub type V::Match<'a> = &'a <V as vortex_array::dtype::extension::ExtVTable>::Metadata
87588758

8759-
pub fn V::matches(item: &vortex_array::dtype::extension::ExtDTypeRef) -> bool
8759+
pub fn V::matches(ext_dtype: &vortex_array::dtype::extension::ExtDTypeRef) -> bool
87608760

8761-
pub fn V::try_match<'a>(item: &'a vortex_array::dtype::extension::ExtDTypeRef) -> core::option::Option<<V as vortex_array::dtype::extension::Matcher>::Match>
8761+
pub fn V::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::ExtDTypeRef) -> core::option::Option<<V as vortex_array::dtype::extension::Matcher>::Match>
87628762

87638763
pub type vortex_array::dtype::extension::ExtDTypePluginRef = alloc::sync::Arc<dyn vortex_array::dtype::extension::ExtDTypePlugin>
87648764

vortex-array/src/dtype/extension/matcher.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ pub trait Matcher {
2222
impl<V: ExtVTable> Matcher for V {
2323
type Match<'a> = &'a V::Metadata;
2424

25-
fn matches(item: &ExtDTypeRef) -> bool {
26-
item.0.as_any().is::<ExtDType<V>>()
25+
fn matches(ext_dtype: &ExtDTypeRef) -> bool {
26+
ext_dtype.0.as_any().is::<ExtDType<V>>()
2727
}
2828

29-
fn try_match<'a>(item: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
30-
item.0
29+
fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
30+
ext_dtype
31+
.0
3132
.as_any()
3233
.downcast_ref::<ExtDType<V>>()
3334
.map(|inner| inner.metadata())

vortex-tensor/public-api.lock

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub const vortex_tensor::encodings::turboquant::TurboQuant::MIN_DIMENSION: u32
1414

1515
pub fn vortex_tensor::encodings::turboquant::TurboQuant::try_new_array(dtype: vortex_array::dtype::DType, codes: vortex_array::array::erased::ArrayRef, norms: vortex_array::array::erased::ArrayRef, centroids: vortex_array::array::erased::ArrayRef, rotation_signs: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult<vortex_tensor::encodings::turboquant::TurboQuantArray>
1616

17-
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<&vortex_array::dtype::extension::erased::ExtDTypeRef>
17+
pub fn vortex_tensor::encodings::turboquant::TurboQuant::validate_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_tensor::vector::VectorMatcherMetadata>
1818

1919
impl core::clone::Clone for vortex_tensor::encodings::turboquant::TurboQuant
2020

@@ -188,6 +188,14 @@ pub type vortex_tensor::encodings::turboquant::TurboQuantArray = vortex_array::a
188188

189189
pub mod vortex_tensor::fixed_shape
190190

191+
pub struct vortex_tensor::fixed_shape::AnyFixedShapeTensor
192+
193+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::fixed_shape::AnyFixedShapeTensor
194+
195+
pub type vortex_tensor::fixed_shape::AnyFixedShapeTensor::Match<'a> = vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
196+
197+
pub fn vortex_tensor::fixed_shape::AnyFixedShapeTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
198+
191199
pub struct vortex_tensor::fixed_shape::FixedShapeTensor
192200

193201
impl core::clone::Clone for vortex_tensor::fixed_shape::FixedShapeTensor
@@ -230,6 +238,34 @@ pub fn vortex_tensor::fixed_shape::FixedShapeTensor::unpack_native<'a>(_ext_dtyp
230238

231239
pub fn vortex_tensor::fixed_shape::FixedShapeTensor::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
232240

241+
pub struct vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
242+
243+
impl vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>
244+
245+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::element_ptype(&self) -> vortex_array::dtype::ptype::PType
246+
247+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::list_size(&self) -> usize
248+
249+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'_>::metadata(&self) -> &vortex_tensor::fixed_shape::FixedShapeTensorMetadata
250+
251+
impl<'a> core::clone::Clone for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
252+
253+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>::clone(&self) -> vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
254+
255+
impl<'a> core::cmp::Eq for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
256+
257+
impl<'a> core::cmp::PartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
258+
259+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>::eq(&self, other: &vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>) -> bool
260+
261+
impl<'a> core::fmt::Debug for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
262+
263+
pub fn vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
264+
265+
impl<'a> core::marker::Copy for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
266+
267+
impl<'a> core::marker::StructuralPartialEq for vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>
268+
233269
pub struct vortex_tensor::fixed_shape::FixedShapeTensorMetadata
234270

235271
impl vortex_tensor::fixed_shape::FixedShapeTensorMetadata
@@ -280,9 +316,19 @@ pub mod vortex_tensor::matcher
280316

281317
pub enum vortex_tensor::matcher::TensorMatch<'a>
282318

283-
pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(&'a vortex_tensor::fixed_shape::FixedShapeTensorMetadata)
319+
pub vortex_tensor::matcher::TensorMatch::FixedShapeTensor(vortex_tensor::fixed_shape::FixedShapeTensorMatcherMetadata<'a>)
320+
321+
pub vortex_tensor::matcher::TensorMatch::Vector(vortex_tensor::vector::VectorMatcherMetadata)
322+
323+
impl vortex_tensor::matcher::TensorMatch<'_>
324+
325+
pub fn vortex_tensor::matcher::TensorMatch<'_>::element_ptype(self) -> vortex_array::dtype::ptype::PType
326+
327+
pub fn vortex_tensor::matcher::TensorMatch<'_>::list_size(self) -> usize
284328

285-
pub vortex_tensor::matcher::TensorMatch::Vector
329+
impl<'a> core::clone::Clone for vortex_tensor::matcher::TensorMatch<'a>
330+
331+
pub fn vortex_tensor::matcher::TensorMatch<'a>::clone(&self) -> vortex_tensor::matcher::TensorMatch<'a>
286332

287333
impl<'a> core::cmp::Eq for vortex_tensor::matcher::TensorMatch<'a>
288334

@@ -294,6 +340,8 @@ impl<'a> core::fmt::Debug for vortex_tensor::matcher::TensorMatch<'a>
294340

295341
pub fn vortex_tensor::matcher::TensorMatch<'a>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
296342

343+
impl<'a> core::marker::Copy for vortex_tensor::matcher::TensorMatch<'a>
344+
297345
impl<'a> core::marker::StructuralPartialEq for vortex_tensor::matcher::TensorMatch<'a>
298346

299347
pub struct vortex_tensor::matcher::AnyTensor
@@ -302,7 +350,7 @@ impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::matcher
302350

303351
pub type vortex_tensor::matcher::AnyTensor::Match<'a> = vortex_tensor::matcher::TensorMatch<'a>
304352

305-
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(item: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
353+
pub fn vortex_tensor::matcher::AnyTensor::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
306354

307355
pub mod vortex_tensor::scalar_fns
308356

@@ -456,6 +504,14 @@ impl core::marker::StructuralPartialEq for vortex_tensor::scalar_fns::ApproxOpti
456504

457505
pub mod vortex_tensor::vector
458506

507+
pub struct vortex_tensor::vector::AnyVector
508+
509+
impl vortex_array::dtype::extension::matcher::Matcher for vortex_tensor::vector::AnyVector
510+
511+
pub type vortex_tensor::vector::AnyVector::Match<'a> = vortex_tensor::vector::VectorMatcherMetadata
512+
513+
pub fn vortex_tensor::vector::AnyVector::try_match<'a>(ext_dtype: &'a vortex_array::dtype::extension::erased::ExtDTypeRef) -> core::option::Option<Self::Match>
514+
459515
pub struct vortex_tensor::vector::Vector
460516

461517
impl core::clone::Clone for vortex_tensor::vector::Vector
@@ -498,4 +554,36 @@ pub fn vortex_tensor::vector::Vector::unpack_native<'a>(_ext_dtype: &'a vortex_a
498554

499555
pub fn vortex_tensor::vector::Vector::validate_dtype(ext_dtype: &vortex_array::dtype::extension::typed::ExtDType<Self>) -> vortex_error::VortexResult<()>
500556

557+
pub struct vortex_tensor::vector::VectorMatcherMetadata
558+
559+
impl vortex_tensor::vector::VectorMatcherMetadata
560+
561+
pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32
562+
563+
pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType
564+
565+
pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult<Self>
566+
567+
impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata
568+
569+
pub fn vortex_tensor::vector::VectorMatcherMetadata::clone(&self) -> vortex_tensor::vector::VectorMatcherMetadata
570+
571+
impl core::cmp::Eq for vortex_tensor::vector::VectorMatcherMetadata
572+
573+
impl core::cmp::PartialEq for vortex_tensor::vector::VectorMatcherMetadata
574+
575+
pub fn vortex_tensor::vector::VectorMatcherMetadata::eq(&self, other: &vortex_tensor::vector::VectorMatcherMetadata) -> bool
576+
577+
impl core::fmt::Debug for vortex_tensor::vector::VectorMatcherMetadata
578+
579+
pub fn vortex_tensor::vector::VectorMatcherMetadata::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
580+
581+
impl core::hash::Hash for vortex_tensor::vector::VectorMatcherMetadata
582+
583+
pub fn vortex_tensor::vector::VectorMatcherMetadata::hash<__H: core::hash::Hasher>(&self, state: &mut __H)
584+
585+
impl core::marker::Copy for vortex_tensor::vector::VectorMatcherMetadata
586+
587+
impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherMetadata
588+
501589
pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession)

vortex-tensor/src/encodings/turboquant/array/data.rs

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ use vortex_error::vortex_ensure_eq;
1515

1616
use crate::encodings::turboquant::array::slots::Slot;
1717
use crate::encodings::turboquant::vtable::TurboQuant;
18-
use crate::utils::tensor_element_ptype;
19-
use crate::utils::tensor_list_size;
2018

2119
/// TurboQuant array data.
2220
///
@@ -41,16 +39,13 @@ pub struct TurboQuantData {
4139
}
4240

4341
impl TurboQuantData {
44-
/// Build a TurboQuant array with validation.
45-
///
46-
/// The `dimension` and `bit_width` are derived from the inputs:
47-
/// - `dimension` from the `dtype`'s `FixedSizeList` storage list size.
48-
/// - `bit_width` from `log2(centroids.len())` (0 for degenerate empty arrays).
42+
/// Build a `TurboQuantData` with validation.
4943
///
5044
/// # Errors
5145
///
52-
/// Returns an error if the provided components do not satisfy the invariants documented
53-
/// in [`new_unchecked`](Self::new_unchecked).
46+
/// Returns an error if:
47+
/// - `dimension` is less than [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
48+
/// - `bit_width` is greater than 8.
5449
pub fn try_new(dimension: u32, bit_width: u8) -> VortexResult<Self> {
5550
vortex_ensure!(
5651
dimension >= TurboQuant::MIN_DIMENSION,
@@ -67,23 +62,14 @@ impl TurboQuantData {
6762
})
6863
}
6964

70-
/// Build a TurboQuant array without validation.
65+
/// Build a `TurboQuantData` without validation.
7166
///
7267
/// # Safety
7368
///
7469
/// The caller must ensure:
7570
///
76-
/// - `dtype` is a [`Vector`](crate::vector::Vector) extension type whose storage list size
77-
/// is >= [`MIN_DIMENSION`](crate::encodings::turboquant::TurboQuant::MIN_DIMENSION).
78-
/// - `codes` is a non-nullable `FixedSizeListArray<u8>` with `list_size == padded_dim` and
79-
/// `codes.len() == norms.len()`. Null vectors are represented by all-zero codes.
80-
/// - `norms` is a primitive array whose ptype matches the element type of the Vector's storage
81-
/// dtype. The nullability must match `dtype.nullability()`. Norms carry the validity of the
82-
/// entire array, since null vectors have null norms.
83-
/// - `centroids` is a non-nullable `PrimitiveArray<f32>` whose length is a power of 2 in
84-
/// `[2, 256]` (i.e., `2^bit_width` for bit_width 1-8), or empty for degenerate arrays.
85-
/// - `rotation_signs` has `3 * padded_dim` elements, or is empty for degenerate arrays.
86-
/// - For degenerate (empty) arrays: all children must be empty.
71+
/// - `dimension` is >= [`MIN_DIMENSION`](TurboQuant::MIN_DIMENSION).
72+
/// - `bit_width` is in the range `[0, 8]`.
8773
///
8874
/// Violating these invariants may produce incorrect results during decompression.
8975
pub unsafe fn new_unchecked(dimension: u32, bit_width: u8) -> Self {
@@ -103,8 +89,8 @@ impl TurboQuantData {
10389
centroids: &ArrayRef,
10490
rotation_signs: &ArrayRef,
10591
) -> VortexResult<()> {
106-
let ext = TurboQuant::validate_dtype(dtype)?;
107-
let dimension = tensor_list_size(ext)?;
92+
let vector_metadata = TurboQuant::validate_dtype(dtype)?;
93+
let dimension = vector_metadata.dimensions();
10894
let padded_dim = dimension.next_power_of_two();
10995

11096
// Codes must be a non-nullable FixedSizeList<u8> with list_size == padded_dim.
@@ -159,7 +145,7 @@ impl TurboQuantData {
159145

160146
// Norms dtype must match the element ptype of the Vector, with the parent's nullability.
161147
// Norms carry the validity of the entire TurboQuant array.
162-
let element_ptype = tensor_element_ptype(ext)?;
148+
let element_ptype = vector_metadata.element_ptype();
163149
let expected_norms_dtype = DType::Primitive(element_ptype, dtype.nullability());
164150
vortex_ensure_eq!(
165151
*norms.dtype(),

vortex-tensor/src/encodings/turboquant/array/scheme.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@ use vortex_compressor::CascadingCompressor;
99
use vortex_compressor::ctx::CompressorContext;
1010
use vortex_compressor::scheme::Scheme;
1111
use vortex_compressor::stats::ArrayAndStats;
12+
use vortex_error::VortexExpect;
1213
use vortex_error::VortexResult;
1314

1415
use crate::encodings::turboquant::TurboQuant;
1516
use crate::encodings::turboquant::TurboQuantConfig;
1617
use crate::encodings::turboquant::turboquant_encode;
17-
use crate::utils::tensor_element_ptype;
18-
use crate::utils::tensor_list_size;
1918

2019
/// TurboQuant compression scheme for [`Vector`] extension types.
2120
///
@@ -58,15 +57,16 @@ impl Scheme for TurboQuantScheme {
5857
let dtype = data.array().dtype();
5958
let len = data.array().len();
6059

61-
let ext = TurboQuant::validate_dtype(dtype)?;
62-
let element_ptype = tensor_element_ptype(ext)?;
63-
let dimension = tensor_list_size(ext)?;
60+
let vector_metadata =
61+
TurboQuant::validate_dtype(dtype).vortex_expect("invalid dtype for TurboQuant");
62+
let element_ptype = vector_metadata.element_ptype();
63+
let bit_width: u8 = element_ptype
64+
.bit_width()
65+
.try_into()
66+
.vortex_expect("invalid bit width for TurboQuant");
67+
let dimension = vector_metadata.dimensions();
6468

65-
Ok(estimate_compression_ratio(
66-
element_ptype.bit_width(),
67-
dimension,
68-
len,
69-
))
69+
Ok(estimate_compression_ratio(bit_width, dimension, len))
7070
}
7171

7272
fn compress(
@@ -84,7 +84,7 @@ impl Scheme for TurboQuantScheme {
8484
}
8585

8686
/// Estimate the compression ratio for TurboQuant MSE encoding with the default config.
87-
fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vectors: usize) -> f64 {
87+
fn estimate_compression_ratio(bits_per_element: u8, dimensions: u32, num_vectors: usize) -> f64 {
8888
let config = TurboQuantConfig::default();
8989
let padded_dim = dimensions.next_power_of_two() as usize;
9090

@@ -99,7 +99,7 @@ fn estimate_compression_ratio(bits_per_element: usize, dimensions: u32, num_vect
9999
+ 3 * padded_dim; // rotation signs, 1 bit each
100100

101101
let compressed_size_bits = compressed_bits_per_vector * num_vectors + overhead_bits;
102-
let uncompressed_size_bits = bits_per_element * num_vectors * dimensions as usize;
102+
let uncompressed_size_bits = bits_per_element as usize * dimensions as usize * num_vectors;
103103
uncompressed_size_bits as f64 / compressed_size_bits as f64
104104
}
105105

@@ -121,7 +121,7 @@ mod tests {
121121
#[case::f64_768d(64, 768, 1000, 5.0, 7.0)]
122122
#[case::f16_768d(16, 768, 1000, 1.2, 2.0)]
123123
fn compression_ratio_in_expected_range(
124-
#[case] bits_per_element: usize,
124+
#[case] bits_per_element: u8,
125125
#[case] dim: u32,
126126
#[case] num_vectors: usize,
127127
#[case] min_ratio: f64,
@@ -142,7 +142,7 @@ mod tests {
142142
#[case(32, 768, 10)]
143143
#[case(64, 256, 50)]
144144
fn ratio_always_greater_than_one(
145-
#[case] bits_per_element: usize,
145+
#[case] bits_per_element: u8,
146146
#[case] dim: u32,
147147
#[case] num_vectors: usize,
148148
) {

vortex-tensor/src/encodings/turboquant/compute/cosine_similarity.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use vortex_error::vortex_ensure_eq;
4545
use crate::encodings::turboquant::TurboQuant;
4646
use crate::encodings::turboquant::TurboQuantArrayExt;
4747
use crate::encodings::turboquant::array::float_from_f32;
48-
use crate::utils::tensor_element_ptype;
48+
use crate::vector::AnyVector;
4949

5050
/// Compute the per-row unit-norm dot products in f32 (centroids are always f32).
5151
///
@@ -109,7 +109,11 @@ pub fn cosine_similarity_quantized_column(
109109
"TurboQuant quantized dot product requires matching dimensions",
110110
);
111111

112-
let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?;
112+
let element_ptype = lhs
113+
.dtype()
114+
.as_extension()
115+
.metadata::<AnyVector>()
116+
.element_ptype();
113117
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
114118
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
115119

@@ -147,7 +151,11 @@ pub fn dot_product_quantized_column(
147151
"TurboQuant quantized dot product requires matching dimensions",
148152
);
149153

150-
let element_ptype = tensor_element_ptype(lhs.dtype().as_extension())?;
154+
let element_ptype = lhs
155+
.dtype()
156+
.as_extension()
157+
.metadata::<AnyVector>()
158+
.element_ptype();
151159
let validity = lhs.norms().validity()?.and(rhs.norms().validity()?)?;
152160
let dots = compute_unit_dots(&lhs, &rhs, ctx)?;
153161
let num_rows = lhs.norms().len();

vortex-tensor/src/encodings/turboquant/decompress.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ use crate::encodings::turboquant::TurboQuant;
2424
use crate::encodings::turboquant::TurboQuantArrayExt;
2525
use crate::encodings::turboquant::array::float_from_f32;
2626
use crate::encodings::turboquant::array::rotation::RotationMatrix;
27-
use crate::utils::tensor_element_ptype;
27+
use crate::vector::AnyVector;
2828

2929
/// Decompress a `TurboQuantArray` into a [`Vector`] extension array.
3030
///
3131
/// The returned array is an [`ExtensionArray`] with the original Vector dtype wrapping a
32-
/// `FixedSizeListArray` of f32 elements.
32+
/// `FixedSizeListArray` of the original vector element type.
3333
///
3434
/// [`Vector`]: crate::vector::Vector
3535
pub fn execute_decompress(
@@ -40,7 +40,7 @@ pub fn execute_decompress(
4040
let padded_dim = array.padded_dim() as usize;
4141
let num_rows = array.norms().len();
4242
let ext_dtype = array.dtype().as_extension().clone();
43-
let element_ptype = tensor_element_ptype(&ext_dtype)?;
43+
let element_ptype = ext_dtype.metadata::<AnyVector>().element_ptype();
4444

4545
if num_rows == 0 {
4646
let fsl_validity = Validity::from(ext_dtype.storage_dtype().nullability());

0 commit comments

Comments
 (0)