@@ -251,98 +251,202 @@ pub fn project_row_bf16_strided(row: &[u16], octave_stride: usize) -> Base17 {
251251 Base17 { dims }
252252}
253253
254- // ── SIMD 8-row-parallel tensor projection ──
254+ // ── F64x8 SIMD: 8 rows → 8 Base17 in parallel ──
255255
256- /// Project an entire BF16 tensor to Base17 using F64x8 SIMD .
256+ /// Gather 8 BF16 values from 8 rows at the same column, convert to F64x8 .
257257///
258- /// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
259- /// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
258+ /// The gather is scalar (8 indexed loads) but the result is SIMD.
259+ /// At -O2 with AVX-512, rustc may emit vpgatherqd + shift + vcvtps2pd.
260+ #[ inline( always) ]
261+ fn gather_bf16_x8 ( buf : & [ u16 ] , offsets : & [ usize ; 8 ] ) -> crate :: simd:: F64x8 {
262+ crate :: simd:: F64x8 :: from_array ( [
263+ bf16_to_f64 ( buf[ offsets[ 0 ] ] ) ,
264+ bf16_to_f64 ( buf[ offsets[ 1 ] ] ) ,
265+ bf16_to_f64 ( buf[ offsets[ 2 ] ] ) ,
266+ bf16_to_f64 ( buf[ offsets[ 3 ] ] ) ,
267+ bf16_to_f64 ( buf[ offsets[ 4 ] ] ) ,
268+ bf16_to_f64 ( buf[ offsets[ 5 ] ] ) ,
269+ bf16_to_f64 ( buf[ offsets[ 6 ] ] ) ,
270+ bf16_to_f64 ( buf[ offsets[ 7 ] ] ) ,
271+ ] )
272+ }
273+
274+ /// Project 8 BF16 rows simultaneously to 8 Base17 patterns.
260275///
261- /// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
262- /// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
263- pub fn project_tensor_bf16_simd (
276+ /// Memory: 17 × F64x8 accumulators on stack = 17 × 64 = 1088 bytes.
277+ pub fn project_8rows_bf16_simd (
264278 buf : & [ u16 ] ,
265- n_rows : usize ,
279+ row_starts : & [ usize ; 8 ] ,
266280 n_cols : usize ,
267281 octave_stride : usize ,
268- ) -> Vec < Base17 > {
282+ ) -> [ Base17 ; 8 ] {
269283 use crate :: simd:: F64x8 ;
270284
271285 let n_octaves = ( n_cols + BASE_DIM - 1 ) / BASE_DIM ;
272- let mut result = Vec :: with_capacity ( n_rows ) ;
286+ let use_halftone = octave_stride > 1 ;
273287
274- // Process 8 rows at a time
275- let full_batches = n_rows / 8 ;
276- let remainder = n_rows % 8 ;
277-
278- for batch in 0 ..full_batches {
279- let base_row = batch * 8 ;
280-
281- // 9 halftone bins × F64x8 accumulators (8 rows per lane)
282- let mut half_sum = [ F64x8 :: splat ( 0.0 ) ; 9 ] ;
283- let mut half_count = [ 0u32 ; 9 ] ; // same count for all 8 rows (same n_cols)
288+ let mut sums: [ F64x8 ; BASE_DIM ] = [ F64x8 :: splat ( 0.0 ) ; BASE_DIM ] ;
289+ let mut counts: [ u32 ; BASE_DIM ] = [ 0 ; BASE_DIM ] ;
284290
291+ if use_halftone {
285292 let mut octave = 0 ;
286293 while octave < n_octaves {
287294 for hi in 0 ..9 {
288- let dim = octave * BASE_DIM + HALFTONE_POS [ hi] as usize ;
289- if dim < n_cols {
290- // Gather 8 BF16 values (one per row) at column `dim`
291- let vals = F64x8 :: from_array ( [
292- bf16_to_f64 ( buf[ ( base_row + 0 ) * n_cols + dim] ) ,
293- bf16_to_f64 ( buf[ ( base_row + 1 ) * n_cols + dim] ) ,
294- bf16_to_f64 ( buf[ ( base_row + 2 ) * n_cols + dim] ) ,
295- bf16_to_f64 ( buf[ ( base_row + 3 ) * n_cols + dim] ) ,
296- bf16_to_f64 ( buf[ ( base_row + 4 ) * n_cols + dim] ) ,
297- bf16_to_f64 ( buf[ ( base_row + 5 ) * n_cols + dim] ) ,
298- bf16_to_f64 ( buf[ ( base_row + 6 ) * n_cols + dim] ) ,
299- bf16_to_f64 ( buf[ ( base_row + 7 ) * n_cols + dim] ) ,
300- ] ) ;
301- half_sum[ hi] = half_sum[ hi] + vals;
302- if batch == 0 || octave == 0 {
303- // Count is same for all batches with same n_cols
304- }
305- half_count[ hi] += 1 ;
295+ let col = octave * BASE_DIM + HALFTONE_POS [ hi] as usize ;
296+ if col < n_cols {
297+ let bin = HALFTONE_TO_BIN [ hi] as usize ;
298+ let offsets: [ usize ; 8 ] = [
299+ row_starts[ 0 ] + col, row_starts[ 1 ] + col,
300+ row_starts[ 2 ] + col, row_starts[ 3 ] + col,
301+ row_starts[ 4 ] + col, row_starts[ 5 ] + col,
302+ row_starts[ 6 ] + col, row_starts[ 7 ] + col,
303+ ] ;
304+ sums[ bin] += gather_bf16_x8 ( buf, & offsets) ;
305+ counts[ bin] += 1 ;
306306 }
307307 }
308308 octave += octave_stride;
309309 }
310310
311- // Finalize: convert 9 SIMD accumulators → 8 Base17 results
312- // Even bins: mean × FP_SCALE, clamped to i16
313- let mut even_dims = [ [ 0i16 ; BASE_DIM ] ; 8 ] ;
314-
315- for hi in 0 ..9 {
316- if half_count[ hi] > 0 {
317- let count_v = F64x8 :: splat ( half_count[ hi] as f64 ) ;
318- let scale_v = F64x8 :: splat ( FP_SCALE ) ;
319- let mean_v = half_sum[ hi] / count_v;
320- let scaled = mean_v * scale_v;
321- let arr = scaled. to_array ( ) ;
322- let bin = HALFTONE_TO_BIN [ hi] as usize ;
323- for lane in 0 ..8 {
324- even_dims[ lane] [ bin] =
325- arr[ lane] . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
311+ // Interpolate odd bins from even neighbors (per-lane, still SIMD)
312+ for odd in ( 1 ..BASE_DIM ) . step_by ( 2 ) {
313+ let left = sums[ odd - 1 ] ;
314+ let right = sums[ ( odd + 1 ) % BASE_DIM ] ;
315+ let left_c = counts[ odd - 1 ] . max ( 1 ) ;
316+ let right_c = counts[ ( odd + 1 ) % BASE_DIM ] . max ( 1 ) ;
317+ let left_mean = left * F64x8 :: splat ( 1.0 / left_c as f64 ) ;
318+ let right_mean = right * F64x8 :: splat ( 1.0 / right_c as f64 ) ;
319+ sums[ odd] = ( left_mean + right_mean) * F64x8 :: splat ( 0.5 ) ;
320+ counts[ odd] = 1 ;
321+ }
322+ } else {
323+ for octave in 0 ..n_octaves {
324+ for bi in 0 ..BASE_DIM {
325+ let col = octave * BASE_DIM + GOLDEN_POS [ bi] as usize ;
326+ if col < n_cols {
327+ let offsets: [ usize ; 8 ] = [
328+ row_starts[ 0 ] + col, row_starts[ 1 ] + col,
329+ row_starts[ 2 ] + col, row_starts[ 3 ] + col,
330+ row_starts[ 4 ] + col, row_starts[ 5 ] + col,
331+ row_starts[ 6 ] + col, row_starts[ 7 ] + col,
332+ ] ;
333+ sums[ bi] += gather_bf16_x8 ( buf, & offsets) ;
334+ counts[ bi] += 1 ;
326335 }
327336 }
328337 }
338+ }
329339
330- // Odd bins: interpolate from neighbors
340+ // Finalize: mean → scale → clamp → i16, all 8 lanes parallel
341+ let lo = F64x8 :: splat ( -32768.0 ) ;
342+ let hi = F64x8 :: splat ( 32767.0 ) ;
343+
344+ let mut dims_x8: [ [ i16 ; BASE_DIM ] ; 8 ] = [ [ 0i16 ; BASE_DIM ] ; 8 ] ;
345+
346+ for bin in 0 ..BASE_DIM {
347+ let c = counts[ bin] . max ( 1 ) as f64 ;
348+ let scaled = sums[ bin] . mul_add (
349+ F64x8 :: splat ( FP_SCALE / c) ,
350+ F64x8 :: splat ( 0.0 ) ,
351+ ) ;
352+ let clamped = scaled. round ( ) . simd_clamp ( lo, hi) ;
353+ let vals = clamped. to_array ( ) ;
331354 for lane in 0 ..8 {
332- for odd in ( 1 ..BASE_DIM ) . step_by ( 2 ) {
333- let left = even_dims[ lane] [ odd - 1 ] as i32 ;
334- let right = even_dims[ lane] [ ( odd + 1 ) % BASE_DIM ] as i32 ;
335- even_dims[ lane] [ odd] = ( ( left + right) / 2 ) as i16 ;
355+ dims_x8[ lane] [ bin] = vals[ lane] as i16 ;
356+ }
357+ }
358+
359+ [
360+ Base17 { dims : dims_x8[ 0 ] } , Base17 { dims : dims_x8[ 1 ] } ,
361+ Base17 { dims : dims_x8[ 2 ] } , Base17 { dims : dims_x8[ 3 ] } ,
362+ Base17 { dims : dims_x8[ 4 ] } , Base17 { dims : dims_x8[ 5 ] } ,
363+ Base17 { dims : dims_x8[ 6 ] } , Base17 { dims : dims_x8[ 7 ] } ,
364+ ]
365+ }
366+
367+ /// Scalar fallback for remainder rows (< 8).
368+ pub fn project_1row_bf16_strided ( row : & [ u16 ] , octave_stride : usize ) -> Base17 {
369+ let d = row. len ( ) ;
370+ let n_octaves = ( d + BASE_DIM - 1 ) / BASE_DIM ;
371+ let use_halftone = octave_stride > 1 ;
372+
373+ let mut sum = [ 0.0f64 ; BASE_DIM ] ;
374+ let mut count = [ 0u32 ; BASE_DIM ] ;
375+
376+ if use_halftone {
377+ let mut octave = 0 ;
378+ while octave < n_octaves {
379+ for hi in 0 ..9 {
380+ let col = octave * BASE_DIM + HALFTONE_POS [ hi] as usize ;
381+ if col < d {
382+ sum[ HALFTONE_TO_BIN [ hi] as usize ] += bf16_to_f64 ( row[ col] ) ;
383+ count[ HALFTONE_TO_BIN [ hi] as usize ] += 1 ;
384+ }
385+ }
386+ octave += octave_stride;
387+ }
388+ for odd in ( 1 ..BASE_DIM ) . step_by ( 2 ) {
389+ let lc = count[ odd - 1 ] . max ( 1 ) as f64 ;
390+ let rc = count[ ( odd + 1 ) % BASE_DIM ] . max ( 1 ) as f64 ;
391+ sum[ odd] = ( sum[ odd - 1 ] / lc + sum[ ( odd + 1 ) % BASE_DIM ] / rc) * 0.5 ;
392+ count[ odd] = 1 ;
393+ }
394+ } else {
395+ for octave in 0 ..n_octaves {
396+ for bi in 0 ..BASE_DIM {
397+ let col = octave * BASE_DIM + GOLDEN_POS [ bi] as usize ;
398+ if col < d {
399+ sum[ bi] += bf16_to_f64 ( row[ col] ) ;
400+ count[ bi] += 1 ;
401+ }
336402 }
337- result. push ( Base17 { dims : even_dims[ lane] } ) ;
338403 }
339404 }
340405
341- // Scalar tail for remaining rows (< 8)
406+ let mut dims = [ 0i16 ; BASE_DIM ] ;
407+ for i in 0 ..BASE_DIM {
408+ if count[ i] > 0 {
409+ let mean = sum[ i] / count[ i] as f64 ;
410+ dims[ i] = ( mean * FP_SCALE ) . round ( ) . clamp ( -32768.0 , 32767.0 ) as i16 ;
411+ }
412+ }
413+ Base17 { dims }
414+ }
415+
416+ /// Project an entire BF16 tensor to Base17 using F64x8 SIMD.
417+ ///
418+ /// Processes 8 rows in parallel per SIMD batch. Each of the 9 halftone bins
419+ /// holds an F64x8 accumulator (8 rows × 9 bins = 72 f64 lanes = 9 zmm registers).
420+ ///
421+ /// Per sampled octave: 9 halftone positions × 8 bf16_to_f64 gathers → 9 vaddpd.
422+ /// For 5120-col rows at stride=16: 19 octaves × 9 = 171 vaddpd per 8-row batch.
423+ pub fn project_tensor_bf16_simd (
424+ buf : & [ u16 ] ,
425+ n_rows : usize ,
426+ n_cols : usize ,
427+ octave_stride : usize ,
428+ ) -> Vec < Base17 > {
429+ let mut result = Vec :: with_capacity ( n_rows) ;
430+
431+ let full_batches = n_rows / 8 ;
432+
433+ for batch in 0 ..full_batches {
434+ let base_row = batch * 8 ;
435+ let row_starts: [ usize ; 8 ] = [
436+ ( base_row + 0 ) * n_cols, ( base_row + 1 ) * n_cols,
437+ ( base_row + 2 ) * n_cols, ( base_row + 3 ) * n_cols,
438+ ( base_row + 4 ) * n_cols, ( base_row + 5 ) * n_cols,
439+ ( base_row + 6 ) * n_cols, ( base_row + 7 ) * n_cols,
440+ ] ;
441+ let b17s = project_8rows_bf16_simd ( buf, & row_starts, n_cols, octave_stride) ;
442+ result. extend_from_slice ( & b17s) ;
443+ }
444+
445+ // Scalar tail
342446 for r in ( full_batches * 8 ) ..n_rows {
343447 let start = r * n_cols;
344448 let end = ( start + n_cols) . min ( buf. len ( ) ) ;
345- result. push ( project_row_bf16_strided ( & buf[ start..end] , octave_stride) ) ;
449+ result. push ( project_1row_bf16_strided ( & buf[ start..end] , octave_stride) ) ;
346450 }
347451
348452 result
@@ -1147,6 +1251,57 @@ mod tests {
11471251 }
11481252 }
11491253
1254+ #[ test]
1255+ fn test_simd_matches_scalar_constant ( ) {
1256+ let n_cols = 5120 ;
1257+ let n_rows = 16 ; // 2 full SIMD batches
1258+ let buf: Vec < u16 > = vec ! [ 0x3F80 ; n_rows * n_cols] ; // all 1.0 in BF16
1259+
1260+ let simd_results = project_tensor_bf16_simd ( & buf, n_rows, n_cols, 1 ) ;
1261+ assert_eq ! ( simd_results. len( ) , n_rows) ;
1262+
1263+ for r in 1 ..n_rows {
1264+ for bin in 0 ..BASE_DIM {
1265+ let diff = ( simd_results[ 0 ] . dims [ bin] as i32 - simd_results[ r] . dims [ bin] as i32 ) . abs ( ) ;
1266+ assert ! ( diff == 0 , "row {} bin {} differs: {} vs {}" ,
1267+ r, bin, simd_results[ 0 ] . dims[ bin] , simd_results[ r] . dims[ bin] ) ;
1268+ }
1269+ }
1270+ }
1271+
1272+ #[ test]
1273+ fn test_simd_matches_scalar_strided ( ) {
1274+ let n_cols = 13824 ;
1275+ let n_rows = 11 ; // 1 full batch + 3 remainder
1276+ let mut buf = vec ! [ 0x3F80u16 ; n_rows * n_cols] ;
1277+ for i in ( 0 ..buf. len ( ) ) . step_by ( 2 ) {
1278+ buf[ i] = 0xBF80 ; // -1.0
1279+ }
1280+
1281+ let simd_results = project_tensor_bf16_simd ( & buf, n_rows, n_cols, 16 ) ;
1282+ assert_eq ! ( simd_results. len( ) , n_rows) ;
1283+
1284+ for r in 0 ..n_rows {
1285+ let start = r * n_cols;
1286+ let scalar = project_1row_bf16_strided ( & buf[ start..start + n_cols] , 16 ) ;
1287+ for bin in 0 ..BASE_DIM {
1288+ let diff = ( simd_results[ r] . dims [ bin] as i32 - scalar. dims [ bin] as i32 ) . abs ( ) ;
1289+ assert ! ( diff <= 1 , "row {} bin {} simd={} scalar={} diff={}" ,
1290+ r, bin, simd_results[ r] . dims[ bin] , scalar. dims[ bin] , diff) ;
1291+ }
1292+ }
1293+ }
1294+
1295+ #[ test]
1296+ fn test_simd_tail_handling ( ) {
1297+ let n_cols = 256 ;
1298+ for n_rows in 1 ..8 {
1299+ let buf: Vec < u16 > = vec ! [ 0x4000 ; n_rows * n_cols] ; // 2.0 in BF16
1300+ let results = project_tensor_bf16_simd ( & buf, n_rows, n_cols, 16 ) ;
1301+ assert_eq ! ( results. len( ) , n_rows, "wrong count for n_rows={}" , n_rows) ;
1302+ }
1303+ }
1304+
11501305 #[ test]
11511306 #[ ignore] // Streams ~801 GB from HuggingFace
11521307 fn test_stream_index_llama4_maverick_bf16_all_shards ( ) {
0 commit comments