Skip to content

Commit ec28916

Browse files
committed
clean up rotation docs
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent bbb371c commit ec28916

1 file changed

Lines changed: 40 additions & 33 deletions

File tree

  • vortex-tensor/src/encodings/turboquant/array

vortex-tensor/src/encodings/turboquant/array/rotation.rs

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,25 @@
66
//! The TurboQuant paper analyzes a full random orthogonal rotation. The current implementation
77
//! uses a cheaper structured Walsh-Hadamard-based surrogate instead of a dense d x d matrix.
88
//!
9-
//! Concretely, this applies three rounds of random sign diagonals interleaved with the
10-
//! Walsh-Hadamard Transform: D3 * H * D2 * H * D1 * H, followed by normalization. This is a
11-
//! SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d)
9+
//! Concretely, this applies three rounds of random sign diagonals interleaved with the Fast
10+
//! Walsh-Hadamard Transform (FWHT): `D3 * H * D2 * H * D1 * H`, followed by normalization. This is
11+
//! a SORF-style structured approximation to a random orthogonal matrix, chosen for O(d log d)
1212
//! encode/decode cost and compact serialized parameters.
1313
//!
14-
//! For dimensions that are not powers of 2, the input is zero-padded to the
15-
//! next power of 2 before the transform and truncated afterward.
14+
//! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2
15+
//! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n)
16+
//! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard
17+
//! matrix is ever materialized.
18+
//!
19+
//! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before
20+
//! the transform and truncated afterward.
1621
//!
1722
//! # Sign representation
1823
//!
19-
//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op)
20-
//! and `0x80000000` for -1 (flip IEEE 754 sign bit). The sign application
21-
//! function uses integer XOR instead of floating-point multiply, which avoids
22-
//! FP dependency chains and auto-vectorizes into `vpxor`/`veor`.
24+
//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) and `0x80000000` for
25+
//! -1 (flip IEEE 754 sign bit). The sign application function uses integer XOR instead of
26+
//! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into
27+
//! `vpxor`/`veor`.
2328
2429
use rand::RngExt;
2530
use rand::SeedableRng;
@@ -32,8 +37,8 @@ const F32_SIGN_BIT: u32 = 0x8000_0000;
3237

3338
/// A Walsh-Hadamard-based structured surrogate for a random orthogonal rotation.
3439
pub struct RotationMatrix {
35-
/// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`.
36-
/// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit).
40+
/// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`. `0x00000000` =
41+
/// multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit).
3742
sign_masks: [Vec<u32>; 3],
3843
/// The padded dimension (next power of 2 >= dimension).
3944
padded_dim: usize,
@@ -59,8 +64,8 @@ impl RotationMatrix {
5964

6065
/// Apply forward rotation: `output = R(input)`.
6166
///
62-
/// Both `input` and `output` must have length `padded_dim()`. The caller
63-
/// is responsible for zero-padding input beyond `dim` positions.
67+
/// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is
68+
/// responsible for zero-padding input beyond `dim` positions.
6469
pub fn rotate(&self, input: &[f32], output: &mut [f32]) {
6570
debug_assert_eq!(input.len(), self.padded_dim);
6671
debug_assert_eq!(output.len(), self.padded_dim);
@@ -120,12 +125,11 @@ impl RotationMatrix {
120125
buf.iter_mut().for_each(|val| *val *= norm);
121126
}
122127

123-
/// Export the 3 sign vectors as a flat `Vec<u8>` of 0/1 values in inverse
124-
/// application order `[D₃ | D₂ | D₁]`.
128+
/// Export the 3 sign vectors as a flat `Vec<u8>` of 0/1 values in inverse application order
129+
/// `[D₃ | D₂ | D₁]`.
125130
///
126-
/// Convention: `1` = positive (+1), `0` = negative (-1).
127-
/// The output has length `3 * padded_dim` and is suitable for bitpacking
128-
/// via FastLanes `bitpack_encode(..., 1, None)`.
131+
/// Convention: `1` = positive (+1), `0` = negative (-1). The output has length `3 * padded_dim`
132+
/// and is suitable for bitpacking via FastLanes `bitpack_encode(..., 1, None)`.
129133
pub fn export_inverse_signs_u8(&self) -> Vec<u8> {
130134
let total = 3 * self.padded_dim;
131135
let mut out = Vec::with_capacity(total);
@@ -139,14 +143,14 @@ impl RotationMatrix {
139143
out
140144
}
141145

142-
/// Reconstruct a `RotationMatrix` from unpacked `u8` 0/1 values.
146+
/// Reconstruct a [`RotationMatrix`] from unpacked `u8` 0/1 values.
143147
///
144-
/// The input must have length `3 * padded_dim` with signs in inverse
145-
/// application order `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]).
146-
/// Convention: `1` = positive, `0` = negative.
148+
/// The input must have length `3 * padded_dim` with signs in inverse application order
149+
/// `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]). Convention: `1` = positive,
150+
/// `0` = negative.
147151
///
148-
/// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the
149-
/// stored `BitPackedArray` into `&[u8]`, which is passed here.
152+
/// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the stored
153+
/// [`BitPackedArray`] into `&[u8]`, which is passed here.
150154
pub fn from_u8_slice(signs_u8: &[u8], dimension: usize) -> VortexResult<Self> {
151155
let padded_dim = dimension.next_power_of_two();
152156
vortex_ensure!(
@@ -192,22 +196,23 @@ fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec<u32> {
192196

193197
/// Apply sign masks via XOR on the IEEE 754 sign bit.
194198
///
195-
/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM).
196-
/// Equivalent to multiplying each element by ±1.0, but avoids FP dependency chains.
199+
/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to
200+
/// multiplying each element by +/-1.0, but avoids FP dependency chains.
197201
#[inline]
198202
fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) {
199203
for (val, &mask) in buf.iter_mut().zip(masks.iter()) {
200204
*val = f32::from_bits(val.to_bits() ^ mask);
201205
}
202206
}
203207

204-
/// In-place Walsh-Hadamard Transform (unnormalized, iterative).
208+
/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative.
205209
///
206-
/// Input length must be a power of 2. Runs in O(n log n).
210+
/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2`
211+
/// [`butterfly`] operations each. See the [module-level docs](self) for why this avoids
212+
/// materializing the full Hadamard matrix.
207213
///
208-
/// Uses a fixed-size chunk strategy: for each stage, the buffer is processed
209-
/// in `CHUNK`-element blocks with a compile-time-known butterfly function.
210-
/// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD.
214+
/// The chunk-based iteration gives LLVM enough structure to auto-vectorize each butterfly call
215+
/// into NEON/AVX SIMD instructions.
211216
fn walsh_hadamard_transform(buf: &mut [f32]) {
212217
let len = buf.len();
213218
debug_assert!(len.is_power_of_two());
@@ -225,9 +230,11 @@ fn walsh_hadamard_transform(buf: &mut [f32]) {
225230
}
226231
}
227232

228-
/// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`.
233+
/// Butterfly: `(lo[i], hi[i]) -> (lo[i] + hi[i], lo[i] - hi[i])`.
229234
///
230-
/// Separate function so LLVM can see the slice lengths match and auto-vectorize.
235+
/// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element
236+
/// pair. Factored into a separate function so LLVM can see the slice lengths match and
237+
/// auto-vectorize.
231238
#[inline(always)]
232239
fn butterfly(lo: &mut [f32], hi: &mut [f32]) {
233240
debug_assert_eq!(lo.len(), hi.len());

0 commit comments

Comments
 (0)