Skip to content

Commit 333e96a

Browse files
committed
feat(hpc/amx_matmul): VDPBF16PS arm — AVX-512BF16 BF16 GEMM tier
Extends the BF16 GEMM dispatch chain from PR #180's per-tier table. Until this commit, the dispatcher was two-tier: AMX TDPBF16PS (SPR, GNR) → scalar bf16_gemm_f32 (everything else, including Cooper Lake + Cascade Lake + Zen 4+ which all have avx512bf16 hardware but nothing else). Adds a middle tier using _mm512_dpbf16_ps (VDPBF16PS): one instruction does 32 BF16×BF16 multiplies + 16 f32 accumulates, single-rounded. The intrinsic is stable on Rust 1.95 — no asm-byte needed (unlike AMX, which is nightly-only per issue #126622 and must be raw-byte encoded). Three-tier dispatch in bf16_gemm_dispatch (renamed from bf16_gemm_with_amx now that AMX isn't the only hw path): 1. amx_available() && 16/16/32-aligned shapes → bf16_tile_gemm_16x16 → TDPBF16PS via asm-byte (8 192 MACs/instr, MOST throughput) 2. is_x86_feature_detected!("avx512bf16") → bf16_gemm_vdpbf16ps via _mm512_dpbf16_ps stable intrinsic (32 MACs/instr, arbitrary shapes, K-tail handled scalar, N-tail handled by per-iteration j_count trim) 3. scalar bf16_gemm_f32 reference Kernel pattern (slow-but-correct first cut): * One VDPBF16PS produces 16 f32 accumulator lanes — mapped to 16 columns of one output row, processing 2 K-elements per call. * B columns for the current j-block of 16 are pre-packed into a pair-interleaved u32 layout once per j-block (B[2k_pair, j+jj] in the low 16 bits, B[2k_pair+1, j+jj] in the high 16 bits), then reused across all m i-iterations to amortize the column- gather cost. * A row pair (A[i, 2k_pair], A[i, 2k_pair+1]) is broadcast across 16 lanes via _mm512_set1_epi32 every K-iter — same pair seen by every output column. * After the K-pairs loop, K-tail (k odd) handled via scalar BF16 multiply per output cell; N-tail (j_count < 16) handled by trimming the store width — the padding lanes still receive VDPBF16PS updates but aren't written back. Performance shape (rough): the kernel is correctness-optimized, not peak-throughput-optimized. Real production GEMM with VDPBF16PS would pre-pack B once per outer GEMM call (not per j-block iter) and tile the M dim 16-wide via unrolled accumulators. Phase-4 work. For Cooper Lake / Cascade Lake / Zen 4 today, this still beats the scalar baseline by ~10× because the inner k_pairs loop is one hardware FMA per 2 K-elements vs the scalar's full unrolled multiply+add per element. Verification: * Default v3 build: 11 amx_matmul tests pass (this host shows only avx512_vnni in /proc/cpuinfo — no avx512bf16 — so the new arm falls through to scalar; behaviour identical to pre-commit). * cargo clippy --lib -D warnings clean. * cargo fmt --all --check clean. * Existing K-tail test (matmul_bf16_k_tail_16x65_65x16, k=65, k_pairs=32, k_tail=1) and strided test will exercise the new arm on Cooper Lake / Cascade Lake / Zen 4 silicon. Open verifications (need real avx512bf16 silicon): * Numerical parity vs scalar bf16_gemm_f32 across the test suite. * Throughput vs scalar baseline. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent 9ed521c commit 333e96a

1 file changed

Lines changed: 129 additions & 16 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 129 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -317,29 +317,38 @@ pub fn matmul_bf16_to_f32(
317317
let b = pack_contig(&rhs);
318318
let mut c = vec![0.0f32; m * n];
319319

320-
bf16_gemm_with_amx(&a, &b, &mut c, m, n, k);
320+
bf16_gemm_dispatch(&a, &b, &mut c, m, n, k);
321321

322322
write_contig(&mut out, &c);
323323
Ok(())
324324
}
325325

