88//! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost
99//! of a dense orthogonal matrix.
1010//!
11- //! This module wraps an FSL child (e.g. `FSL(Dict(codes, centroids))`) and applies the inverse SORF
12- //! transform at execution time, producing a [`Vector`] extension array with the original
13- //! (pre-padding) dimensionality.
11+ //! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension
12+ //! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform
13+ //! at execution time, producing a [`Vector`] extension array with the original (pre-padding)
14+ //! dimensionality.
1415//!
1516//! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the
1617//! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's
1718//! frozen local SplitMix64 stream contract rather than by an external RNG crate.
1819//!
19- //! **All SORF computation happens in f32.** Input elements of other float types (f16, f64) are cast
20- //! to f32 before the transform, and the result is cast back to the target type specified by
21- //! [`SorfOptions::element_ptype`].
20+ //! # Input element type: `f32` only (TODO(connor): for now...)
21+ //!
22+ //! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is
23+ //! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need
24+ //! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform.
25+ //!
26+ //! The reason for this constraint is that TurboQuant (the only production caller today) stores its
27+ //! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`.
28+ //!
29+ //! Supporting other float storage types would require an implicit up-/down-cast that we do not yet
30+ //! want to bake into SorfTransform. This restriction is intentional and may be relaxed in the
31+ //! future, but today it is load-bearing.
32+ //!
33+ //! # Output element type
34+ //!
35+ //! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It
36+ //! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast
37+ //! while materializing the output. This lets SorfTransform hand its result directly to a
38+ //! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose
39+ //! element-type expectation may differ from the `f32` the transform operated on internally.
2240//!
2341//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf
2442//! [`Vector`]: crate::vector::Vector
@@ -41,9 +59,13 @@ mod vtable;
4159
4260/// Inverse SORF orthogonal transform scalar function.
4361///
44- /// Applies the inverse structured Walsh-Hadamard orthogonal transform to an FSL child,
45- /// truncates from padded dimension to the original dimension, casts to the target element
46- /// type, and wraps in a [`Vector`](crate::vector::Vector) extension array.
62+ /// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32`
63+ /// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the
64+ /// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and
65+ /// wraps the result in a new [`Vector`](crate::vector::Vector) extension array.
66+ ///
67+ /// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the
68+ /// `f32`-only input constraint.
4769#[ derive( Clone ) ]
4870pub struct SorfTransform ;
4971
@@ -57,9 +79,12 @@ pub struct SorfOptions {
5779 pub seed : u64 ,
5880 /// Number of sign-diagonal + WHT rounds in the structured orthogonal transform.
5981 pub num_rounds : u8 ,
60- /// Original vector dimension (before power-of-2 padding).
82+ /// Original vector dimension (before power-of-2 padding). The output
83+ /// [`Vector`](crate::vector::Vector) has this dimension.
6184 pub dimension : u32 ,
62- /// Target output element type (e.g. `F16`, `F32`, `F64`).
85+ /// Element type of the output [`Vector`](crate::vector::Vector). The child input must always
86+ /// be `f32`, but the output can be any float type (`F16`, `F32`, `F64`); the final
87+ /// `f32 -> element_ptype` cast happens while building the output.
6388 pub element_ptype : PType ,
6489}
6590
@@ -71,8 +96,16 @@ impl SorfTransform {
7196
7297 /// Constructs a validated [`ScalarFnArray`] that lazily applies the inverse SORF transform.
7398 ///
74- /// The `child` must be a `FixedSizeList` (or array that executes to one) with logical float
75- /// elements and `list_size == padded_dim` (i.e. `dimension.next_power_of_two()`).
99+ /// The `child` must be a [`Vector`] extension array (or an array that executes to one) with:
100+ ///
101+ /// - dimension equal to `padded_dim` (i.e. `options.dimension.next_power_of_two()`), and
102+ /// - `f32` storage elements. This is a hard requirement today; see the
103+ /// [module-level docs](crate::scalar_fns::sorf_transform) for the rationale.
104+ ///
105+ /// The output [`Vector`] has dimension `options.dimension` and element type
106+ /// `options.element_ptype`.
107+ ///
108+ /// [`Vector`]: crate::vector::Vector
76109 pub fn try_new_array (
77110 options : & SorfOptions ,
78111 child : ArrayRef ,
0 commit comments