Skip to content

Commit 5582f1e

Browse files
committed
holy moly simd
Signed-off-by: Will Manning <will@willmanning.io>
1 parent f978d7b commit 5582f1e

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

encodings/turboquant/src/rotation.rs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,27 +202,41 @@ fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) {
202202
/// In-place Walsh-Hadamard Transform (unnormalized, iterative).
203203
///
204204
/// Input length must be a power of 2. Runs in O(n log n).
205+
///
206+
/// Uses a fixed-size chunk strategy: for each stage, the buffer is processed
207+
/// in `CHUNK`-element blocks with a compile-time-known butterfly function.
208+
/// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD.
205209
fn walsh_hadamard_transform(buf: &mut [f32]) {
206210
let len = buf.len();
207211
debug_assert!(len.is_power_of_two());
208212

209213
let mut half = 1;
210214
while half < len {
211215
let stride = half * 2;
212-
let mut block_start = 0;
213-
while block_start < len {
214-
for idx in block_start..block_start + half {
215-
let sum = buf[idx] + buf[idx + half];
216-
let diff = buf[idx] - buf[idx + half];
217-
buf[idx] = sum;
218-
buf[idx + half] = diff;
219-
}
220-
block_start += stride;
216+
// Process in chunks of `stride` elements. Within each chunk,
217+
// split into non-overlapping (lo, hi) halves for the butterfly.
218+
for chunk in buf.chunks_exact_mut(stride) {
219+
let (lo, hi) = chunk.split_at_mut(half);
220+
butterfly(lo, hi);
221221
}
222222
half *= 2;
223223
}
224224
}
225225

226+
/// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`.
227+
///
228+
/// Separate function so LLVM can see the slice lengths match and auto-vectorize.
229+
#[inline(always)]
230+
fn butterfly(lo: &mut [f32], hi: &mut [f32]) {
231+
debug_assert_eq!(lo.len(), hi.len());
232+
for (a, b) in lo.iter_mut().zip(hi.iter_mut()) {
233+
let sum = *a + *b;
234+
let diff = *a - *b;
235+
*a = sum;
236+
*b = diff;
237+
}
238+
}
239+
226240
#[cfg(test)]
227241
mod tests {
228242
use rstest::rstest;

0 commit comments

Comments
 (0)