Skip to content

Commit c0f6037

Browse files
committed
fix input dtype
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 45ee193 commit c0f6037

6 files changed

Lines changed: 195 additions & 162 deletions

File tree

vortex-tensor/src/encodings/turboquant/compress.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@ use vortex_array::ArrayView;
1515
use vortex_array::ExecutionCtx;
1616
use vortex_array::IntoArray;
1717
use vortex_array::arrays::Extension;
18+
use vortex_array::arrays::ExtensionArray;
1819
use vortex_array::arrays::FixedSizeListArray;
1920
use vortex_array::arrays::PrimitiveArray;
2021
use vortex_array::arrays::dict::DictArray;
2122
use vortex_array::arrays::extension::ExtensionArrayExt;
2223
use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt;
2324
use vortex_array::dtype::Nullability;
25+
use vortex_array::dtype::extension::ExtDType;
26+
use vortex_array::extension::EmptyMetadata;
2427
use vortex_array::validity::Validity;
2528
use vortex_buffer::BufferMut;
2629
use vortex_error::VortexExpect;
@@ -38,6 +41,7 @@ use crate::scalar_fns::sorf_transform::SorfOptions;
3841
use crate::scalar_fns::sorf_transform::SorfTransform;
3942
use crate::utils::cast_to_f32;
4043
use crate::vector::AnyVector;
44+
use crate::vector::Vector;
4145

4246
/// Configuration for TurboQuant encoding.
4347
#[derive(Clone, Debug)]
@@ -236,6 +240,7 @@ pub unsafe fn turboquant_encode_unchecked(
236240
Validity::NonNullable,
237241
0,
238242
)?;
243+
let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?;
239244

240245
let sorf_options = SorfOptions {
241246
seed,
@@ -244,19 +249,27 @@ pub unsafe fn turboquant_encode_unchecked(
244249
element_ptype,
245250
};
246251
return Ok(
247-
SorfTransform::try_new_array(&sorf_options, empty_fsl.into_array(), 0)?.into_array(),
252+
SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(),
248253
);
249254
}
250255

251256
let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?;
252257
let quantized_fsl =
253258
build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?;
259+
let padded_vector = wrap_padded_as_vector(quantized_fsl)?;
254260

255261
let sorf_options = SorfOptions {
256262
seed,
257263
num_rounds: config.num_rounds,
258264
dimension,
259265
element_ptype,
260266
};
261-
Ok(SorfTransform::try_new_array(&sorf_options, quantized_fsl, num_rows)?.into_array())
267+
Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array())
268+
}
269+
270+
/// Wrap an `FSL<f32, padded_dim>` in a [`Vector`](crate::vector::Vector) extension so it can be
271+
/// passed as the child of [`SorfTransform`], which expects a `Vector<padded_dim>` input.
272+
fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult<ArrayRef> {
273+
let ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
274+
Ok(ExtensionArray::new(ext_dtype, fsl).into_array())
262275
}

