Skip to content

[CK_TILE] gemm bridge: TE<->Dispatcher for fp8/bf8/int8 (all layouts)#8887

Closed
ozturkosu wants to merge 1 commit into
users/muozturk/ck-tile/gemm-bridge-all-layout-bf16-fp16from
muozturk/gemm-bridge-all-layouts-fp8-bf8-int8
Closed

[CK_TILE] gemm bridge: TE<->Dispatcher for fp8/bf8/int8 (all layouts)#8887
ozturkosu wants to merge 1 commit into
users/muozturk/ck-tile/gemm-bridge-all-layout-bf16-fp16from
muozturk/gemm-bridge-all-layouts-fp8-bf8-int8

Conversation

@ozturkosu

Copy link
Copy Markdown
Contributor

Summary

Extends the Tile Engine ↔ Dispatcher GEMM bridge to the remaining data types TE's plain GEMM has MFMA warp tiles for, beyond the fp16/bf16 surface of #8479:

  • fp8 (E4M3) and bf8 (E5M2) → fp16 output, fp32 accumulate
  • int8 → int32 output and accumulate (gfx942)

All four A/B layout combinations per dtype (row-major C only, matching #8479). fp32/fp64 are intentionally excluded — they appear in TE's dtype-string map but have no MFMA warp tiles in GEMM_WARP_TILE_SUPPORTED_COMBINATIONS, so no kernel can be generated/run.

Stacked on #8479 (muozturk/gemm-bridge-all-layouts-bf16), which carries the bridge infrastructure and is not yet merged. Please merge #8479 first, then this can rebase onto develop.

Changes

  • Codegen (codegen_common.py, unified_gemm_codegen.py): add int32 to the dtype maps; get_output_dtype int8→int32; new get_acc_dtype (int8→int32, else fp32); derive AccDataType/CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros, and the registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32.
  • Host harness (gemm_utils.py): fp8/bf8 FNUZ (gfx942) uint8 codecs — exact decode (matches device fp8_t/bf8_t), nearest-representable saturating encode (same pattern as the existing bf16 helper); GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype; expand_sweep sets dtype_c/dtype_acc.
  • Tests: test_gemm_utils.py adds CPU-only fp8/bf8 codec + output-dtype tests (all green); test_gemm_parity.py adds fp8/bf8/int8 cases with dtype-aware inputs/references/tolerances (int8 is bit-exact), GPU-gated like the existing cases.

Verification done

  • test_gemm_utils.py + test_codegen_common.py: 54 passed (CPU).
  • Codegen smoke: fp8/int8/fp16 each generate 1 kernel + 1 wrapper, 0 failed; emitted ADataType/CDataType/AccDataType and GEMM_KEY_* macros are correct (int8→int32_t acc/C; fp8→fp16_t C).
  • test_gemm_parity.py collects 60 cases and skips cleanly without a GPU.
  • The 16 unrelated failures in test_examples_integration / test_grouped_conv_codegen / test_library_caching are pre-existing (verified identical on the base branch; they require a built dispatcher .a / GPU).

Test plan

  • Merge feat(ck_tile): TE -> Dispatcher GEMM bridge (all layouts, fp16/bf16) #8479, rebase this onto develop.
  • On an MI300X (gfx942) node: run python3 tests/test_gemm_parity.py and confirm fp8/bf8/int8 parity; tune the fp8/bf8 tolerances if needed (current values are first-cut headroom).
  • FNUZ vs OCP: the fp8/bf8 host codec targets the gfx942 FNUZ format; validate / extend for gfx950 (OCP) before enabling there.

🤖 Generated with Claude Code

…(all layouts)

Adds the remaining data types Tile Engine's plain GEMM has MFMA warp tiles for
beyond the fp16/bf16 surface of PR #8479: fp8 (E4M3) and bf8 (E5M2) accumulating
into fp16, and int8 accumulating into int32 (gfx942). Covers all four A/B layout
combinations per dtype (row-major C only, as ck_tile rejects column-major C).

Codegen (codegen_common.py, unified_gemm_codegen.py):
- add int32 to the CK / qualified / dispatcher dtype maps
- get_output_dtype: int8 -> int32 (fp8/bf8 -> fp16 unchanged)
- new get_acc_dtype: int8 -> int32, else fp32
- derive AccDataType, CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros and the
  registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32

Host harness (gemm_utils.py):
- fp8/bf8 FNUZ (gfx942) uint8 codecs: exact decode (matches device fp8_t/bf8_t),
  nearest-representable saturating encode, mirroring the existing bf16 helper
- GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype (fp16 for
  fp8/bf8, int32 for int8)
- expand_sweep sets dtype_c/dtype_acc from the input dtype

Tests:
- test_gemm_utils.py: fp8/bf8 codec round-trip, format ranges, NaN/zero slots,
  saturation, byte size; output-dtype mapping (CPU-only)
- test_gemm_parity.py: fp8/bf8/int8 cases with dtype-aware inputs, references and
  tolerances (int8 exact); GPU-gated like the existing fp16/bf16 cases

GPU parity validation deferred to a follow-up run on an MI300X node.
@ozturkosu ozturkosu requested a review from a team as a code owner June 27, 2026 22:39
@ozturkosu ozturkosu self-assigned this Jun 27, 2026
ozturkosu added a commit that referenced this pull request Jun 28, 2026
Extend the Tile-Engine -> Dispatcher Stream-K bridge (PR #8136) beyond
fp16/bf16 to the FNUZ fp8 (E4M3) and bf8 (E5M2) formats used by gfx942/MI300.

GpuGemmRunner (dispatcher/python/gemm_utils.py):
- Port the tested FNUZ codecs from the sibling fp8 bridge (PR #8887):
  bit-exact decode tables + nearest-representable/saturating encode, carried
  as uint8 bit patterns (sizeof fp8_t/bf8_t == 1). Encode preserves operand
  C/F contiguity so the layout-generic _to_buf path holds for the new dtypes.
- run() now sizes the C buffer per get_output_dtype: fp8/bf8 -> fp16 store,
  int8 -> int32; bf16 still carried as raw uint16. fp16/bf16 paths unchanged.
- Arch guard: fp8/bf8 raise a clear error on a non-gfx942 GPU (gfx950/MI350
  uses OCP fp8, a different bit layout) rather than silently mis-decoding.
- An int8 codec is included for when the engine supports it (see below).

Reference + surface:
- run_one_streamk_gemm_kernel.py verify reference is now dtype-aware
  (decode(encode(x)) per dtype; int8 = exact int32 matmul).
- streamk_gemm_full_benchmark.py SUPPORTED_DTYPES += fp8, bf8.

int8 is intentionally left OUT of SUPPORTED_DTYPES: it is blocked at the
ck_tile engine, not the bridge. The int8 kernel codegens but fails to compile
for every reduction strategy -- warp_gemm_dispatcher has no
Dispatcher<int8,int8,float,32,32,16,...> specialization for the streamk CompV3
path, so the BlockUniversalGemmAsBsCr WarpGemm static_asserts fail. Matches the
PR #8094 decision to leave int8 out.

GPU-validated on gfx942 (MI300X), 2048^3, both reduction + layout variants:
  fp8 atomic/linear/tree rcr: PASS  (192/180/183 TFLOPS, max_rel <= 9.4e-4)
  bf8 atomic/linear/tree rcr: PASS  (192/181/181 TFLOPS, max_rel <= 7.8e-4)
  fp8 ccr / bf8 crr (col-major): PASS (245/210 TFLOPS)
ozturkosu added a commit that referenced this pull request Jun 29, 2026
The grouped GpuGroupedGemmRunner sized the host C buffer with the INPUT
numpy dtype (numpy_dtype_for(dtype); fp8/bf8 = 1 byte). But the generated
grouped kernel's CDataType is fp16 for fp8/bf8 inputs (codegen_common
CommonTypeMappings.get_output_dtype maps fp8/bf8 -> fp16). The ctypes
copy-back therefore wrote 2*M*N bytes into an M*N-byte host buffer ->
munmap_chunk(): invalid pointer (heap corruption). fp16/bf16 were
unaffected because there CDataType == input dtype.

Mirror the #8887 regular-bridge output-dtype handling: add
output_dtype_for / output_numpy_dtype_for helpers (fp8/bf8 -> fp16, else
identity) and allocate C_h with the OUTPUT numpy dtype, computed once in
the runner constructor (self._c_np_dtype). A/B operands keep the input
codec. The verify worker reads result.outputs (now fp16) for its
non-zero check, so it stays dtype-correct. The single-problem run()
hardcoded np.float16 is the regular path and is left untouched.

Co-Authored-By: Claude <noreply@anthropic.com>
ozturkosu added a commit that referenced this pull request Jun 30, 2026
Grouped GEMM variant of the bridge, stacked on the fp8/bf8/int8 bridge
(#8887, itself on #8479). The grouped kernel is multi-problem, so it uses a
dedicated registry-bypass ctypes ABI; TE only generates configs and benchmarks.

- codegen: grouped variant + launch generator (arch_filter.py,
  unified_gemm_codegen.py), standalone 02_grouped_gemm_driver.cpp, README.
- bridge: grouped_gemm_ctypes_lib.cpp (multi-problem ABI), GpuGroupedGemmRunner
  + dtype/layout codecs (gemm_utils.py), TE driver/worker harness,
  GROUPED_GEMM_BRIDGE.md.

Coverage: {rcr,rrr,ccr,crr} x {fp16,bf16,fp8,bf8}; validated at Old-TE parity
on MI300X/gfx942 (64/64 correct, 64/64 within +/-15%).
@ozturkosu ozturkosu deleted the branch users/muozturk/ck-tile/gemm-bridge-all-layout-bf16-fp16 July 1, 2026 05:08
@ozturkosu ozturkosu closed this Jul 1, 2026
@ozturkosu

Copy link
Copy Markdown
Contributor Author

Superseded by #8998, which is the same commit (e30be76) on a policy-compliant branch (users/muozturk/ck-tile/gemm-bridge-all-layout-fp8-bf8-int8) targeting develop. Please continue review there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant