Skip to content

Commit b1979d7

Browse files
committed
feat(hpc): TD-T2 — AMX TDPBUSD tile kernel + matmul_i8_to_i32 wiring
Mirror of the BF16 AMX work (TD-T1 / TD-T1b in PR #182) for the integer operand family. Builds the missing int8 tile kernel from scratch (the BF16 equivalent shipped in PR #104; the int8 one had never been built despite the primitives existing in simd_amx since day one) and wires matmul_i8_to_i32's AMX arm through it. New module `hpc::int8_tile_gemm`: * `int8_tile_gemm_16x16(a_u8, b_i8, c, k)` — public tile kernel, K must be multiple of 64. Mirror shape of `bf16_tile_gemm_16x16` but for the `u8 × i8 → i32` operand family that TDPBUSD natively supports. **One TDPBUSD = 16 384 multiply-accumulates per instruction** (16×16 output tile × 64 K-elements per A row × 4 K-elements per inner-product). That's 256× the VPDPBUSD-zmm throughput per instruction. * Internal `amx_path()` uses the existing primitives in `amx_matmul`: TileConfig::for_dpbusd(64) → tile_loadconfig → tile_zero → K/64 iterations of (tile_load A, tile_load B, tile_dpbusd) → tile_store → tile_release. * `fallback_path()` for non-AMX hosts: scalar u8 × i8 → i32 triple-loop reference. New primitive `amx_matmul::vnni_pack_i8(src, dst, k, n)`: * Packs K × N row-major i8 into K/4 outer rows × (N*4) VNNI quad layout required by TDPBUSD tile 2. * `dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]` * Sibling of `vnni_pack_bf16` (which uses K/2 × (N*2) pair layout for TDPBF16PS — both kernels reach the same 64-byte tile row width via element-width × pack-factor symmetry: BF16 is 2B × 2, INT8 is 1B × 4). Wiring `matmul_i8_to_i32`'s AMX arm (was placebo): Pre-commit the AMX branch shifted i8 → u8 then called the SCALAR `int8_gemm_i32` reference and subtracted the bias — TDPBUSD itself was never reached even on real AMX silicon. Now: 1. Shift A: i8 → u8 via (+128). 2. Tile-loop over M/16 i_tile × N/16 j_tile blocks, calling int8_tile_gemm_16x16 per (i_tile, j_tile). B sub-block extracted into K × 16 scratch once per j_tile, reused across i_tile iterations. 3. Subtract bias: c[i, j] -= 128 × colsum(B[:, j]). The shape requirement is m%16 == 0 && n%16 == 0 && k%64 == 0; misaligned shapes fall back to the scalar reference. Phase-4 work will land mixed AMX-tile + per-axis scalar tail handling for arbitrary shapes (same shape of Phase-4 work TD-T1 deferred). Verification: * Default v3 build: 2092 lib tests pass (was 2087 — adds 5 new tests: 4 in int8_tile_gemm + the existing matmul_i8_to_i32 test now exercises the actual TDPBUSD path because this host has amx_int8 + amx_tile in /proc/cpuinfo; the test continues to pass with bit-identical results to the scalar reference). * `vnni_pack_i8_roundtrip` test verifies the pack layout matches the spec exactly for an 8 × 4 sample. * `fallback_matches_scalar_reference_k64` test verifies the non-AMX path produces the same i32 output as a hand-written reference for a 64-K, pseudo-random u8/i8 matrix pair. * `public_api_diagonal_k128` test asserts a structured pattern (A = identity-like, B = constant 2) gives the expected accumulation through the full dispatch chain. * `cargo clippy --lib -D warnings` clean. * `cargo fmt --all --check` clean. Dropped: `int8_gemm_i32` import in `amx_matmul.rs` since the AMX arm no longer falls back to it (the scalar else-branch uses an inline triple-loop directly). After this commit, the per-CPU dispatch table from PR #180 has the AMX tier wired for BOTH operand families on Sapphire Rapids+: BF16 GEMM: SPR+ → TDPBF16PS (TD-T1 / TD-T1b in PR #182) INT8 GEMM: SPR+ → TDPBUSD (this commit) Out of scope (separate PRs): * VPDPBUSD-zmm arm of matmul_i8_to_i32 for Cooper Lake / Cascade Lake / Zen 4+ (avx512vnni without AMX). The kernel function `vnni_dot_u8_i8` and `vnni_matvec` exist in simd_amx.rs; just need to assemble them into a m×n×k GEMM and wire as the middle dispatch tier (analogous to the VDPBF16PS arm in PR #182's bf16_gemm_dispatch). * AMX tile path for `simd_int_ops::gemm_u8_i8` (the slice-level surface from PR #182) — it's u8 × i8 natively so no sign-shift needed, simpler to wire than matmul_i8_to_i32. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent 098c5aa commit b1979d7

3 files changed

Lines changed: 288 additions & 14 deletions

File tree

src/hpc/amx_matmul.rs

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,32 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
193193
}
194194
}
195195

