[CK_TILE] gemm bridge: TE<->Dispatcher for fp8/bf8/int8 (all layouts)#8887
Closed
ozturkosu wants to merge 1 commit into
Closed
Conversation
…(all layouts) Adds the remaining data types Tile Engine's plain GEMM has MFMA warp tiles for beyond the fp16/bf16 surface of PR #8479: fp8 (E4M3) and bf8 (E5M2) accumulating into fp16, and int8 accumulating into int32 (gfx942). Covers all four A/B layout combinations per dtype (row-major C only, as ck_tile rejects column-major C). Codegen (codegen_common.py, unified_gemm_codegen.py): - add int32 to the CK / qualified / dispatcher dtype maps - get_output_dtype: int8 -> int32 (fp8/bf8 -> fp16 unchanged) - new get_acc_dtype: int8 -> int32, else fp32 - derive AccDataType, CDataType, the GEMM_KEY_DTYPE_{C,ACC} macros and the registry dtype_c/dtype_acc from the dtype instead of hard-coding float/fp32 Host harness (gemm_utils.py): - fp8/bf8 FNUZ (gfx942) uint8 codecs: exact decode (matches device fp8_t/bf8_t), nearest-representable saturating encode, mirroring the existing bf16 helper - GpuGemmRunner.run encodes A/B and sizes the C buffer per dtype (fp16 for fp8/bf8, int32 for int8) - expand_sweep sets dtype_c/dtype_acc from the input dtype Tests: - test_gemm_utils.py: fp8/bf8 codec round-trip, format ranges, NaN/zero slots, saturation, byte size; output-dtype mapping (CPU-only) - test_gemm_parity.py: fp8/bf8/int8 cases with dtype-aware inputs, references and tolerances (int8 exact); GPU-gated like the existing fp16/bf16 cases GPU parity validation deferred to a follow-up run on an MI300X node.
ozturkosu
added a commit
that referenced
this pull request
Jun 28, 2026
Extend the Tile-Engine -> Dispatcher Stream-K bridge (PR #8136) beyond fp16/bf16 to the FNUZ fp8 (E4M3) and bf8 (E5M2) formats used by gfx942/MI300. GpuGemmRunner (dispatcher/python/gemm_utils.py): - Port the tested FNUZ codecs from the sibling fp8 bridge (PR #8887): bit-exact decode tables + nearest-representable/saturating encode, carried as uint8 bit patterns (sizeof fp8_t/bf8_t == 1). Encode preserves operand C/F contiguity so the layout-generic _to_buf path holds for the new dtypes. - run() now sizes the C buffer per get_output_dtype: fp8/bf8 -> fp16 store, int8 -> int32; bf16 still carried as raw uint16. fp16/bf16 paths unchanged. - Arch guard: fp8/bf8 raise a clear error on a non-gfx942 GPU (gfx950/MI350 uses OCP fp8, a different bit layout) rather than silently mis-decoding. - An int8 codec is included for when the engine supports it (see below). Reference + surface: - run_one_streamk_gemm_kernel.py verify reference is now dtype-aware (decode(encode(x)) per dtype; int8 = exact int32 matmul). - streamk_gemm_full_benchmark.py SUPPORTED_DTYPES += fp8, bf8. int8 is intentionally left OUT of SUPPORTED_DTYPES: it is blocked at the ck_tile engine, not the bridge. The int8 kernel codegens but fails to compile for every reduction strategy -- warp_gemm_dispatcher has no Dispatcher<int8,int8,float,32,32,16,...> specialization for the streamk CompV3 path, so the BlockUniversalGemmAsBsCr WarpGemm static_asserts fail. Matches the PR #8094 decision to leave int8 out. GPU-validated on gfx942 (MI300X), 2048^3, both reduction + layout variants: fp8 atomic/linear/tree rcr: PASS (192/180/183 TFLOPS, max_rel <= 9.4e-4) bf8 atomic/linear/tree rcr: PASS (192/181/181 TFLOPS, max_rel <= 7.8e-4) fp8 ccr / bf8 crr (col-major): PASS (245/210 TFLOPS)
5 tasks
ozturkosu
added a commit
that referenced
this pull request
Jun 29, 2026
The grouped GpuGroupedGemmRunner sized the host C buffer with the INPUT numpy dtype (numpy_dtype_for(dtype); fp8/bf8 = 1 byte). But the generated grouped kernel's CDataType is fp16 for fp8/bf8 inputs (codegen_common CommonTypeMappings.get_output_dtype maps fp8/bf8 -> fp16). The ctypes copy-back therefore wrote 2*M*N bytes into an M*N-byte host buffer -> munmap_chunk(): invalid pointer (heap corruption). fp16/bf16 were unaffected because there CDataType == input dtype. Mirror the #8887 regular-bridge output-dtype handling: add output_dtype_for / output_numpy_dtype_for helpers (fp8/bf8 -> fp16, else identity) and allocate C_h with the OUTPUT numpy dtype, computed once in the runner constructor (self._c_np_dtype). A/B operands keep the input codec. The verify worker reads result.outputs (now fp16) for its non-zero check, so it stays dtype-correct. The single-problem run() hardcoded np.float16 is the regular path and is left untouched. Co-Authored-By: Claude <noreply@anthropic.com>
ozturkosu
added a commit
that referenced
this pull request
Jun 30, 2026
Grouped GEMM variant of the bridge, stacked on the fp8/bf8/int8 bridge (#8887, itself on #8479). The grouped kernel is multi-problem, so it uses a dedicated registry-bypass ctypes ABI; TE only generates configs and benchmarks. - codegen: grouped variant + launch generator (arch_filter.py, unified_gemm_codegen.py), standalone 02_grouped_gemm_driver.cpp, README. - bridge: grouped_gemm_ctypes_lib.cpp (multi-problem ABI), GpuGroupedGemmRunner + dtype/layout codecs (gemm_utils.py), TE driver/worker harness, GROUPED_GEMM_BRIDGE.md. Coverage: {rcr,rrr,ccr,crr} x {fp16,bf16,fp8,bf8}; validated at Old-TE parity on MI300X/gfx942 (64/64 correct, 64/64 within +/-15%).
3 tasks
Contributor
Author
|
Superseded by #8998, which is the same commit ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Extends the Tile Engine ↔ Dispatcher GEMM bridge to the remaining data types TE's plain GEMM has MFMA warp tiles for, beyond the fp16/bf16 surface of #8479:
All four A/B layout combinations per dtype (row-major C only, matching #8479).
fp32/fp64are intentionally excluded — they appear in TE's dtype-string map but have no MFMA warp tiles inGEMM_WARP_TILE_SUPPORTED_COMBINATIONS, so no kernel can be generated/run.Stacked on #8479 (
muozturk/gemm-bridge-all-layouts-bf16), which carries the bridge infrastructure and is not yet merged. Please merge #8479 first, then this can rebase ontodevelop.Changes
codegen_common.py,unified_gemm_codegen.py): addint32to the dtype maps;get_output_dtypeint8→int32; newget_acc_dtype(int8→int32, else fp32); deriveAccDataType/CDataType, theGEMM_KEY_DTYPE_{C,ACC}macros, and the registrydtype_c/dtype_accfrom the dtype instead of hard-codingfloat/fp32.gemm_utils.py): fp8/bf8 FNUZ (gfx942) uint8 codecs — exact decode (matches devicefp8_t/bf8_t), nearest-representable saturating encode (same pattern as the existing bf16 helper);GpuGemmRunner.runencodes A/B and sizes the C buffer per dtype;expand_sweepsetsdtype_c/dtype_acc.test_gemm_utils.pyadds CPU-only fp8/bf8 codec + output-dtype tests (all green);test_gemm_parity.pyadds fp8/bf8/int8 cases with dtype-aware inputs/references/tolerances (int8 is bit-exact), GPU-gated like the existing cases.Verification done
test_gemm_utils.py+test_codegen_common.py: 54 passed (CPU).ADataType/CDataType/AccDataTypeandGEMM_KEY_*macros are correct (int8→int32_t acc/C; fp8→fp16_t C).test_gemm_parity.pycollects 60 cases and skips cleanly without a GPU.test_examples_integration/test_grouped_conv_codegen/test_library_cachingare pre-existing (verified identical on the base branch; they require a built dispatcher.a/ GPU).Test plan
develop.python3 tests/test_gemm_parity.pyand confirm fp8/bf8/int8 parity; tune the fp8/bf8 tolerances if needed (current values are first-cut headroom).🤖 Generated with Claude Code