Skip to content

Commit 0eaa3ac

Browse files
committed
feat(hpc/amx_matmul): TD-T1b — matmul_f32 AMX arm routes through tile kernel
Follow-up to TD-T1 (fe334de). `matmul_f32`'s AMX branch was the same shape of placebo as `matmul_bf16_to_f32`'s pre-TD-T1: it down-cast f32 → BF16, then called the scalar `bf16_gemm_f32` reference — never reaching `TDPBF16PS` even on real AMX silicon. Factored the BF16 AMX-tile dispatch logic out of `matmul_bf16_to_f32` into a private `bf16_gemm_with_amx(a, b, c, m, n, k)` helper. Both public entry points now route through it: matmul_bf16_to_f32 → bf16_gemm_with_amx (direct BF16 inputs) matmul_f32 → RNE down-cast → bf16_gemm_with_amx (f32 in, BF16 compute, f32 accumulator out) The helper's behaviour is unchanged from what TD-T1 shipped: 16/16/32- aligned shapes hit `bf16_tile_gemm_16x16` (TDPBF16PS via asm-byte, 8 192 BF16×BF16 multiplies + 256 f32 accumulates per instruction); mis-aligned shapes or non-AMX hosts fall back to scalar `bf16_gemm_f32`. Single source of truth — future Phase-4 mixed-tile- plus-tail dispatch only needs to land in one place. Verification: * 11 amx_matmul tests pass (default v3, no AMX on this host → scalar fallback exercised; behaviour identical to pre-commit). * cargo clippy --lib -D warnings clean. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent fe334de commit 0eaa3ac

1 file changed

Lines changed: 29 additions & 10 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,29 @@ pub fn matmul_bf16_to_f32(
317317
let b = pack_contig(&rhs);
318318
let mut c = vec![0.0f32; m * n];
319319

320-
// AMX TDPBF16PS tile path: requires m, n multiples of 16 and k a
321-
// multiple of 32 (the tile shape `bf16_tile_gemm_16x16` enforces).
322-
// For mis-aligned shapes fall back to scalar — Phase-4 work will
323-
// add mixed-tile / tail handling.
320+
bf16_gemm_with_amx(&a, &b, &mut c, m, n, k);
321+
322+
write_contig(&mut out, &c);
323+
Ok(())
324+
}
325+
326+
/// BF16 × BF16 → f32 GEMM with AMX `TDPBF16PS` tile path when available.
327+
///
328+
/// Inputs are packed row-major (`a` is M × K, `b` is K × N). Output `c`
329+
/// is M × N row-major and is overwritten (not accumulated).
330+
///
331+
/// Aligned shapes (M, N multiples of 16 and K a multiple of 32) dispatch
332+
/// through the 16×16 tile kernel in
333+
/// [`crate::hpc::bf16_tile_gemm::bf16_tile_gemm_16x16`] which emits
334+
/// `TDPBF16PS` via the asm-byte path in
335+
/// [`crate::simd_amx::tile_dpbf16ps`] — 8 192 BF16×BF16 multiply-
336+
/// accumulates per instruction (16×16×32 = 256 MAC outer-product
337+
/// matmul tile) into f32 accumulator registers, single-rounded.
338+
///
339+
/// Mis-aligned shapes (or non-AMX hosts) fall back to the validated
340+
/// scalar [`bf16_gemm_f32`] reference. Phase-4 work will land mixed
341+
/// AMX tile + scalar tail dispatch for arbitrary shapes.
342+
fn bf16_gemm_with_amx(a: &[BF16], b: &[BF16], c: &mut [f32], m: usize, n: usize, k: usize) {
324343
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 32 == 0 {
325344
// SAFETY: BF16 is `#[repr(transparent)] struct BF16(pub u16)`
326345
// (per `hpc::quantized::BF16`). Reinterpreting `&[BF16]` as
@@ -355,11 +374,8 @@ pub fn matmul_bf16_to_f32(
355374
}
356375
}
357376
} else {
358-
bf16_gemm_f32(&a, &b, &mut c, m, n, k, 1.0, 0.0);
377+
bf16_gemm_f32(a, b, c, m, n, k, 1.0, 0.0);
359378
}
360-
361-
write_contig(&mut out, &c);
362-
Ok(())
363379
}
364380

365381
// ── f32 → f32 (BF16 compute on AMX) ────────────────────────────────────────
@@ -381,10 +397,13 @@ pub fn matmul_f32(
381397
let mut c = vec![0.0f32; m * n];
382398

383399
if amx_available() {
384-
// AMX path: down-cast to BF16, run BF16 GEMM, accumulate in f32.
400+
// AMX path: down-cast to BF16 (RNE, ~1 ULP at BF16 mantissa
401+
// precision), then dispatch through the shared BF16 helper
402+
// which picks `TDPBF16PS` tile kernel for 16/16/32-aligned
403+
// shapes and the scalar `bf16_gemm_f32` reference otherwise.
385404
let a_bf16: Vec<BF16> = a_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
386405
let b_bf16: Vec<BF16> = b_f32.iter().map(|&v| BF16::from_f32_rounded(v)).collect();
387-
bf16_gemm_f32(&a_bf16, &b_bf16, &mut c, m, n, k, 1.0, 0.0);
406+
bf16_gemm_with_amx(&a_bf16, &b_bf16, &mut c, m, n, k);
388407
} else {
389408
// Pure f32 reference path.
390409
for i in 0..m {

0 commit comments

Comments
 (0)