196+
/// Pack B[K, N] i8 row-major into K/4 × (N*4) VNNI quads for `TDPBUSD`.
197+
///
198+
/// Output layout required by `TDPBUSD` tile 2 (16 rows × 64 bytes):
199+
/// dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]
200+
///
201+
/// For N=16 (AMX tile width), each output "row" holds 16 i8 quads = 64
202+
/// bytes (matches the 64-byte tile row width). K must be a multiple of
203+
/// 4. The same layout is used for `u8` operands (just bit-cast through
204+
/// — VNNI doesn't care about sign at the packing layer; sign
205+
/// interpretation happens inside TDPBUSD which treats A as u8 and B
206+
/// as i8 for the multiply).
207+
#[inline]
208+
pub fn vnni_pack_i8(src: &[i8], dst: &mut [i8], k: usize, n: usize) {
209+
debug_assert_eq!(src.len(), k * n);
210+
debug_assert_eq!(dst.len(), k * n);
211+
debug_assert_eq!(k % 4, 0, "K must be multiple of 4 for VNNI INT8 quads");
212+
for kb in 0..(k / 4) {
213+
let dst_row = kb * n * 4;
214+
for j in 0..n {
215+
for p in 0..4 {
216+
dst[dst_row + j * 4 + p] = src[(4 * kb + p) * n + j];
217+
}
218+
}
219+
}
220+
}
221+
196222
// ═══════════════════════════════════════════════════════════════════════════
197223
// Public ndarray-typed matmul API (sprint A4 / Burn parity item 6)
198224
// ═══════════════════════════════════════════════════════════════════════════
@@ -207,7 +233,7 @@ pub fn vnni_pack_bf16(src: &[u16], dst: &mut [u16], k: usize, n: usize) {
207233
// strided (e.g. `view.slice(s![.., ..;2])`). Strided inputs are repacked
208234
// into contiguous staging buffers before the kernel runs.
209235

210-
use crate::hpc::quantized::{bf16_gemm_f32, int8_gemm_i32, BF16};
236+
use crate::hpc::quantized::{bf16_gemm_f32, BF16};
211237
use crate::{ArrayView2, ArrayViewMut2};
212238

213239
/// Errors returned by the public AMX matmul API.
@@ -537,14 +563,17 @@ pub fn matmul_f32(
537563

538564
/// Matrix multiply i8 × i8 → i32: `out = lhs · rhs`.
539565
///
540-
/// On AMX hosts uses `TDPBUSD` (256 MACs/instr); otherwise falls back to
541-
/// the scalar `int8_gemm_i32`.
566+
/// On AMX hosts with 16/16/64-aligned shapes uses `TDPBUSD` via the
567+
/// 16×16 tile kernel in [`crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16`]
568+
/// — 16 384 MACs per instruction. Mis-aligned shapes (or non-AMX hosts)
569+
/// fall back to the scalar i8×i8 → i32 reference.
542570
///
543-
/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For the
544-
/// signed-by-signed surface required here, the LHS is shifted into the
545-
/// unsigned domain and the bias subtracted from the accumulator (only on
546-
/// the AMX path; the scalar path operates directly in i8). The public
547-
/// result is identical.
571+
/// Note: `TDPBUSD` natively expects unsigned-by-signed (u8 × i8). For
572+
/// the signed-by-signed surface required here, the LHS is shifted into
573+
/// the unsigned domain (i8 + 128 → u8) and the bias `128 · sum(B[:, j]
574+
/// over k)` is subtracted from the accumulator. The public result is
575+
/// bit-identical to the scalar reference because all arithmetic stays
576+
/// in i32 (no float rounding).
548577
///
549578
/// `out` must be row-contiguous; inputs may be strided.
550579
pub fn matmul_i8_to_i32(
@@ -556,13 +585,39 @@ pub fn matmul_i8_to_i32(
556585
let b_i8 = pack_contig(&rhs);
557586
let mut c = vec![0i32; m * n];
558587

559-
if amx_available() {
560-
// AMX TDPBUSD path: shift LHS i8 → u8 via (+128) and subtract the
561-
// bias 128·sum(B[:, j] over k) afterwards. This keeps numerics exact.
588+
if amx_available() && m % 16 == 0 && n % 16 == 0 && k % 64 == 0 {
589+
// AMX TDPBUSD path: shift LHS i8 → u8 via (+128), tile-GEMM into
590+
// i32, subtract bias 128·colsum(B). The tile kernel zeroes its
591+
// internal accumulator (TILEZERO + TDPBUSD accumulate); we need
592+
// fresh per-tile output here so we tile manually over M/N and
593+
// call int8_tile_gemm_16x16 per (i, j) block.
562594
let a_u8: Vec<u8> = a_i8.iter().map(|&v| (v as i32 + 128) as u8).collect();
563595

564-
// Compute C' = A_u8 · B_i8 in i32, then subtract 128 · colsum(B).
565-
int8_gemm_i32(&a_u8, &b_i8, &mut c, m, n, k);
596+
// B sub-block extraction per j-tile (B is row-major K × N; the
597+
// tile kernel wants K × 16 contiguous). Reused across i-tiles.
598+
let mut b_tile = vec![0i8; k * 16];
599+
let mut tile_c = vec![0i32; 256];
600+
601+
for j_tile in (0..n).step_by(16) {
602+
// Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows.
603+
for kk in 0..k {
604+
let row = kk * n + j_tile;
605+
b_tile[kk * 16..(kk + 1) * 16]
606+
.copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) });
607+
}
608+
for i_tile in (0..m).step_by(16) {
609+
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
610+
tile_c.fill(0);
611+
crate::hpc::int8_tile_gemm::int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k);
612+
// Write tile_c (16 × 16) into c at (i_tile, j_tile).
613+
for ii in 0..16 {
614+
let dst_off = (i_tile + ii) * n + j_tile;
615+
c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]);
616+
}
617+
}
618+
}
619+
620+
// Subtract bias: c[i, j] -= 128 · colsum(B[:, j]).
566621
let mut colsum = vec![0i32; n];
567622
for p in 0..k {
568623
for j in 0..n {
@@ -575,7 +630,8 @@ pub fn matmul_i8_to_i32(
575630
}
576631
}
577632
} else {
578-
// Scalar i8×i8 → i32 reference.
633+
// Scalar i8×i8 → i32 reference — used for non-AMX hosts and for
634+
// shapes that don't fit the 16/16/64 tile alignment.
579635
for i in 0..m {
580636
for p in 0..k {
581637
let av = a_i8[i * k + p] as i32;

src/hpc/int8_tile_gemm.rs

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
//! INT8 tile GEMM polyfill — AMX (TDPBUSD) tile kernel.
2+
//!
3+
//! Mirror of `hpc::bf16_tile_gemm` for the `u8 × i8 → i32` shape, the
4+
//! native TDPBUSD operand type. One TDPBUSD: 16×16 output tile × 64
5+
//! K-elements per A row × 4 K-elements per inner product = **16 384
6+
//! multiply-accumulates per instruction**. That's 256× the VPDPBUSD
7+
//! zmm throughput per instruction (which does 16 × 4 = 64 MACs).
8+
//!
9+
//! Public surface:
10+
//! * [`int8_tile_gemm_16x16`] — the 16×16 tile kernel; M=16, N=16,
11+
//! K a multiple of 64. AMX path requires runtime feature
12+
//! detection (`amx_available()`); falls back to a scalar reference
13+
//! when AMX isn't OS-enabled.
14+
//!
15+
//! Caller responsibility:
16+
//! * B comes in row-major K × 16 i8; the kernel pre-packs it into
17+
//! VNNI quad layout via [`super::amx_matmul::vnni_pack_i8`].
18+
//! * A is row-major 16 × K u8 (TDPBUSD's unsigned operand).
19+
//! * C accumulates into the caller's i32 buffer (16 × 16 = 256 i32).
20+
//!
21+
//! Same shape as `bf16_tile_gemm::bf16_tile_gemm_16x16`. The two kernels
22+
//! together cover the SPR/GNR AMX dispatch tier for both `BF16 × BF16
23+
//! → f32` and `u8 × i8 → i32` — the two operand families that AMX
24+
//! supports natively.
25+
26+
use crate::hpc::amx_matmul::{
27+
amx_available, tile_dpbusd, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_i8,
28+
TileConfig,
29+
};
30+
31+
// ═════════════════════════════════════════════════════════════════════
32+
// Public API — safe dispatching wrapper
33+
// ═════════════════════════════════════════════════════════════════════
34+
35+
/// Compute C[16, 16] += A[16, K] × B[K, 16] where A is u8 row-major,
36+
/// B is i8 row-major, C is i32 row-major. K must be a multiple of 64.
37+
///
38+
/// Tier dispatch (runtime):
39+
/// AMX available → TDPBUSD tile GEMM (16×16 × K/64 tile iterations,
40+
/// 16 384 MACs per instruction)
41+
/// AMX unavailable → scalar u8 × i8 → i32 reference
42+
///
43+
/// Output behavior: this function **accumulates** into `c` (does NOT
44+
/// zero it first). Callers wanting fresh `C = A·B` semantics should
45+
/// zero `c` before calling, the same convention `bf16_tile_gemm_16x16`
46+
/// uses.
47+
pub fn int8_tile_gemm_16x16(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) {
48+
assert_eq!(k % 64, 0, "K must be multiple of 64 for TDPBUSD tiles");
49+
assert_eq!(a_u8.len(), 16 * k);
50+
assert_eq!(b_i8.len(), k * 16);
51+
assert_eq!(c.len(), 16 * 16);
52+
53+
if amx_available() {
54+
// AMX path: pack B into VNNI quad layout, call tile GEMM.
55+
let mut b_vnni = vec![0i8; k * 16];
56+
vnni_pack_i8(b_i8, &mut b_vnni, k, 16);
57+
// SAFETY: amx_available() just confirmed CPUID + XCR0 + prctl.
58+
unsafe {
59+
amx_path(a_u8, &b_vnni, c, k);
60+
}
61+
} else {
62+
fallback_path(a_u8, b_i8, c, k);
63+
}
64+
}
65+
66+
// ═════════════════════════════════════════════════════════════════════
67+
// AMX path (TDPBUSD)
68+
// ═════════════════════════════════════════════════════════════════════
69+
70+
/// AMX tile GEMM. B must be pre-VNNI-packed (see `vnni_pack_i8`).
71+
/// # Safety
72+
/// Caller must have verified `amx_available() == true`.
73+
#[inline]
74+
unsafe fn amx_path(a_u8: &[u8], b_vnni: &[i8], c: &mut [i32], k: usize) {
75+
// Tile config: 16×64-byte tiles, identical shape to the BF16 tile
76+
// (BF16 is 32 elements × 2 bytes per row, INT8 is 64 elements × 1
77+
// byte — same 64-byte row width either way).
78+
let cfg = TileConfig::for_dpbusd(64);
79+
tile_loadconfig(&cfg);
80+
tile_zero(0);
81+
82+
// Accumulate over K/64 tile blocks. Each TDPBUSD consumes 64
83+
// K-elements per A row × 4 K-elements per inner-product = 256 MACs
84+
// per output cell × 16 × 16 = 16 384 MACs per instruction.
85+
let k_blocks = k / 64;
86+
let a_stride = k; // bytes per A row (u8 = 1 byte each)
87+
let b_stride = 64usize; // VNNI: 16 columns × 4 bytes per row
88+
89+
for kb in 0..k_blocks {
90+
let a_ptr = a_u8.as_ptr().add(kb * 64);
91+
// B sits in VNNI layout: K/4 outer rows × 64 bytes. Each
92+
// 64-K-element block spans 16 outer rows × 64 bytes = 1024
93+
// bytes.
94+
let b_ptr = b_vnni.as_ptr().add(kb * 16 * 64) as *const u8;
95+
tile_load(1, a_ptr, a_stride);
96+
tile_load(2, b_ptr, b_stride);
97+
tile_dpbusd();
98+
}
99+
100+
tile_store(0, c.as_mut_ptr() as *mut u8, 64);
101+
tile_release();
102+
}
103+
104+
// ═════════════════════════════════════════════════════════════════════
105+
// Scalar fallback (i32 reference)
106+
// ═════════════════════════════════════════════════════════════════════
107+
108+
/// Direct scalar u8 × i8 → i32 reference. Accumulates into `c`.
109+
fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) {
110+
for i in 0..16 {
111+
for kk in 0..k {
112+
let a_val = a_u8[i * k + kk] as i32;
113+
for j in 0..16 {
114+
c[i * 16 + j] += a_val * b_i8[kk * 16 + j] as i32;
115+
}
116+
}
117+
}
118+
}
119+
120+
// ═════════════════════════════════════════════════════════════════════
121+
// Tests
122+
// ═════════════════════════════════════════════════════════════════════
123+
124+
#[cfg(test)]
125+
mod tests {
126+
use super::*;
127+
128+
/// Reference: scalar u8 × i8 → i32 (matches `fallback_path`).
129+
fn ref_gemm(a: &[u8], b: &[i8], c: &mut [i32], k: usize) {
130+
for i in 0..16 {
131+
for j in 0..16 {
132+
let mut s = 0i32;
133+
for kk in 0..k {
134+
s += a[i * k + kk] as i32 * b[kk * 16 + j] as i32;
135+
}
136+
c[i * 16 + j] = s;
137+
}
138+
}
139+
}
140+
141+
#[test]
142+
fn fallback_matches_scalar_reference_k64() {
143+
let k = 64;
144+
// Deterministic pseudo-random inputs covering the u8 / i8 ranges.
145+
let a: Vec<u8> = (0..16 * k).map(|i| ((i * 7 + 3) % 256) as u8).collect();
146+
let b: Vec<i8> = (0..k * 16)
147+
.map(|i| (((i * 11 + 5) % 256) as u8 as i8))
148+
.collect();
149+
150+
let mut c_ref = vec![0i32; 256];
151+
ref_gemm(&a, &b, &mut c_ref, k);
152+
153+
let mut c_fb = vec![0i32; 256];
154+
fallback_path(&a, &b, &mut c_fb, k);
155+
156+
for i in 0..256 {
157+
assert_eq!(c_fb[i], c_ref[i], "fallback mismatch at {}", i);
158+
}
159+
}
160+
161+
#[test]
162+
fn public_api_runs_on_any_hardware_k64() {
163+
let k = 64;
164+
let a = vec![0u8; 16 * k];
165+
let b = vec![0i8; k * 16];
166+
let mut c = vec![0i32; 256];
167+
int8_tile_gemm_16x16(&a, &b, &mut c, k);
168+
for v in c.iter() {
169+
assert_eq!(*v, 0, "zero × zero must be 0");
170+
}
171+
}
172+
173+
#[test]
174+
fn public_api_diagonal_k128() {
175+
// A = identity-like (only A[i, i] = 1, but we need 16 × 128), so
176+
// pick A[i, i*8..i*8+8] = 1 (8 ones per i-row). B = constant 2.
177+
// Expected: C[i, j] = sum_{kk in i*8..i*8+8}(1 × 2) = 16.
178+
let k = 128;
179+
let mut a = vec![0u8; 16 * k];
180+
for i in 0..16 {
181+
for off in 0..8 {
182+
a[i * k + i * 8 + off] = 1;
183+
}
184+
}
185+
let b = vec![2i8; k * 16];
186+
let mut c = vec![0i32; 256];
187+
int8_tile_gemm_16x16(&a, &b, &mut c, k);
188+
for i in 0..16 {
189+
for j in 0..16 {
190+
assert_eq!(c[i * 16 + j], 16, "diagonal accumulator at ({}, {})", i, j);
191+
}
192+
}
193+
}
194+
195+
#[test]
196+
fn vnni_pack_i8_roundtrip() {
197+
// Pack then verify the VNNI layout matches the spec:
198+
// dst[kb*N*4 + j*4 + p] = src[(4*kb + p) * N + j]
199+
let k = 8usize;
200+
let n = 4usize;
201+
let src: Vec<i8> = (0..(k * n) as i8).collect();
202+
let mut dst = vec![0i8; k * n];
203+
vnni_pack_i8(&src, &mut dst, k, n);
204+
for kb in 0..(k / 4) {
205+
for j in 0..n {
206+
for p in 0..4 {
207+
let dst_idx = kb * n * 4 + j * 4 + p;
208+
let expected = src[(4 * kb + p) * n + j];
209+
assert_eq!(dst[dst_idx], expected, "vnni quad mismatch at kb={} j={} p={}", kb, j, p);
210+
}
211+
}
212+
}
213+
}
214+
}

src/hpc/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ pub mod heel_f64x8;
6666
pub mod amx_matmul;
6767
#[cfg(target_arch = "x86_64")]
6868
pub mod bf16_tile_gemm;
69+
/// INT8 (`u8 × i8 → i32`) tile GEMM via AMX `TDPBUSD` — mirror of
70+
/// `bf16_tile_gemm` for the integer operand family.
71+
#[cfg(target_arch = "x86_64")]
72+
pub mod int8_tile_gemm;
6973
#[allow(missing_docs)]
7074
pub mod bf16_truth;
7175
#[allow(missing_docs)]

0 commit comments

Comments
 (0)