Skip to content

Commit e8f9ce0

Browse files
committed
fix(hpc): validate int8_gemm_amx_tiled slice lengths (codex P1)
Per codex review on PR #185: `int8_gemm_amx_tiled` is a safe public function (no `unsafe` in the signature) but its inner loop read `b_i8` via `core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16)` without any length check. Callers passing mismatched (m, n, k) vs slice lengths could trigger out-of-bounds reads / UB instead of a panic. Before PR #185 this logic lived only in `matmul_i8_to_i32`'s private AMX arm (where the public `pack_contig` preceded it and bounded everything), but the factored helper is now reachable from `gemm_u8_i8` and any future caller. Fix: 1. Add three boundary assertions at function entry matching `gemm_u8_i8`'s contract: a_u8.len() >= m * k b_i8.len() >= k * n c.len() >= m * n These panic with descriptive messages on undersized input — the safety contract is now enforced at the public function boundary, not at the unsafe pointer-arithmetic site inside the hot loop. 2. Replace the `unsafe { core::slice::from_raw_parts(...) }` B-pack line with safe `b_tile[..].copy_from_slice(&b_i8[row..row + 16])`. The bounds-check inside the loop is now redundant given the function-entry assertions, but the compiler should elide it once the invariant is proven; either way the code becomes panicking- safe instead of UB-on-misuse. 3. Update the doc-comment `# Panics` section to list the boundary panics alongside the existing debug-only AMX / alignment assertions. New regression test `amx_tiled_panics_on_undersized_b`: * Constructs `b: Vec<i8>` half-a-j_tile shorter than the claimed `k * n`. * Calls `int8_gemm_amx_tiled` and asserts the expected panic fires before any unsafe slice arithmetic. * `#[should_panic(expected = "b_i8.len()")]` catches the exact assertion message; works on any host (the boundary check fires before the `debug_assert!(amx_available())` so the test passes on AMX-less CI runners too). Verification: * 2097 lib tests pass (was 2096 — +1 new regression test). * cargo clippy --lib --tests --features rayon,native -- -D warnings clean. * cargo fmt --all --check clean. The matmul_i8_to_i32 path that delegates to int8_gemm_amx_tiled inherits the assertions transparently via the call chain. No behavior change for valid input — only mismatched-shape callers that would have hit UB now get a clean panic instead. https://claude.ai/code/session_01HbqooFZHAjaUtFEzhA1R2u
1 parent 38d4800 commit e8f9ce0

1 file changed

Lines changed: 40 additions & 7 deletions

File tree

src/hpc/int8_tile_gemm.rs

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,22 @@ fn fallback_path(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], k: usize) {
341341
/// is pure overwrite.)
342342
///
343343
/// # Panics
344-
/// Debug-asserts AMX availability and the 16/16/64 shape constraints.
345-
/// Production builds rely on the caller's runtime check
346-
/// (`crate::hpc::amx_matmul::amx_available()`).
344+
/// Panics if `a_u8`, `b_i8`, or `c` are too small for the requested
345+
/// `(m, n, k)`, mirroring the boundary contract from `gemm_u8_i8`. Also
346+
/// panics in debug builds when AMX isn't OS-enabled or when the shape
347+
/// alignment constraints aren't met (production builds skip those for
348+
/// performance — callers must runtime-check
349+
/// `crate::hpc::amx_matmul::amx_available()` and the 16/16/64
350+
/// alignment themselves).
347351
pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) {
352+
// Length assertions (codex P1 from PR #185 — the function reads
353+
// `b_i8` via a 16-wide window per (kk, j_tile) iteration and a_u8
354+
// via a 16-row slice per i_tile, so mismatched shapes would
355+
// trigger out-of-bounds reads without these gates).
356+
assert!(a_u8.len() >= m * k, "int8_gemm_amx_tiled: a_u8.len()={} < m*k={}", a_u8.len(), m * k);
357+
assert!(b_i8.len() >= k * n, "int8_gemm_amx_tiled: b_i8.len()={} < k*n={}", b_i8.len(), k * n);
358+
assert!(c.len() >= m * n, "int8_gemm_amx_tiled: c.len()={} < m*n={}", c.len(), m * n);
359+
348360
debug_assert!(crate::hpc::amx_matmul::amx_available());
349361
debug_assert_eq!(m % 16, 0, "int8_gemm_amx_tiled: M must be multiple of 16");
350362
debug_assert_eq!(n % 16, 0, "int8_gemm_amx_tiled: N must be multiple of 16");
@@ -354,12 +366,13 @@ pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n:
354366
let mut tile_c = vec![0i32; 256];
355367

356368
for j_tile in (0..n).step_by(16) {
357-
// Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows (contiguous
358-
// memory for int8_tile_gemm_16x16's input shape).
369+
// Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows
370+
// (contiguous memory for int8_tile_gemm_16x16's input shape).
371+
// Safe slicing — the row..row+16 range is bounded by
372+
// `b_i8.len() >= k * n` asserted at function entry.
359373
for kk in 0..k {
360374
let row = kk * n + j_tile;
361-
b_tile[kk * 16..(kk + 1) * 16]
362-
.copy_from_slice(unsafe { core::slice::from_raw_parts(b_i8.as_ptr().add(row), 16) });
375+
b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]);
363376
}
364377
for i_tile in (0..m).step_by(16) {
365378
let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k];
@@ -513,6 +526,26 @@ mod tests {
513526
}
514527
}
515528

529+
/// Codex P1 regression on PR #185: `int8_gemm_amx_tiled` is a
530+
/// safe public function — mismatched (m, n, k) vs slice lengths
531+
/// must panic at the function boundary, not trigger UB inside
532+
/// the unsafe slice/pointer arithmetic in the inner loop. This
533+
/// test passes deliberately-undersized buffers and expects a
534+
/// panic (which `#[should_panic]` catches).
535+
#[test]
536+
#[should_panic(expected = "b_i8.len()")]
537+
fn amx_tiled_panics_on_undersized_b() {
538+
let m = 16;
539+
let n = 32;
540+
let k = 64;
541+
let a = vec![0u8; m * k];
542+
let b = vec![0i8; k * (n - 16)]; // half a j_tile short of what's claimed
543+
let mut c = vec![0i32; m * n];
544+
// Even on non-AMX hosts the assertion fires before reaching
545+
// the (debug-asserted) amx_available() check.
546+
int8_gemm_amx_tiled(&a, &b, &mut c, m, n, k);
547+
}
548+
516549
/// Direct test for the VPDPBUSD-ymm arm (AVX-VNNI tier of
517550
/// `matmul_i8_to_i32`). Same shape / bit-exactness contract as
518551
/// the zmm version's test, just on the narrower 8-wide kernel.

0 commit comments

Comments
 (0)