vortex-tensor/src/encodings/turboquant/tests/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,11 @@ fn unwrap_codes_centroids_norms(
128128
ctx: &mut vortex_array::ExecutionCtx,
129129
) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> {
130130
let (sorf_child, norms_child) = unwrap_l2denorm(encoded);
131-
let fsl_child = unwrap_sorf(&sorf_child);
131+
let padded_vector_child = unwrap_sorf(&sorf_child);
132132

133-
// FSL(Dict(codes, centroids))
134-
let fsl: FixedSizeListArray = fsl_child.execute(ctx)?;
133+
// Vector<padded_dim> wrapping FSL(Dict(codes, centroids))
134+
let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?;
135+
let fsl: FixedSizeListArray = padded_vector.storage_array().clone().execute(ctx)?;
135136
let dict = fsl
136137
.elements()
137138
.as_opt::<Dict>()

vortex-tensor/src/encodings/turboquant/tests/structural.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ fn dot_product_quantized_accuracy() -> VortexResult<()> {
225225
fn sorf_transform_roundtrip_isolation() -> VortexResult<()> {
226226
use vortex_array::IntoArray;
227227
use vortex_array::arrays::dict::DictArray;
228+
use vortex_array::dtype::extension::ExtDType;
229+
use vortex_array::extension::EmptyMetadata;
228230
use vortex_array::validity::Validity;
229231
use vortex_buffer::BufferMut;
230232

@@ -234,6 +236,7 @@ fn sorf_transform_roundtrip_isolation() -> VortexResult<()> {
234236
use crate::scalar_fns::sorf_transform::SorfMatrix;
235237
use crate::scalar_fns::sorf_transform::SorfOptions;
236238
use crate::scalar_fns::sorf_transform::SorfTransform;
239+
use crate::vector::Vector;
237240

238241
let dim = 128usize;
239242
let seed = 99u64;
@@ -287,14 +290,20 @@ fn sorf_transform_roundtrip_isolation() -> VortexResult<()> {
287290
num_rows,
288291
)?;
289292

293+
// Wrap the padded FSL in a Vector extension so it can be the SorfTransform child.
294+
let padded_vector_dtype =
295+
ExtDType::<Vector>::try_new(EmptyMetadata, fsl.dtype().clone())?.erased();
296+
let padded_vector = ExtensionArray::new(padded_vector_dtype, fsl.into_array());
297+
290298
// Wrap in SorfTransform and execute.
291299
let sorf_options = SorfOptions {
292300
seed,
293301
num_rounds,
294302
dimension: dim as u32,
295303
element_ptype: vortex_array::dtype::PType::F32,
296304
};
297-
let sorf_array = SorfTransform::try_new_array(&sorf_options, fsl.into_array(), num_rows)?;
305+
let sorf_array =
306+
SorfTransform::try_new_array(&sorf_options, padded_vector.into_array(), num_rows)?;
298307

299308
let mut ctx = SESSION.create_execution_ctx();
300309
let result: ExtensionArray = sorf_array.into_array().execute(&mut ctx)?;

vortex-tensor/src/scalar_fns/sorf_transform/mod.rs

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,35 @@
88
//! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost
99
//! of a dense orthogonal matrix.
1010
//!
11-
//! This module wraps an FSL child (e.g. `FSL(Dict(codes, centroids))`) and applies the inverse SORF
12-
//! transform at execution time, producing a [`Vector`] extension array with the original
13-
//! (pre-padding) dimensionality.
11+
//! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension
12+
//! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform
13+
//! at execution time, producing a [`Vector`] extension array with the original (pre-padding)
14+
//! dimensionality.
1415
//!
1516
//! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the
1617
//! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's
1718
//! frozen local SplitMix64 stream contract rather than by an external RNG crate.
1819
//!
19-
//! **All SORF computation happens in f32.** Input elements of other float types (f16, f64) are cast
20-
//! to f32 before the transform, and the result is cast back to the target type specified by
21-
//! [`SorfOptions::element_ptype`].
20+
//! # Input element type: `f32` only (TODO(connor): for now...)
21+
//!
22+
//! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is
23+
//! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need
24+
//! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform.
25+
//!
26+
//! The reason for this constraint is that TurboQuant (the only production caller today) stores its
27+
//! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`.
28+
//!
29+
//! Supporting other float storage types would require an implicit up-/down-cast that we do not yet
30+
//! want to bake into SorfTransform. This restriction is intentional and may be relaxed in the
31+
//! future, but today it is load-bearing.
32+
//!
33+
//! # Output element type
34+
//!
35+
//! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It
36+
//! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast
37+
//! while materializing the output. This lets SorfTransform hand its result directly to a
38+
//! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose
39+
//! element-type expectation may differ from the `f32` the transform operated on internally.
2240
//!
2341
//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf
2442
//! [`Vector`]: crate::vector::Vector
@@ -41,9 +59,13 @@ mod vtable;
4159

4260
/// Inverse SORF orthogonal transform scalar function.
4361
///
44-
/// Applies the inverse structured Walsh-Hadamard orthogonal transform to an FSL child,
45-
/// truncates from padded dimension to the original dimension, casts to the target element
46-
/// type, and wraps in a [`Vector`](crate::vector::Vector) extension array.
62+
/// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32`
63+
/// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the
64+
/// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and
65+
/// wraps the result in a new [`Vector`](crate::vector::Vector) extension array.
66+
///
67+
/// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the
68+
/// `f32`-only input constraint.
4769
#[derive(Clone)]
4870
pub struct SorfTransform;
4971

@@ -57,9 +79,12 @@ pub struct SorfOptions {
5779
pub seed: u64,
5880
/// Number of sign-diagonal + WHT rounds in the structured orthogonal transform.
5981
pub num_rounds: u8,
60-
/// Original vector dimension (before power-of-2 padding).
82+
/// Original vector dimension (before power-of-2 padding). The output
83+
/// [`Vector`](crate::vector::Vector) has this dimension.
6184
pub dimension: u32,
62-
/// Target output element type (e.g. `F16`, `F32`, `F64`).
85+
/// Element type of the output [`Vector`](crate::vector::Vector). The child input must always
86+
/// be `f32`, but the output can be any float type (`F16`, `F32`, `F64`); the final
87+
/// `f32 -> element_ptype` cast happens while building the output.
6388
pub element_ptype: PType,
6489
}
6590

@@ -71,8 +96,16 @@ impl SorfTransform {
7196

7297
/// Constructs a validated [`ScalarFnArray`] that lazily applies the inverse SORF transform.
7398
///
74-
/// The `child` must be a `FixedSizeList` (or array that executes to one) with logical float
75-
/// elements and `list_size == padded_dim` (i.e. `dimension.next_power_of_two()`).
99+
/// The `child` must be a [`Vector`] extension array (or an array that executes to one) with:
100+
///
101+
/// - dimension equal to `padded_dim` (i.e. `options.dimension.next_power_of_two()`), and
102+
/// - `f32` storage elements. This is a hard requirement today; see the
103+
/// [module-level docs](crate::scalar_fns::sorf_transform) for the rationale.
104+
///
105+
/// The output [`Vector`] has dimension `options.dimension` and element type
106+
/// `options.element_ptype`.
107+
///
108+
/// [`Vector`]: crate::vector::Vector
76109
pub fn try_new_array(
77110
options: &SorfOptions,
78111
child: ArrayRef,

0 commit comments

Comments
 (0)