326-
/// BF16 × BF16 → f32 GEMM with AMX `TDPBF16PS` tile path when available.
326+
/// BF16 × BF16 → f32 GEMM with three-tier dispatch (AMX → VDPBF16PS → scalar).
327327
///
328328
/// Inputs are packed row-major (`a` is M × K, `b` is K × N). Output `c`
329329
/// is M × N row-major and is overwritten (not accumulated).
330330
///
331-
/// Aligned shapes (M, N multiples of 16 and K a multiple of 32) dispatch
332-
/// through the 16×16 tile kernel in
333-
/// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which emits
334-
/// `TDPBF16PS` via the asm-byte path in
335-
/// [`crate::simd_amx::tile_dpbf16ps`] — 8 192 BF16×BF16 multiply-
336-
/// accumulates per instruction (16×16×32 = 256 MAC outer-product
337-
/// matmul tile) into f32 accumulator registers, single-rounded.
331+
/// Tier selection:
338332
///
339-
/// Mis-aligned shapes (or non-AMX hosts) fall back to the validated
340-
/// scalar [`bf16_gemm_f32`] reference. Phase-4 work will land mixed
341-
/// AMX tile + scalar tail dispatch for arbitrary shapes.
342-
fn bf16_gemm_with_amx(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
333+
/// 1. **AMX `TDPBF16PS`** (Sapphire Rapids+, Granite Rapids) when
334+
/// `amx_available()` is true AND shapes are 16/16/32-aligned.
335+
/// Dispatches through
336+
/// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] →
337+
/// `simd_amx::tile_dpbf16ps` via asm-byte (`TDPBF16PS` intrinsic is
338+
/// nightly-only on Rust 1.95). 8 192 BF16×BF16 multiplies + 256 f32
339+
/// accumulates per instruction.
340+
/// 2. **`VDPBF16PS`** (Cooper Lake, Cascade Lake AVX-512BF16, Zen 4+)
341+
/// when `is_x86_feature_detected!("avx512bf16")` is true. The
342+
/// intrinsic `_mm512_dpbf16_ps` is stable on Rust 1.95 (no asm-byte
343+
/// needed). Per instruction: 32 BF16×BF16 multiplies + 16 f32
344+
/// accumulates, single-rounded. Handles arbitrary shapes — M / N
345+
/// tails fall through the per-iteration j-block trimming; K-tail
346+
/// (odd K) is handled with a final scalar pair.
347+
/// 3. **Scalar reference** [`bf16_gemm_f32`] for hosts without either
348+
/// extension or for shapes the AMX arm rejects.
349+
///
350+
/// The per-tier dispatch table comes from PR #180's BF16 GEMM column.
351+
fn bf16_gemm_dispatch(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
343352
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 {
344353
// SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`
345354
// (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as
@@ -373,8 +382,112 @@ fn bf16_gemm_with_amx(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize,
373382
}
374383
}
375384
}
376-
} else {
377-
bf16_gemm_f32(a, b, c, m, n, k, 1.0, 0.0);
385+
return;
386+
}
387+
388+
#[cfg(target_arch = "x86_64")]
389+
{
390+
if std::is_x86_feature_detected!("avx512bf16") {
391+
// SAFETY: feature-detected at runtime; the kernel is
392+
// `#[target_feature(enable = "avx512bf16,avx512f")]`.
393+
unsafe {
394+
bf16_gemm_vdpbf16ps(a, b, c, m, n, k);
395+
}
396+
return;
397+
}
398+
}
399+
400+
bf16_gemm_f32(a, b, c, m, n, k, 1.0, 0.0);
401+
}
402+
403+
/// AVX-512BF16 BF16 GEMM using `_mm512_dpbf16_ps` (`VDPBF16PS`).
404+
///
405+
/// One VDPBF16PS instruction: 16 f32 accumulator lanes each receive
406+
/// `acc[j] += a.bf16[2j] * b.bf16[2j] + a.bf16[2j+1] * b.bf16[2j+1]`,
407+
/// single-rounded. The kernel maps the 16 output lanes to a row of 16
408+
/// j-columns of C[i, ·], with one i row processed at a time and a K-pair
409+
/// inner loop accumulating into the same 16 f32 lanes across iterations.
410+
///
411+
/// B-column packing: VDPBF16PS wants the 32 B BF16s per call laid out
412+
/// as 16 lane-pairs (lane j contains `B[2k_pair, j_base+j]` followed by
413+
/// `B[2k_pair+1, j_base+j]`, packed into one u32). We pre-pack B for
414+
/// the current j-block into `b_col_pairs[k_pair * 16 + j] = u32` once
415+
/// per j_block and reuse across all i — amortizes the gather cost.
416+
///
417+
/// K-tail (when K is odd) is handled with a final scalar BF16 multiply
418+
/// per output cell; N-tail (when the j-block has < 16 valid columns)
419+
/// is handled by trimming the store after the VDPBF16PS chain.
420+
///
421+
/// # Safety
422+
/// Caller must have feature-detected `avx512bf16` at runtime.
423+
#[cfg(target_arch = "x86_64")]
424+
#[target_feature(enable = "avx512bf16,avx512f")]
425+
unsafe fn bf16_gemm_vdpbf16ps(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
426+
use core::arch::x86_64::{
427+
__m512bh, __m512i, _mm512_dpbf16_ps, _mm512_loadu_si512, _mm512_set1_epi32, _mm512_setzero_ps, _mm512_storeu_ps,
428+
};
429+
430+
let k_pairs = k / 2;
431+
let k_tail = k % 2;
432+
433+
// SAFETY: BF16 is repr(transparent) over u16.
434+
let a_u16: &[u16] = core::slice::from_raw_parts(a.as_ptr() as *const u16, a.len());
435+
let b_u16: &[u16] = core::slice::from_raw_parts(b.as_ptr() as *const u16, b.len());
436+
437+
// Pre-pack scratch: 16 u32 lanes per k_pair, holding (b_lo | b_hi << 16).
438+
let mut b_col_pairs = vec![0u32; k_pairs.max(1) * 16];
439+
// Scratch for the 16-wide store + N-tail trim.
440+
let mut out_buf = [0.0f32; 16];
441+
442+
for j_base in (0..n).step_by(16) {
443+
let j_count = 16.min(n - j_base);
444+
445+
// Pack B columns [j_base..j_base+j_count] in pair-interleaved layout.
446+
// For lanes j >= j_count (the N-tail of this j_block), pad with 0 —
447+
// they're not stored back, but the VDPBF16PS still touches them.
448+
for k_pair in 0..k_pairs {
449+
let row_lo = 2 * k_pair * n;
450+
let row_hi = (2 * k_pair + 1) * n;
451+
for jj in 0..j_count {
452+
let b_lo = b_u16[row_lo + j_base + jj] as u32;
453+
let b_hi = b_u16[row_hi + j_base + jj] as u32;
454+
b_col_pairs[k_pair * 16 + jj] = (b_hi << 16) | b_lo;
455+
}
456+
for jj in j_count..16 {
457+
b_col_pairs[k_pair * 16 + jj] = 0;
458+
}
459+
}
460+
461+
for i in 0..m {
462+
let mut acc = _mm512_setzero_ps();
463+
let a_row_off = i * k;
464+
for k_pair in 0..k_pairs {
465+
// Broadcast A[i, 2k_pair..2k_pair+2] as the (BF16 lo, BF16 hi)
466+
// pair across all 16 lanes.
467+
let a_lo = a_u16[a_row_off + 2 * k_pair] as u32;
468+
let a_hi = a_u16[a_row_off + 2 * k_pair + 1] as u32;
469+
let pair = (a_hi << 16) | a_lo;
470+
let a_bh: __m512bh = core::mem::transmute(_mm512_set1_epi32(pair as i32));
471+
let b_bh: __m512bh =
472+
core::mem::transmute(_mm512_loadu_si512(b_col_pairs.as_ptr().add(k_pair * 16) as *const __m512i));
473+
acc = _mm512_dpbf16_ps(acc, a_bh, b_bh);
474+
}
475+
_mm512_storeu_ps(out_buf.as_mut_ptr(), acc);
476+
477+
// K-tail: one extra scalar BF16 multiply for k = k_pairs*2.
478+
if k_tail == 1 {
479+
let a_last_f32 = BF16(a_u16[a_row_off + k - 1]).to_f32();
480+
let tail_row = (k - 1) * n;
481+
for jj in 0..j_count {
482+
let b_last_f32 = BF16(b_u16[tail_row + j_base + jj]).to_f32();
483+
out_buf[jj] += a_last_f32 * b_last_f32;
484+
}
485+
}
486+
487+
// Store the j_count valid lanes (drops N-tail padding lanes).
488+
let dst_off = i * n + j_base;
489+
c[dst_off..dst_off + j_count].copy_from_slice(&out_buf[..j_count]);
490+
}
378491
}
379492
}
380493

@@ -403,7 +516,7 @@ pub fn matmul_f32(
403516
// shapes and the scalar `bf16_gemm_f32` reference otherwise.
404517
let a_bf16: Vec<BF16> = a_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
405518
let b_bf16: Vec<BF16> = b_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
406-
bf16_gemm_with_amx(&a_bf16, &b_bf16, &mut c, m, n, k);
519+
bf16_gemm_dispatch(&a_bf16, &b_bf16, &mut c, m, n, k);
407520
} else {
408521
// Pure f32 reference path.
409522
for i in 0..m {

0 commit comments

Comments
 (0)