Skip to content

Commit 08c6486

Browse files
committed
Proper F64x8 SIMD: gather + 8-row batch + mul_add finalize
Replace naive SIMD with structured projection: - gather_bf16_x8: explicit 8-lane gather from row offsets - project_8rows_bf16_simd: 17 F64x8 accumulators (1088 bytes stack), halftone odd-bin interpolation in SIMD (normalize→average), mul_add finalization with simd_clamp - project_1row_bf16_strided: scalar fallback matching SIMD algorithm - project_tensor_bf16_simd: dispatches to 8-row batches + scalar tail - 3 new tests: constant agreement, scalar parity, tail handling https://claude.ai/code/session_01HmdXNPit7QsTCfhJFef3Ee
1 parent 763351b commit 08c6486

1 file changed

Lines changed: 216 additions & 61 deletions

File tree

src/hpc/gguf_indexer.rs

Lines changed: 216 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -251,98 +251,202 @@ pub fn project_row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {
251251
Base17 { dims }
252252
}
253253

254-
// ── SIMD 8-row-parallel tensor projection ──
254+
// ── F64x8 SIMD: 8 rows → 8 Base17 in parallel ──
255255

256-
/// Project an entire BF16 tensor to Base17 using F64x8 SIMD.
256+
/// Gather 8 BF16 values from 8 rows at the same column, convert to F64x8.
257257
///
258-
/// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
259-
/// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
258+
/// The gather is scalar (8 indexed loads) but the result is SIMD.
259+
/// At -O2 with AVX-512, rustc may emit vpgatherqd + shift + vcvtps2pd.
260+
#[inline(always)]
261+
fn gather_bf16_x8(buf: &[u16], offsets: &[usize; 8]) -> crate::simd::F64x8 {
262+
crate::simd::F64x8::from_array([
263+
bf16_to_f64(buf[offsets[0]]),
264+
bf16_to_f64(buf[offsets[1]]),
265+
bf16_to_f64(buf[offsets[2]]),
266+
bf16_to_f64(buf[offsets[3]]),
267+
bf16_to_f64(buf[offsets[4]]),
268+
bf16_to_f64(buf[offsets[5]]),
269+
bf16_to_f64(buf[offsets[6]]),
270+
bf16_to_f64(buf[offsets[7]]),
271+
])
272+
}
273+
274+
/// Project 8 BF16 rows simultaneously to 8 Base17 patterns.
260275
///
261-
/// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
262-
/// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
263-
pub fn project_tensor_bf16_simd(
276+
/// Memory: 17 × F64x8 accumulators on stack = 17 × 64 = 1088 bytes.
277+
pub fn project_8rows_bf16_simd(
264278
buf: &[u16],
265-
n_rows: usize,
279+
row_starts: &[usize; 8],
266280
n_cols: usize,
267281
octave_stride: usize,
268-
) -> Vec<Base17> {
282+
) -> [Base17; 8] {
269283
use crate::simd::F64x8;
270284

271285
let n_octaves = (n_cols + BASE_DIM - 1) / BASE_DIM;
272-
let mut result = Vec::with_capacity(n_rows);
286+
let use_halftone = octave_stride > 1;
273287

274-
// Process 8 rows at a time
275-
let full_batches = n_rows / 8;
276-
let remainder = n_rows % 8;
277-
278-
for batch in 0..full_batches {
279-
let base_row = batch * 8;
280-
281-
// 9 halftone bins × F64x8 accumulators (8 rows per lane)
282-
let mut half_sum = [F64x8::splat(0.0); 9];
283-
let mut half_count = [0u32; 9]; // same count for all 8 rows (same n_cols)
288+
let mut sums: [F64x8; BASE_DIM] = [F64x8::splat(0.0); BASE_DIM];
289+
let mut counts: [u32; BASE_DIM] = [0; BASE_DIM];
284290

291+
if use_halftone {
285292
let mut octave = 0;
286293
while octave < n_octaves {
287294
for hi in 0..9 {
288-
let dim = octave * BASE_DIM + HALFTONE_POS[hi] as usize;
289-
if dim < n_cols {
290-
// Gather 8 BF16 values (one per row) at column `dim`
291-
let vals = F64x8::from_array([
292-
bf16_to_f64(buf[(base_row + 0) * n_cols + dim]),
293-
bf16_to_f64(buf[(base_row + 1) * n_cols + dim]),
294-
bf16_to_f64(buf[(base_row + 2) * n_cols + dim]),
295-
bf16_to_f64(buf[(base_row + 3) * n_cols + dim]),
296-
bf16_to_f64(buf[(base_row + 4) * n_cols + dim]),
297-
bf16_to_f64(buf[(base_row + 5) * n_cols + dim]),
298-
bf16_to_f64(buf[(base_row + 6) * n_cols + dim]),
299-
bf16_to_f64(buf[(base_row + 7) * n_cols + dim]),
300-
]);
301-
half_sum[hi] = half_sum[hi] + vals;
302-
if batch == 0 || octave == 0 {
303-
// Count is same for all batches with same n_cols
304-
}
305-
half_count[hi] += 1;
295+
let col = octave * BASE_DIM + HALFTONE_POS[hi] as usize;
296+
if col < n_cols {
297+
let bin = HALFTONE_TO_BIN[hi] as usize;
298+
let offsets: [usize; 8] = [
299+
row_starts[0] + col, row_starts[1] + col,
300+
row_starts[2] + col, row_starts[3] + col,
301+
row_starts[4] + col, row_starts[5] + col,
302+
row_starts[6] + col, row_starts[7] + col,
303+
];
304+
sums[bin] += gather_bf16_x8(buf, &offsets);
305+
counts[bin] += 1;
306306
}
307307
}
308308
octave += octave_stride;
309309
}
310310

311-
// Finalize: convert 9 SIMD accumulators → 8 Base17 results
312-
// Even bins: mean × FP_SCALE, clamped to i16
313-
let mut even_dims = [[0i16; BASE_DIM]; 8];
314-
315-
for hi in 0..9 {
316-
if half_count[hi] > 0 {
317-
let count_v = F64x8::splat(half_count[hi] as f64);
318-
let scale_v = F64x8::splat(FP_SCALE);
319-
let mean_v = half_sum[hi] / count_v;
320-
let scaled = mean_v * scale_v;
321-
let arr = scaled.to_array();
322-
let bin = HALFTONE_TO_BIN[hi] as usize;
323-
for lane in 0..8 {
324-
even_dims[lane][bin] =
325-
arr[lane].round().clamp(-32768.0, 32767.0) as i16;
311+
// Interpolate odd bins from even neighbors (per-lane, still SIMD)
312+
for odd in (1..BASE_DIM).step_by(2) {
313+
let left = sums[odd - 1];
314+
let right = sums[(odd + 1) % BASE_DIM];
315+
let left_c = counts[odd - 1].max(1);
316+
let right_c = counts[(odd + 1) % BASE_DIM].max(1);
317+
let left_mean = left * F64x8::splat(1.0 / left_c as f64);
318+
let right_mean = right * F64x8::splat(1.0 / right_c as f64);
319+
sums[odd] = (left_mean + right_mean) * F64x8::splat(0.5);
320+
counts[odd] = 1;
321+
}
322+
} else {
323+
for octave in 0..n_octaves {
324+
for bi in 0..BASE_DIM {
325+
let col = octave * BASE_DIM + GOLDEN_POS[bi] as usize;
326+
if col < n_cols {
327+
let offsets: [usize; 8] = [
328+
row_starts[0] + col, row_starts[1] + col,
329+
row_starts[2] + col, row_starts[3] + col,
330+
row_starts[4] + col, row_starts[5] + col,
331+
row_starts[6] + col, row_starts[7] + col,
332+
];
333+
sums[bi] += gather_bf16_x8(buf, &offsets);
334+
counts[bi] += 1;
326335
}
327336
}
328337
}
338+
}
329339

330-
// Odd bins: interpolate from neighbors
340+
// Finalize: mean → scale → clamp → i16, all 8 lanes parallel
341+
let lo = F64x8::splat(-32768.0);
342+
let hi = F64x8::splat(32767.0);
343+
344+
let mut dims_x8: [[i16; BASE_DIM]; 8] = [[0i16; BASE_DIM]; 8];
345+
346+
for bin in 0..BASE_DIM {
347+
let c = counts[bin].max(1) as f64;
348+
let scaled = sums[bin].mul_add(
349+
F64x8::splat(FP_SCALE / c),
350+
F64x8::splat(0.0),
351+
);
352+
let clamped = scaled.round().simd_clamp(lo, hi);
353+
let vals = clamped.to_array();
331354
for lane in 0..8 {
332-
for odd in (1..BASE_DIM).step_by(2) {
333-
let left = even_dims[lane][odd - 1] as i32;
334-
let right = even_dims[lane][(odd + 1) % BASE_DIM] as i32;
335-
even_dims[lane][odd] = ((left + right) / 2) as i16;
355+
dims_x8[lane][bin] = vals[lane] as i16;
356+
}
357+
}
358+
359+
[
360+
Base17 { dims: dims_x8[0] }, Base17 { dims: dims_x8[1] },
361+
Base17 { dims: dims_x8[2] }, Base17 { dims: dims_x8[3] },
362+
Base17 { dims: dims_x8[4] }, Base17 { dims: dims_x8[5] },
363+
Base17 { dims: dims_x8[6] }, Base17 { dims: dims_x8[7] },
364+
]
365+
}
366+
367+
/// Scalar fallback for remainder rows (< 8).
368+
pub fn project_1row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {
369+
let d = row.len();
370+
let n_octaves = (d + BASE_DIM - 1) / BASE_DIM;
371+
let use_halftone = octave_stride > 1;
372+
373+
let mut sum = [0.0f64; BASE_DIM];
374+
let mut count = [0u32; BASE_DIM];
375+
376+
if use_halftone {
377+
let mut octave = 0;
378+
while octave < n_octaves {
379+
for hi in 0..9 {
380+
let col = octave * BASE_DIM + HALFTONE_POS[hi] as usize;
381+
if col < d {
382+
sum[HALFTONE_TO_BIN[hi] as usize] += bf16_to_f64(row[col]);
383+
count[HALFTONE_TO_BIN[hi] as usize] += 1;
384+
}
385+
}
386+
octave += octave_stride;
387+
}
388+
for odd in (1..BASE_DIM).step_by(2) {
389+
let lc = count[odd - 1].max(1) as f64;
390+
let rc = count[(odd + 1) % BASE_DIM].max(1) as f64;
391+
sum[odd] = (sum[odd - 1] / lc + sum[(odd + 1) % BASE_DIM] / rc) * 0.5;
392+
count[odd] = 1;
393+
}
394+
} else {
395+
for octave in 0..n_octaves {
396+
for bi in 0..BASE_DIM {
397+
let col = octave * BASE_DIM + GOLDEN_POS[bi] as usize;
398+
if col < d {
399+
sum[bi] += bf16_to_f64(row[col]);
400+
count[bi] += 1;
401+
}
336402
}
337-
result.push(Base17 { dims: even_dims[lane] });
338403
}
339404
}
340405

341-
// Scalar tail for remaining rows (< 8)
406+
let mut dims = [0i16; BASE_DIM];
407+
for i in 0..BASE_DIM {
408+
if count[i] > 0 {
409+
let mean = sum[i] / count[i] as f64;
410+
dims[i] = (mean * FP_SCALE).round().clamp(-32768.0, 32767.0) as i16;
411+
}
412+
}
413+
Base17 { dims }
414+
}
415+
416+
/// Project an entire BF16 tensor to Base17 using F64x8 SIMD.
417+
///
418+
/// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
419+
/// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
420+
///
421+
/// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
422+
/// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
423+
pub fn project_tensor_bf16_simd(
424+
buf: &[u16],
425+
n_rows: usize,
426+
n_cols: usize,
427+
octave_stride: usize,
428+
) -> Vec<Base17> {
429+
let mut result = Vec::with_capacity(n_rows);
430+
431+
let full_batches = n_rows / 8;
432+
433+
for batch in 0..full_batches {
434+
let base_row = batch * 8;
435+
let row_starts: [usize; 8] = [
436+
(base_row + 0) * n_cols, (base_row + 1) * n_cols,
437+
(base_row + 2) * n_cols, (base_row + 3) * n_cols,
438+
(base_row + 4) * n_cols, (base_row + 5) * n_cols,
439+
(base_row + 6) * n_cols, (base_row + 7) * n_cols,
440+
];
441+
let b17s = project_8rows_bf16_simd(buf, &row_starts, n_cols, octave_stride);
442+
result.extend_from_slice(&b17s);
443+
}
444+
445+
// Scalar tail
342446
for r in (full_batches * 8)..n_rows {
343447
let start = r * n_cols;
344448
let end = (start + n_cols).min(buf.len());
345-
result.push(project_row_bf16_strided(&buf[start..end], octave_stride));
449+
result.push(project_1row_bf16_strided(&buf[start..end], octave_stride));
346450
}
347451

348452
result
@@ -1147,6 +1251,57 @@ mod tests {
11471251
}
11481252
}
11491253

1254+
#[test]
1255+
fn test_simd_matches_scalar_constant() {
1256+
let n_cols = 5120;
1257+
let n_rows = 16; // 2 full SIMD batches
1258+
let buf: Vec<u16> = vec![0x3F80; n_rows * n_cols]; // all 1.0 in BF16
1259+
1260+
let simd_results = project_tensor_bf16_simd(&buf, n_rows, n_cols, 1);
1261+
assert_eq!(simd_results.len(), n_rows);
1262+
1263+
for r in 1..n_rows {
1264+
for bin in 0..BASE_DIM {
1265+
let diff = (simd_results[0].dims[bin] as i32 - simd_results[r].dims[bin] as i32).abs();
1266+
assert!(diff == 0, "row {} bin {} differs: {} vs {}",
1267+
r, bin, simd_results[0].dims[bin], simd_results[r].dims[bin]);
1268+
}
1269+
}
1270+
}
1271+
1272+
#[test]
1273+
fn test_simd_matches_scalar_strided() {
1274+
let n_cols = 13824;
1275+
let n_rows = 11; // 1 full batch + 3 remainder
1276+
let mut buf = vec![0x3F80u16; n_rows * n_cols];
1277+
for i in (0..buf.len()).step_by(2) {
1278+
buf[i] = 0xBF80; // -1.0
1279+
}
1280+
1281+
let simd_results = project_tensor_bf16_simd(&buf, n_rows, n_cols, 16);
1282+
assert_eq!(simd_results.len(), n_rows);
1283+
1284+
for r in 0..n_rows {
1285+
let start = r * n_cols;
1286+
let scalar = project_1row_bf16_strided(&buf[start..start + n_cols], 16);
1287+
for bin in 0..BASE_DIM {
1288+
let diff = (simd_results[r].dims[bin] as i32 - scalar.dims[bin] as i32).abs();
1289+
assert!(diff <= 1, "row {} bin {} simd={} scalar={} diff={}",
1290+
r, bin, simd_results[r].dims[bin], scalar.dims[bin], diff);
1291+
}
1292+
}
1293+
}
1294+
1295+
#[test]
1296+
fn test_simd_tail_handling() {
1297+
let n_cols = 256;
1298+
for n_rows in 1..8 {
1299+
let buf: Vec<u16> = vec![0x4000; n_rows * n_cols]; // 2.0 in BF16
1300+
let results = project_tensor_bf16_simd(&buf, n_rows, n_cols, 16);
1301+
assert_eq!(results.len(), n_rows, "wrong count for n_rows={}", n_rows);
1302+
}
1303+
}
1304+
11501305
#[test]
11511306
#[ignore] // Streams ~801 GB from HuggingFace
11521307
fn test_stream_index_llama4_maverick_bf16_all_shards() {

0 commit comments

Comments
 (0)