Skip to content

Commit b6bea8a

Browse files
committed
Add fp8/bf8 dtype support to the Stream-K GEMM bridge runner
Extend the Tile-Engine -> Dispatcher Stream-K bridge (PR #8136) beyond fp16/bf16 to the FNUZ fp8 (E4M3) and bf8 (E5M2) formats used by gfx942/MI300. GpuGemmRunner (dispatcher/python/gemm_utils.py): - Port the tested FNUZ codecs from the sibling fp8 bridge (PR #8887): bit-exact decode tables + nearest-representable/saturating encode, carried as uint8 bit patterns (sizeof fp8_t/bf8_t == 1). Encode preserves operand C/F contiguity so the layout-generic _to_buf path holds for the new dtypes. - run() now sizes the C buffer per get_output_dtype: fp8/bf8 -> fp16 store, int8 -> int32; bf16 still carried as raw uint16. fp16/bf16 paths unchanged. - Arch guard: fp8/bf8 raise a clear error on a non-gfx942 GPU (gfx950/MI350 uses OCP fp8, a different bit layout) rather than silently mis-decoding. - An int8 codec is included for when the engine supports it (see below). Reference + surface: - run_one_streamk_gemm_kernel.py verify reference is now dtype-aware (decode(encode(x)) per dtype; int8 = exact int32 matmul). - streamk_gemm_full_benchmark.py SUPPORTED_DTYPES += fp8, bf8. int8 is intentionally left OUT of SUPPORTED_DTYPES: it is blocked at the ck_tile engine, not the bridge. The int8 kernel codegens but fails to compile for every reduction strategy -- warp_gemm_dispatcher has no Dispatcher<int8,int8,float,32,32,16,...> specialization for the streamk CompV3 path, so the BlockUniversalGemmAsBsCr WarpGemm static_asserts fail. Matches the PR #8094 decision to leave int8 out. GPU-validated on gfx942 (MI300X), 2048^3, both reduction + layout variants: fp8 atomic/linear/tree rcr: PASS (192/180/183 TFLOPS, max_rel <= 9.4e-4) bf8 atomic/linear/tree rcr: PASS (192/181/181 TFLOPS, max_rel <= 7.8e-4) fp8 ccr / bf8 crr (col-major): PASS (245/210 TFLOPS)
1 parent 3985cfd commit b6bea8a

3 files changed

Lines changed: 207 additions & 13 deletions

File tree

projects/composablekernel/dispatcher/python/gemm_utils.py

Lines changed: 178 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,141 @@ def cleanup(self) -> None:
391391
# ============================================================================
392392

393393

394+
# ---------------------------------------------------------------------------
395+
# fp8 (E4M3) / bf8 (E5M2) -- FNUZ ("NANOO") encoding used by gfx942/MI300.
396+
#
397+
# numpy has no native 8-bit float, and the C ABI only cares about the 1-byte
398+
# memory layout (sizeof(fp8_t) == sizeof(bf8_t) == 1). We carry the value as a
399+
# uint8 bit pattern. As with bf16, the DECODE is the load-bearing half: it must
400+
# return the exact value the device's fp8_t/bf8_t represents for a byte, so the
401+
# NumPy reference multiplies bit-for-bit what the GPU multiplies. The ENCODE only
402+
# needs to land on the nearest representable byte.
403+
#
404+
# FNUZ format (gfx942): bias = 2^(exp_bits-1); the all-1s exponent is a normal
405+
# number (no Inf), the sole NaN is the sign=1/exp=0/mant=0 byte (0x80), and there
406+
# is no negative zero. gfx950/MI350 uses the OCP fp8 format instead; this codec
407+
# targets the gfx942 default and the OCP path needs separate handling (the runner
408+
# raises a clear error for fp8/bf8 on a non-gfx942 arch).
409+
# ---------------------------------------------------------------------------
410+
411+
412+
def _fnuz_decode_table(exp_bits: int, mant_bits: int) -> np.ndarray:
413+
"""Build the 256-entry byte -> fp32 value table for an 8-bit FNUZ float."""
414+
bias = (1 << (exp_bits - 1))
415+
mant_max = 1 << mant_bits
416+
sign_shift = exp_bits + mant_bits
417+
exp_mask = (1 << exp_bits) - 1
418+
table = np.zeros(256, dtype=np.float32)
419+
for b in range(256):
420+
sign = (b >> sign_shift) & 1
421+
exp = (b >> mant_bits) & exp_mask
422+
mant = b & (mant_max - 1)
423+
if exp == 0 and mant == 0:
424+
# +0 (0x00); the negative-zero slot (0x80) is the lone NaN.
425+
table[b] = np.float32(np.nan) if sign else np.float32(0.0)
426+
continue
427+
if exp == 0:
428+
val = (mant / mant_max) * (2.0 ** (1 - bias)) # subnormal
429+
else:
430+
val = (1.0 + mant / mant_max) * (2.0 ** (exp - bias)) # normal
431+
table[b] = np.float32(-val if sign else val)
432+
return table
433+
434+
435+
def _fnuz_encode(x: np.ndarray, exp_bits: int, mant_bits: int) -> np.ndarray:
436+
"""Encode fp32 -> nearest 8-bit FNUZ float, returned as a uint8 bit pattern.
437+
438+
PRESERVES the input's memory order (C or F) so a column-major operand stays
439+
column-major after encoding.
440+
"""
441+
table = _fnuz_decode_table(exp_bits, mant_bits)
442+
sign_byte = np.uint8(1 << (exp_bits + mant_bits)) # 0x80
443+
444+
# Positive half (bytes 0..127) holds every non-negative magnitude, sorted.
445+
# Compare in float64: for very large inputs the gap between the two top
446+
# magnitudes is below fp32 resolution, which would tie and mis-saturate.
447+
pos_mag = table[: int(sign_byte)].astype(np.float64)
448+
order = np.argsort(pos_mag)
449+
sorted_mag = pos_mag[order]
450+
sorted_byte = order.astype(np.uint8)
451+
452+
xf = np.asarray(x, dtype=np.float32)
453+
if not (xf.flags["C_CONTIGUOUS"] or xf.flags["F_CONTIGUOUS"]):
454+
xf = np.ascontiguousarray(xf)
455+
ax = np.abs(xf).astype(np.float64)
456+
# Both neighbours come from the raw insertion point: raw==size saturates to
457+
# the top magnitude (lo==hi), raw==0 pins to zero, otherwise compare the two.
458+
raw = np.searchsorted(sorted_mag, ax)
459+
hi = np.clip(raw, 0, sorted_mag.size - 1)
460+
lo = np.clip(raw - 1, 0, sorted_mag.size - 1)
461+
pick_lo = np.abs(sorted_mag[lo] - ax) <= np.abs(sorted_mag[hi] - ax)
462+
chosen = np.where(pick_lo, lo, hi)
463+
out = sorted_byte[chosen]
464+
465+
# Apply sign, but never the 0x80 (-0 == NaN) slot: zeros stay +0.
466+
is_zero = sorted_mag[chosen] == 0
467+
out = np.where((xf < 0) & ~is_zero, out | sign_byte, out)
468+
out = np.where(np.isnan(xf), sign_byte, out) # NaN inputs -> NaN byte
469+
# np.where collapses memory order; restore the operand's contiguity.
470+
out = out.astype(np.uint8)
471+
return np.asfortranarray(out) if xf.flags["F_CONTIGUOUS"] else np.ascontiguousarray(out)
472+
473+
474+
def _fp32_to_fp8_u8(x: np.ndarray) -> np.ndarray:
475+
"""Encode fp32 -> fp8 E4M3 (FNUZ) bit pattern in a uint8 array."""
476+
return _fnuz_encode(x, exp_bits=4, mant_bits=3)
477+
478+
479+
def _fp8_u8_to_fp32(u8: np.ndarray) -> np.ndarray:
480+
"""Decode an fp8 E4M3 (FNUZ) bit pattern back to fp32."""
481+
return _fnuz_decode_table(4, 3)[u8.astype(np.intp)]
482+
483+
484+
def _fp32_to_bf8_u8(x: np.ndarray) -> np.ndarray:
485+
"""Encode fp32 -> bf8 E5M2 (FNUZ) bit pattern in a uint8 array."""
486+
return _fnuz_encode(x, exp_bits=5, mant_bits=2)
487+
488+
489+
def _bf8_u8_to_fp32(u8: np.ndarray) -> np.ndarray:
490+
"""Decode a bf8 E5M2 (FNUZ) bit pattern back to fp32."""
491+
return _fnuz_decode_table(5, 2)[u8.astype(np.intp)]
492+
493+
494+
# Output (C) element dtype for an A/B element dtype, mirroring the codegen's
495+
# CommonTypeMappings.get_output_dtype: fp8/bf8 accumulate into fp16, int8 into
496+
# int32, everything else stores in its own dtype.
497+
_OUTPUT_DTYPE = {"fp8": "fp16", "bf8": "fp16", "int8": "int32"}
498+
499+
500+
def _output_dtype(dtype: str) -> str:
501+
return _OUTPUT_DTYPE.get(dtype, dtype)
502+
503+
504+
# numpy carrier dtype for each output (C) element type. fp8/bf8 -> fp16 store,
505+
# int8 -> int32 accumulate, bf16 carried as raw uint16 bits.
506+
_C_NP = {"fp16": np.float16, "bf16": np.uint16, "int32": np.int32}
507+
508+
509+
def _detect_gpu_arch() -> Optional[str]:
510+
"""Best-effort detection of the active GPU's gcnArchName (e.g. 'gfx942').
511+
512+
Parses ``rocminfo`` for the first ``gfx*`` Name line. Returns ``None`` if it
513+
cannot be determined; callers treat that as "cannot verify arch" rather than
514+
a hard failure for non-fp8 dtypes.
515+
"""
516+
import re
517+
import subprocess
518+
519+
try:
520+
out = subprocess.run(
521+
["rocminfo"], capture_output=True, text=True, timeout=30
522+
).stdout
523+
except Exception:
524+
return None
525+
m = re.search(r"^\s*Name:\s*(gfx[0-9a-fA-F]+)\s*$", out, re.MULTILINE)
526+
return m.group(1) if m else None
527+
528+
394529
class GpuGemmRunner:
395530
"""High-level runner: construct from a .so path, call run(A, B, problem).
396531
@@ -434,32 +569,69 @@ def _bf16_encode(x: np.ndarray) -> np.ndarray:
434569
def _bf16_decode(u16: np.ndarray) -> np.ndarray:
435570
return (u16.astype(np.uint32) << 16).view(np.float32)
436571

572+
# fp8/bf8 codecs are bit-exact to the device fp8_t/bf8_t (FNUZ on gfx942);
573+
# re-exposed as static methods so references (smoke test, run_one) can build
574+
# decode(encode(x)) quantized inputs without reaching into module functions.
575+
_fp8_encode = staticmethod(_fp32_to_fp8_u8)
576+
_fp8_decode = staticmethod(_fp8_u8_to_fp32)
577+
_bf8_encode = staticmethod(_fp32_to_bf8_u8)
578+
_bf8_decode = staticmethod(_bf8_u8_to_fp32)
579+
580+
def _check_arch_for_dtype(self) -> None:
581+
"""fp8/bf8 use the gfx942 FNUZ format. gfx950/MI350 uses OCP fp8, a
582+
different bit layout, so refuse rather than silently mis-decode."""
583+
if self._dtype not in ("fp8", "bf8"):
584+
return
585+
arch = _detect_gpu_arch()
586+
if arch is not None and arch != "gfx942":
587+
raise RuntimeError(
588+
f"fp8/bf8 bridge codec is FNUZ (gfx942/MI300) only; detected "
589+
f"GPU arch {arch!r}. gfx950/MI350 uses OCP fp8 (different bit "
590+
f"layout) -- an OCP codec is required for that arch."
591+
)
592+
437593
def _to_buf(self, X: np.ndarray, major: str) -> np.ndarray:
438594
"""Lay out an operand in the order its layout implies: RowMajor ->
439595
C-contiguous, ColumnMajor -> F-contiguous. The .so reads a flat buffer
440-
with the matching stride, so the raw byte order is what matters."""
596+
with the matching stride, so the raw byte order is what matters. The
597+
encode helpers (bf16/fp8/bf8) preserve that contiguity; int8/fp16 keep
598+
the requested order via astype(order='K')."""
441599
arr = np.ascontiguousarray(X) if major == "r" else np.asfortranarray(X)
442600
if self._dtype == "bf16":
443601
return self._bf16_encode(arr)
602+
if self._dtype == "fp8":
603+
return _fp32_to_fp8_u8(arr)
604+
if self._dtype == "bf8":
605+
return _fp32_to_bf8_u8(arr)
606+
if self._dtype == "int8":
607+
return arr.astype(np.int8, order="K")
444608
return arr.astype(np.float16, order="K")
445609

446610
def run(
447611
self, A: np.ndarray, B: np.ndarray, problem: GemmProblem
448612
) -> GemmResult:
449613
M, N, K = problem.M, problem.N, problem.K
614+
self._check_arch_for_dtype()
450615

451-
# Arrange A (MxK), B (KxN), C (MxN) per the kernel's actual layout. bf16 is
452-
# passed as raw uint16 bits (the ctypes ABI is void*+sizeof, so 2-byte bf16
453-
# and fp16 share the path; only the bit pattern differs).
616+
# Arrange A (MxK), B (KxN), C (MxN) per the kernel's actual layout. The
617+
# ctypes ABI is void*+sizeof, so each dtype just needs the right bit
618+
# pattern: bf16 -> uint16, fp8/bf8 -> uint8, int8 -> int8, fp16 -> fp16.
454619
la, lb, lc = self._layout[0], self._layout[1], self._layout[2]
455620
A_h = self._to_buf(A, la)
456621
B_h = self._to_buf(B, lb)
457-
cdt = np.uint16 if self._dtype == "bf16" else np.float16
622+
623+
# The C buffer's element size must equal sizeof(CDataType): fp8/bf8
624+
# accumulate into fp16, int8 into int32, otherwise the input dtype (bf16
625+
# carried as raw uint16 bits).
626+
out_dtype = _output_dtype(self._dtype)
627+
cdt = _C_NP.get(out_dtype, np.float16)
458628
C_h = np.zeros((M, N), dtype=cdt, order=("C" if lc == "r" else "F"))
459629

460630
status, time_ms = self.lib.run(A_h, B_h, C_h, M, N, K)
461631

462-
out = self._bf16_decode(C_h) if self._dtype == "bf16" else C_h
632+
# Decode the output to a comparable numeric array. fp16/fp8/bf8 store fp16
633+
# (already comparable); int8 stores int32; only bf16 needs bit-decode.
634+
out = self._bf16_decode(C_h) if out_dtype == "bf16" else C_h
463635
tflops = (problem.flops / (time_ms * 1e-3)) / 1e12 if time_ms > 0 else 0.0
464636
return GemmResult(
465637
output=out,

projects/composablekernel/tile_engine/ops/gemm/run_one_streamk_gemm_kernel.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,30 @@ def _run_one(idx, so_path, prob_dict, kernel_name, verify=False, verify_tol=2e-2
7979
}
8080
if verify:
8181
# Reference uses the SAME quantized inputs the device sees, per the
82-
# kernel's dtype (bf16 bit-truncation vs fp16), so the metric isolates
83-
# compute error from input quantization.
84-
if getattr(runner, "_dtype", "fp16") == "bf16":
82+
# kernel's dtype (bf16/fp8/bf8 bit-quantization vs fp16), so the
83+
# metric isolates compute error from input quantization. int8 is
84+
# exact: the device multiplies the int8 values directly.
85+
kdt = getattr(runner, "_dtype", "fp16")
86+
if kdt == "bf16":
8587
Aq = GpuGemmRunner._bf16_decode(GpuGemmRunner._bf16_encode(A))
8688
Bq = GpuGemmRunner._bf16_decode(GpuGemmRunner._bf16_encode(B))
89+
ref = Aq @ Bq
90+
elif kdt == "fp8":
91+
Aq = GpuGemmRunner._fp8_decode(GpuGemmRunner._fp8_encode(A))
92+
Bq = GpuGemmRunner._fp8_decode(GpuGemmRunner._fp8_encode(B))
93+
ref = Aq @ Bq
94+
elif kdt == "bf8":
95+
Aq = GpuGemmRunner._bf8_decode(GpuGemmRunner._bf8_encode(A))
96+
Bq = GpuGemmRunner._bf8_decode(GpuGemmRunner._bf8_encode(B))
97+
ref = Aq @ Bq
98+
elif kdt == "int8":
99+
Aq = A.astype(np.int8).astype(np.int32)
100+
Bq = B.astype(np.int8).astype(np.int32)
101+
ref = (Aq @ Bq).astype(np.float32)
87102
else:
88103
Aq = A.astype(np.float16).astype(np.float32)
89104
Bq = B.astype(np.float16).astype(np.float32)
90-
ref = Aq @ Bq
105+
ref = Aq @ Bq
91106
got = result.output.astype(np.float32)
92107
denom = float(np.max(np.abs(ref))) or 1.0
93108
max_rel = float(np.max(np.abs(got - ref)) / denom)

projects/composablekernel/tile_engine/ops/gemm/streamk_gemm_full_benchmark.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,16 @@
6363
# Bridge surface for Stream-K. The dispatcher host path
6464
# (streamk_gemm_ctypes_lib.cpp) derives strides from the kernel's layouts and the
6565
# worker (run_one_streamk_gemm_kernel.py) reads dtype/layout off the kernel name,
66-
# so all 4 A/B/C layouts are supported; dtypes cover fp16 + bf16 (the codecs the
67-
# bridge runner implements). fp8/bf8/int8 await runner codecs.
68-
SUPPORTED_DTYPES = ("fp16", "bf16")
66+
# so all 4 A/B/C layouts are supported. dtypes cover fp16 + bf16 + fp8 + bf8 (the
67+
# codecs the bridge runner implements); fp8/bf8 use the gfx942 FNUZ format and
68+
# accumulate into fp16. int8 is left out: it is blocked at the ck_tile engine
69+
# level, not the bridge -- the int8 kernel codegens but fails to COMPILE for
70+
# every reduction strategy (atomic/linear/tree). warp_gemm_dispatcher has no
71+
# Dispatcher<int8,int8,float,32,32,16,...> specialization for the streamk
72+
# CompV3 path, so WarpGemm resolves to `int` and the BlockUniversalGemmAsBsCr
73+
# WarpGemm::kM/kN static_asserts fail. The runner keeps an int8 codec ready for
74+
# when the engine adds that instantiation; this matches PR #8094 leaving int8 out.
75+
SUPPORTED_DTYPES = ("fp16", "bf16", "fp8", "bf8")
6976
SUPPORTED_LAYOUTS = ("rcr", "rrr", "ccr", "crr")
7077

7178

0 commit comments

Comments
 (0)