66//! The TurboQuant paper analyzes a full random orthogonal rotation. The current implementation
77//! uses a cheaper structured Walsh-Hadamard-based surrogate instead of a dense d x d matrix.
88//!
9- //! Concretely, this applies three rounds of random sign diagonals interleaved with the
10- //! Walsh-Hadamard Transform: D3 * H * D2 * H * D1 * H, followed by normalization. This is a
11- //! SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d)
9+ //! Concretely, this applies three rounds of random sign diagonals interleaved with the Fast
10+ //! Walsh-Hadamard Transform (FWHT): ` D3 * H * D2 * H * D1 * H` , followed by normalization. This is
11+ //! a SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d)
1212//! encode/decode cost and compact serialized parameters.
1313//!
14- //! For dimensions that are not powers of 2, the input is zero-padded to the
15- //! next power of 2 before the transform and truncated afterward.
14+ //! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2
15+ //! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n)
16+ //! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard
17+ //! matrix is ever materialized.
18+ //!
19+ //! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before
20+ //! the transform and truncated afterward.
1621//!
1722//! # Sign representation
1823//!
19- //! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op)
20- //! and `0x80000000` for -1 (flip IEEE 754 sign bit). The sign application
21- //! function uses integer XOR instead of floating-point multiply, which avoids
22- //! FP dependency chains and auto-vectorizes into `vpxor`/`veor`.
24+ //! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) and `0x80000000` for
25+ //! -1 (flip IEEE 754 sign bit). The sign application function uses integer XOR instead of
26+ //! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into
27+ //! `vpxor`/`veor`.
2328
2429use rand:: RngExt ;
2530use rand:: SeedableRng ;
@@ -32,8 +37,8 @@ const F32_SIGN_BIT: u32 = 0x8000_0000;
3237
3338/// A Walsh-Hadamard-based structured surrogate for a random orthogonal rotation.
3439pub struct RotationMatrix {
35- /// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`.
36- /// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit).
40+ /// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`. `0x00000000` =
41+ /// multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit).
3742 sign_masks : [ Vec < u32 > ; 3 ] ,
3843 /// The padded dimension (next power of 2 >= dimension).
3944 padded_dim : usize ,
@@ -59,8 +64,8 @@ impl RotationMatrix {
5964
6065 /// Apply forward rotation: `output = R(input)`.
6166 ///
62- /// Both `input` and `output` must have length `padded_dim()`. The caller
63- /// is responsible for zero-padding input beyond `dim` positions.
67+ /// Both `input` and `output` must have length [ `padded_dim()`](Self::padded_dim) . The caller is
68+ /// responsible for zero-padding input beyond `dim` positions.
6469 pub fn rotate ( & self , input : & [ f32 ] , output : & mut [ f32 ] ) {
6570 debug_assert_eq ! ( input. len( ) , self . padded_dim) ;
6671 debug_assert_eq ! ( output. len( ) , self . padded_dim) ;
@@ -120,12 +125,11 @@ impl RotationMatrix {
120125 buf. iter_mut ( ) . for_each ( |val| * val *= norm) ;
121126 }
122127
123- /// Export the 3 sign vectors as a flat `Vec<u8>` of 0/1 values in inverse
124- /// application order `[D₃ | D₂ | D₁]`.
128+ /// Export the 3 sign vectors as a flat `Vec<u8>` of 0/1 values in inverse application order
129+ /// `[D₃ | D₂ | D₁]`.
125130 ///
126- /// Convention: `1` = positive (+1), `0` = negative (-1).
127- /// The output has length `3 * padded_dim` and is suitable for bitpacking
128- /// via FastLanes `bitpack_encode(..., 1, None)`.
131+ /// Convention: `1` = positive (+1), `0` = negative (-1). The output has length `3 * padded_dim`
132+ /// and is suitable for bitpacking via FastLanes `bitpack_encode(..., 1, None)`.
129133 pub fn export_inverse_signs_u8 ( & self ) -> Vec < u8 > {
130134 let total = 3 * self . padded_dim ;
131135 let mut out = Vec :: with_capacity ( total) ;
@@ -139,14 +143,14 @@ impl RotationMatrix {
139143 out
140144 }
141145
142- /// Reconstruct a `RotationMatrix` from unpacked `u8` 0/1 values.
146+ /// Reconstruct a [ `RotationMatrix`] from unpacked `u8` 0/1 values.
143147 ///
144- /// The input must have length `3 * padded_dim` with signs in inverse
145- /// application order `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]).
146- /// Convention: `1` = positive, `0` = negative.
148+ /// The input must have length `3 * padded_dim` with signs in inverse application order
149+ /// `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]). Convention: `1` = positive,
150+ /// `0` = negative.
147151 ///
148- /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the
149- /// stored `BitPackedArray` into `&[u8]`, which is passed here.
152+ /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the stored
153+ /// [ `BitPackedArray`] into `&[u8]`, which is passed here.
150154 pub fn from_u8_slice ( signs_u8 : & [ u8 ] , dimension : usize ) -> VortexResult < Self > {
151155 let padded_dim = dimension. next_power_of_two ( ) ;
152156 vortex_ensure ! (
@@ -192,22 +196,23 @@ fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec<u32> {
192196
193197/// Apply sign masks via XOR on the IEEE 754 sign bit.
194198///
195- /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM).
196- /// Equivalent to multiplying each element by ± 1.0, but avoids FP dependency chains.
199+ /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to
200+ /// multiplying each element by +/- 1.0, but avoids FP dependency chains.
197201#[ inline]
198202fn apply_signs_xor ( buf : & mut [ f32 ] , masks : & [ u32 ] ) {
199203 for ( val, & mask) in buf. iter_mut ( ) . zip ( masks. iter ( ) ) {
200204 * val = f32:: from_bits ( val. to_bits ( ) ^ mask) ;
201205 }
202206}
203207
204- /// In-place Walsh-Hadamard Transform (unnormalized, iterative) .
208+ /// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative.
205209///
206- /// Input length must be a power of 2. Runs in O(n log n).
210+ /// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2`
211+ /// [`butterfly`] operations each. See the [module-level docs](self) for why this avoids
212+ /// materializing the full Hadamard matrix.
207213///
208- /// Uses a fixed-size chunk strategy: for each stage, the buffer is processed
209- /// in `CHUNK`-element blocks with a compile-time-known butterfly function.
210- /// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD.
214+ /// The chunk-based iteration gives LLVM enough structure to auto-vectorize each butterfly call
215+ /// into NEON/AVX SIMD instructions.
211216fn walsh_hadamard_transform ( buf : & mut [ f32 ] ) {
212217 let len = buf. len ( ) ;
213218 debug_assert ! ( len. is_power_of_two( ) ) ;
@@ -225,9 +230,11 @@ fn walsh_hadamard_transform(buf: &mut [f32]) {
225230 }
226231}
227232
228- /// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`.
233+ /// Butterfly: `( lo[i], hi[i]) -> ( lo[i] + hi[i], lo[i] - hi[i]) `.
229234///
230- /// Separate function so LLVM can see the slice lengths match and auto-vectorize.
235+ /// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element
236+ /// pair. Factored into a separate function so LLVM can see the slice lengths match and
237+ /// auto-vectorize.
231238#[ inline( always) ]
232239fn butterfly ( lo : & mut [ f32 ] , hi : & mut [ f32 ] ) {
233240 debug_assert_eq ! ( lo. len( ) , hi. len( ) ) ;
0 commit comments