feat(ck_tile): add stream_k variant to GEMM Dispatcher codegen#8094
feat(ck_tile): add stream_k variant to GEMM Dispatcher codegen#8094ozturkosu wants to merge 21 commits into
Conversation
Add the stream-K GEMM variant to the unified GEMM dispatcher codegen the dispatcher way: a single-GEMM launch(args, stream) that allocates the reduction workspace internally via DeviceMem (GetWorkSpaceSize / SetWorkSpacePointer), zeroes it, and launches StreamKKernel with an atomic-reduction preprocess that resets C between timed iterations. No external workspace pointer (not the Tile Engine way). - arch_filter.py: add OperatorType.GEMM_STREAMK + tile constraints. - unified_gemm_codegen.py: add GemmVariant.STREAM_K, CLI --variants stream_k, naming, includes, _launch_function_streamk, variant->operator map, cshuffle-only config selection, and A/B/CLayout export in the CK_TILE_SINGLE_KERNEL_INCLUDE block. - examples/gemm/cpp/03_streamk_gemm_driver.cpp: standalone single-kernel driver that calls SelectedKernel::launch and verifies vs reference_gemm. Parity vs Tile Engine on MI300X (gfx942), fp16 rcr atomic 128x128x64_2x2x1_32x32x16, 3840x4096x2048, warmup=10/repeat=50: dispatcher 0.242 ms / 266 TFLOPS PASS vs TE 0.24 ms / 266 TFLOPS correct.
There was a problem hiding this comment.
Pull request overview
This PR extends the Composable Kernel dispatcher’s unified GEMM codegen to generate a Stream-K GEMM variant (workspace owned internally via DeviceMem, launch(args, stream) API), and adds a minimal standalone C++ driver that builds/runs a single generated Stream-K kernel header.
Changes:
- Added
GEMM_STREAMKoperator type and tile constraints to the architecture filter. - Added
stream_kas a selectable variant inunified_gemm_codegen.py, including naming, includes, variant→operator mapping, config selection restrictions (cshuffle-only), and a Stream-K launch-path implementation. - Added
03_streamk_gemm_driver.cppexample for single-kernel include builds, benchmarking, and reference verification.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| projects/composablekernel/dispatcher/examples/gemm/cpp/03_streamk_gemm_driver.cpp | New minimal driver that includes a generated Stream-K header and runs/validates a single Stream-K GEMM kernel. |
| projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py | Adds Stream-K variant plumbing and a generated launcher that allocates/zeros workspace internally and launches via launch_kernel_time_mask. |
| projects/composablekernel/dispatcher/codegen/arch_filter.py | Introduces OperatorType.GEMM_STREAMK and associated tile constraints for arch filtering. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| const ck_tile::index_t M = std::stoi(get_opt(argc, argv, "--m", "3840")); | ||
| const ck_tile::index_t N = std::stoi(get_opt(argc, argv, "--n", "4096")); | ||
| const ck_tile::index_t K = std::stoi(get_opt(argc, argv, "--k", "2048")); | ||
| const int warmup = std::stoi(get_opt(argc, argv, "--warmup", "10")); | ||
| const int repeat = std::stoi(get_opt(argc, argv, "--repeat", "50")); |
There was a problem hiding this comment.
Fixed in 23afd2c — M/N/K now parse with std::stoll (into ck_tile::index_t) in both 03_streamk_gemm_driver.cpp and 04_streamk_registry_driver.cpp, so large GEMM dimensions no longer overflow/throw the int range. (--warmup/--repeat stay std::stoi as they are genuinely int.)
| // Atomic reduction accumulates into C, so reset buffers before each run. | ||
| auto reset_data_buffers = [&]() {{ | ||
| if constexpr (ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) {{ | ||
| (void)hipMemsetAsync(args.e_ptr, 0, | ||
| args.M * args.N * sizeof(CDataType), stream.stream_id_); | ||
| }} else {{ | ||
| workspace_dev.SetZero(); | ||
| }} | ||
| }}; |
There was a problem hiding this comment.
Resolved in e8120f5. The atomic C-reset in _launch_function_streamk (both the 2-arg internal and 3-arg external launch overloads) is now stride-aware: it uses hipMemset2DAsync with pitch stride_E * sizeof(CDataType), width N * sizeof(CDataType), height M instead of a flat M*N memset. A padded / non-contiguous row-major C is now zeroed correctly, and contiguous rcr is covered identically. Verified on MI300X/gfx942 (atomic verifies vs reference for the 128x128x64_2x2x1_32x32x16 kernel).
Two Copilot findings on the Stream-K codegen carried in from #8094: - 03_streamk_gemm_driver.cpp: parse M/N/K with std::stoll (not std::stoi) before narrowing to ck_tile::index_t; stoi throws std::out_of_range past INT_MAX, needlessly rejecting large GEMM sizes. - unified_gemm_codegen.py (_launch_function_streamk): the Atomic reduction's per-iteration C reset zeroed args.M*args.N as a flat contiguous block, which skips elements when C has a padded leading dimension and corrupts the accumulation. Zero the used MxN region honoring stride_E via hipMemset2DAsync (CLayout-aware row/col-major), and check the HIP status instead of discarding it. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Superseded by #8136 — keeping open until that landsThis PR's stream-K codegen (commit #8136 is a superset: it contains this codegen and adds the full Python bridge that this PR doesn't have:
It then evolves the codegen further (selectable atomic/linear/tree reduction strategy + Copilot fixes; +102/−18 on Why keep this open for now: #8136 is stacked on #8123, which isn't in |
…rategy fields
First slice of moving Stream-K into the dispatcher core (registry-addressable),
per the deep-core checklist. Additive and inert by default:
- KernelKey: new ReductionStrategy enum {None,Atomic,Linear,Tree}; Algorithm
gains streamk / reduction_strategy / workspace. tie() includes them so the
three strategies are distinct keys. encode_identifier() appends the Stream-K
suffix ("_streamk" / "_streamk_linear" / "_streamk_tree") byte-for-byte with
unified_gemm_codegen.py KernelNaming.generate(), guarded by algorithm.streamk
so non-Stream-K identifiers are unchanged.
- Problem: streamk / reduction_strategy request fields + ProblemBuilder::stream_k().
Validated on gfx90a (hipcc 7.12): non-SK encode_identifier byte-identical;
atomic/linear/tree suffixes correct; tie() distinguishes strategies.
Deep-core Stream-K integration — progress (PR-A) ✅Beginning the deep-core integration (Stream-K as a first-class registry citizen) on this branch, per the design in the MLSE Confluence page Stream-K Deep-Core Integration & Bridge Convergence. Each commit updates the checklist below. Commit Deep-core checklistC++ Dispatcher Core
C++ Backend
Codegen / registration
Functional (re-homed from bridge into the core)
Validation: gfx90a (MI210, hipcc 7.12) — host compile + run: non-SK
|
Add two non-pure virtuals so existing GEMM/FMHA/Conv instances compile unchanged: - get_workspace_size(Problem) -> bytes (default 0) - run(a,b,c,d_ptrs, void* workspace, problem, stream) overload whose default forwards to the existing no-workspace run(). The Dispatcher invokes these through a base KernelInstance* pointer (so the new overload is visible despite derived 6-arg run() overrides). The Stream-K backend (PR-C) overrides both to size and bind the reduction workspace. Validated on gfx90a (hipcc 7.12): a concrete instance overriding only the pre-existing pure virtuals compiles; default get_workspace_size==0 and the workspace-run forwards correctly via base pointer.
Deep-core Stream-K integration — progress (PR-B) ✅Commit Deep-core checklistC++ Dispatcher Core
C++ Backend
Codegen / registration
Functional (re-homed from bridge into the core)
Validation: gfx90a (hipcc 7.12) — a concrete Next: PR-C — the heavy one: new |
…tree codegen Adds the C++ backend that lets Stream-K ride the registry, plus the reduction strategy codegen needed to generate the three variants on this branch. - generated_tile_backend_streamk.hpp (NEW): GeneratedStreamKKernelInstance wraps a generated Stream-K kernel and builds ck_tile::StreamKHostArgs (the ABI-incompatible args the GemmHostArgs path could not). supports() gates on Problem.streamk + reduction_strategy so atomic/linear/tree coexist in the registry and the Dispatcher's first-fit selection picks the requested one. create_generated_streamk_kernel<> mirrors create_generated_tile_kernel<>. - codegen: reduction_strategy axis (atomic/linear/tree) -> KernelConfig field, key_name redux_*, KernelNaming "_streamk"/"_streamk_linear"/"_streamk_tree" (matches KernelKey::encode_identifier from PR-A), per-strategy StreamKReductionStrategy in the generated launch, and a streamk_config sweep axis. (Ported from the bridge branch reduction-strategy work.) PR-C keeps the generated launch's internal workspace/reset; PR-D relocates those to Dispatcher::run() via get_workspace_size()/the workspace-aware run(). Validated on gfx90a (hipcc 7.12): codegen emits 584 atomic + 584 linear + 584 tree headers with correct names; the backend device-compiles (22s) against a generated header and supports() accepts the matching strategy while rejecting the others and non-Stream-K problems.
Deep-core Stream-K integration — progress (PR-C) ✅Commit Deep-core checklistC++ Dispatcher Core
C++ Backend
Codegen / registration
Functional (re-homed from bridge into the core)
Validation (gfx90a / MI210, hipcc 7.12):
Next: PR-D — wire registration (ctypes init builds the SK |
…pace Relocate the Stream-K reduction-workspace buffer from the per-call generated launch() to a grow-on-demand buffer owned by the Dispatcher, so a long-lived dispatcher stops paying a hipMalloc/hipFree on every invocation. - codegen: hoist the StreamKGemmKernel type to struct scope and add GetWorkSpaceSize() + an external-workspace launch(args, cfg, workspace) overload. The existing 2-arg launch (internal DeviceMem) is unchanged so the bridge ctypes lib and the standalone 03 driver keep working. - backend: override get_workspace_size() and the workspace-aware run(); the no-workspace run() delegates with a null buffer. The per-iteration reset stays in the backend (it needs CDataType + the reduction strategy). - dispatcher: own a grow-on-demand workspace (raw void*/size_t to keep HIP out of the public header), size it via get_workspace_size(), and pass it through run_fused()/run_explicit(); free it in the destructor. Atomic needs none (size 0 -> null -> internal path); linear/tree consume the owned buffer. Validated on MI210/gfx90a: atomic/linear/tree all verify vs reference_gemm at unchanged perf, with linear/tree now running on the dispatcher-owned workspace.
…river
Add 04_streamk_registry_driver.cpp: a runnable proof of the full deep-core path
(Registry::register_kernel -> Dispatcher::run -> first-fit supports() gate on
reduction_strategy -> GeneratedStreamKKernelInstance::run -> generated launch ->
verify vs reference_gemm). Unlike 03_streamk_gemm_driver.cpp, which calls
SelectedKernel::launch() directly and bypasses the dispatcher, this exercises the
registry selection and the Dispatcher-owned workspace.
Selectable strategy via --strategy {atomic,linear,tree}. Validated on
MI210/gfx90a for all three (distinct registry identifiers, each PASS).
|
Pushed PR-D + PR-E of the Stream-K deep-core integration (head
Validated on MI210/gfx90a @ 3840×4096×2048, all PASS: atomic 79.2, linear 86.4, tree 86.6 TFLOPS — linear/tree now run on the dispatcher-owned workspace at unchanged perf. gfx942/MI300X re-validation still pending. |
…K backend The dispatcher-wrapper generator emitted ONE template for every variant: backends::GeneratedKernelInstance<KernelStruct> with no streamk/reduction_strategy on the key. For Stream-K that is wrong twice over -- the regular backend calls launch(GemmHostArgs,...) which the SK kernel struct does not have (so the aggregate register_all_kernels.hpp would not compile against SK), and the key omits the SK fields so encode_identifier() emits no _streamk suffix and atomic/linear/tree collide in the registry. Make the wrapper variant-aware: for STREAM_K configs include generated_tile_backend_streamk.hpp, set key.algorithm.streamk + reduction_strategy + workspace (and pad flags for identifier parity), and return create_generated_streamk_kernel<KernelStruct, KernelStruct::ADataType, ...>. All other variants are unchanged. Validated on MI210/gfx90a: a registry populated via the generated wrappers holds atomic+linear+tree side by side; Dispatcher::run() selects each by Problem::reduction_strategy and all three verify vs reference_gemm.
…are atomic reset P2: GeneratedStreamKKernelInstance::supports() now ends with SelectedKernel::IsSupported(make_args(problem)) (a new generated static that runs MakeKernelArgs + IsSupportedArgument). A problem too small to partition across CUs is rejected during selection, so first-fit falls back to a non-Stream-K kernel instead of throwing std::runtime_error at launch. P3: the atomic reduction reset zeroes C with a stride-aware hipMemset2DAsync (pitch = stride_E * sizeof(C), width = N * sizeof(C), height = M) instead of a flat hipMemsetAsync over M*N. Correct for a padded/strided C; identical coverage for the contiguous rcr case. Applied to both the internal and external-workspace launch overloads. Validated on MI210/gfx90a: atomic/linear/tree still select + run + verify from a multi-kernel registry; valid small problems are accepted (no false-negatives).
b74ada4 to
e8120f5
Compare
…eductionStrategy) Close two review nits on the Stream-K drivers: - Parse M/N/K with std::stoll instead of std::stoi in the 03/04 drivers so large GEMM dimensions no longer overflow/throw int range (Copilot nit). - Add inline to_string(ReductionStrategy) in kernel_key.hpp and route the 04 driver through it, removing the driver-local strategy_name() duplicate so callers share one spelling that matches the codegen suffix scheme.
Adds dispatcher_test_streamk_registry, a GPU test that generates the three reduction-strategy kernels (atomic/linear/tree) from one tile config, builds the 04 registry driver once per strategy (each force-including its own header, since SkReductionStrategy is a compile-time constexpr), and asserts for each that the encode_identifier() suffix matches, the Dispatcher selects it by Problem::reduction_strategy, and the result verifies against the reference. This converts the previously manual deep-core validation into a regression- guarded CTest. It SKIPs (return 77) when no GPU or hipcc is present, so CPU-only CI is unaffected.
Commit
|
Commit
|
| strategy | identifier | TFLOPS | verify |
|---|---|---|---|
| atomic | …_streamk |
265.6 | PASS |
| linear | …_streamk_linear |
306.6 | PASS |
| tree | …_streamk_tree |
306.4 | PASS |
Too-small problem (64³) is rejected at selection (supports() → false via IsSupportedArgument), so first-fit falls back gracefully instead of throwing at launch — confirmed on gfx942.
Note: the earlier MI210 numbers and any run that force-includes a single header and passes
--strategy linear/treeonly relabel the registry key; the compiled kernel stays whatever the included header baked in. The new test (and the table above) build each strategy from its own header, so this is a genuine per-strategy validation.
… driver The standalone stream-K driver verified atomic results with the single-pass GEMM tolerance get_*_threshold<...>(K). Atomic reduction accumulates K-split partials directly into low-precision C (workspace size 0), incurring rounding error that grows with the split factor -- correct results were flagged FAIL on small-M/N, large-K shapes (e.g. 512x512x8192) where tiles < CUs. Mirror tile_engine's calculate_rtol_atol (validation.hpp): derive kbatch from the kernel's tile partitioner (estimate_num_wgs_per_tile), widen atol/rtol with the split-K CDataType accumulation term, and take the max with the per-split tolerance. The driver and tile_engine now verify identically; the kernel is unchanged.
…gine
The standalone stream-K driver built its stream_config as {stream, true, 0,
warmup, repeat}, leaving is_gpu_timer/flush_cache/rotating_count at defaults
(flush_cache=false, rotating_count=1). The tile_engine benchmark instead times
with flush_cache=true and rotating_count=1000, so the driver measured a
warm-cache best case while tile_engine measured cold-cache -- the entire source
of the reported dispatcher-vs-TE "performance gap" at low tile counts.
Add --timer/--flush_cache/--rotating_count (defaulting to the tile_engine
values) and pass them through stream_config so both sides use identical timing
methodology. A validating run still times a single cold shot, mirroring
tile_engine's repeat_once_if_verify(); collect perf with a separate --validate 0
pass.
The 04 registry driver hardcoded the KernelKey signature to DataType::FP16 and an rcr layout, so fp8/bf8/bf16 Stream-K kernels registered under the wrong key and failed dispatch/identifier checks. Derive dtype_a/b/c/acc and layout tags from the generated kernel's actual A/B/C types via compile-time dtype_enum_of<T>()/layout_tag_of<Layout>() helpers (fp8/bf8 inputs accumulate in fp32 and write fp16 C, matching Tile Engine). Parametrize test_streamk_registry.py over fp16/bf16/fp8/bf8 (dtype-independent core objects built once; per-dtype codegen + build + verify with per-dtype identifier assertions). All four datatypes register, dispatch, and verify across atomic/linear/tree on gfx942 (MI300X).
Parametrize the registry test over all four layouts Tile Engine builds
Stream-K for (rcr/rrr/ccr/crr) in addition to datatypes and reduction
strategies. Full coverage is now {fp16,bf16,fp8,bf8} x {rcr,rrr,ccr,crr}
x {atomic,linear,tree}; all four layouts keep C row-major, which the
atomic C-reset relies on. The encode_identifier assertion is generalized
to {dtype}_{layout}, and --datatypes/--layouts flags allow trimming the
matrix for faster CI runs.
make_args() hardcoded rcr leading dims (stride_a=K, stride_b=K, stride_c=N)
for every layout, so non-rcr Stream-K kernels (rrr/ccr/crr) ran with wrong
strides and failed verification. Derive the leading dims from the kernel
key's layouts instead: A is MxK (row->K, col->M), B is KxN (row->N, col->K),
C is MxN (row->N, col->M). rcr is unchanged; rrr/ccr/crr now match the host
tensor strides the driver builds via get_default_stride.
Verified on gfx942/MI300X: full {fp16,bf16,fp8,bf8} x {rcr,rrr,ccr,crr} x
{atomic,linear,tree} matrix (48/48) registers, dispatches, and verifies.
Port #8136's Tile-Engine->Dispatcher Stream-K bridge onto the rewritten deep-core #8094 engine (KernelKey reduction fields, KernelInstance workspace virtuals, StreamK backend, Dispatcher-owned reduction workspace, registry + validation driver). 3-way merge over the shared stream_k ancestor; only the streamk launch emitter in unified_gemm_codegen.py and 03_streamk_gemm_driver.cpp conflicted -- both resolved to the deep-core side: - codegen now emits the struct-scope Sk* kernel type + GetWorkSpaceSize + IsSupported, keeps the 2-arg internal-workspace launch the bridge ctypes lib calls, and adds the 3-arg dispatcher-owned-workspace launch. - driver takes deep-core's stoll parse + apple-to-apple timing + validate cold shot. Bridge ctypes lib still bypasses the registry and calls the 2-arg launch directly, so the bridge runs the exact deep-core kernels. Codegen smoke: atomic/linear/tree + regular gemm all generate cleanly (0 failed).
Extend the Tile-Engine -> Dispatcher Stream-K bridge (PR #8136) beyond fp16/bf16 to the FNUZ fp8 (E4M3) and bf8 (E5M2) formats used by gfx942/MI300. GpuGemmRunner (dispatcher/python/gemm_utils.py): - Port the tested FNUZ codecs from the sibling fp8 bridge (PR #8887): bit-exact decode tables + nearest-representable/saturating encode, carried as uint8 bit patterns (sizeof fp8_t/bf8_t == 1). Encode preserves operand C/F contiguity so the layout-generic _to_buf path holds for the new dtypes. - run() now sizes the C buffer per get_output_dtype: fp8/bf8 -> fp16 store, int8 -> int32; bf16 still carried as raw uint16. fp16/bf16 paths unchanged. - Arch guard: fp8/bf8 raise a clear error on a non-gfx942 GPU (gfx950/MI350 uses OCP fp8, a different bit layout) rather than silently mis-decoding. - An int8 codec is included for when the engine supports it (see below). Reference + surface: - run_one_streamk_gemm_kernel.py verify reference is now dtype-aware (decode(encode(x)) per dtype; int8 = exact int32 matmul). - streamk_gemm_full_benchmark.py SUPPORTED_DTYPES += fp8, bf8. int8 is intentionally left OUT of SUPPORTED_DTYPES: it is blocked at the ck_tile engine, not the bridge. The int8 kernel codegens but fails to compile for every reduction strategy -- warp_gemm_dispatcher has no Dispatcher<int8,int8,float,32,32,16,...> specialization for the streamk CompV3 path, so the BlockUniversalGemmAsBsCr WarpGemm static_asserts fail. Matches the PR #8094 decision to leave int8 out. GPU-validated on gfx942 (MI300X), 2048^3, both reduction + layout variants: fp8 atomic/linear/tree rcr: PASS (192/180/183 TFLOPS, max_rel <= 9.4e-4) bf8 atomic/linear/tree rcr: PASS (192/181/181 TFLOPS, max_rel <= 7.8e-4) fp8 ccr / bf8 crr (col-major): PASS (245/210 TFLOPS)
Stream-K dispatcher ⇄ Old-TE equivalence — validation summary (MI300X + MI350)Validated the Stream-K dispatcher against legacy Tile Engine (Old-TE) across datatypes {fp16, bf16, fp8, bf8} × layouts {rcr, rrr, ccr, crr} × reduction {atomic, linear, tree} on MI300X (gfx942) and MI350 (gfx950). Equivalence verdict: ✅ 100% equivalent to Old-TEThere is no case where Old-TE works but the dispatcher fails.
MI300X (gfx942) — fully passing
MI350 (gfx950) — fp16/bf16 pass; fp8/bf8 fail (pre-existing shared bug)
Code in this PR
Full report + CSV: MLSE Confluence (PR #8094 Stream-K equivalence page). |
…wing, timing, coverage) - 04 registry driver: port the split-K-aware verification tolerance from the 03 driver (a0ff521) so the deep-core path uses kbatch-derived rtol/atol instead of the plain single-pass threshold that spuriously FAILs correct atomic results on small-M/N, large-K shapes. - Stream-K backend make_args: guard the int64->int32 (index_t) narrowing of M/N/K and derived leading dims; throw on overflow instead of silently wrapping (the parser was widened to std::stoll for exactly this reason). - Stream-K backend run(): flush the L2 between timed iterations so the measurement is cold like tile_engine/the 03 driver (warm cache over-reported TFlops); document that this path is cold-but-non-rotated and not the calibrated apple-to-apple perf surface. - codegen: document the Stream-K tile-coverage limitation (dispatcher emits a narrower tile surface than TE; equivalence is per matched tile config). Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS.
Addressed review majors (commit ae7b571)Pushed fixes for the four "major" findings from the deep-core review. M1 — M4 — silent int64→int32 narrowing in M2 — registry Validation on MI300X (gfx942)Run via
Still open from the review (not in this commit): rebase onto develop (B1), gate/track gfx950 fp8/bf8 (B2), and the redundancy decision vs #8136 (B3). |
…tadata, arch-filter, test)
- Backend validate(): stop returning a blind "true". Without a host reference it
cannot do a numeric check, but it now validates what it can -- non-null
operands, a well-formed problem, and that this instance supports() it -- so an
unrunnable config is not mis-reported as valid.
- 04 registry driver make_streamk_key(): derive wave_shape/block_size/
transpose_c/double_buffer/persistent/preshuffle/num_wave_groups from the
generated kernel's own static traits instead of hardcoding (the wave_shape was
wrong, e.g. {2,2,1} vs the actual {1,4,1}); the registry identifier now
describes the kernel that was built. pipeline/scheduler stay fixed (baked into
the kernel type, not on the SK selection axis) -- now documented.
- arch_filter GEMM_STREAMK: document that the min_tile constraints are copied
from plain GEMM and the real feasibility gate (enough tiles to partition K
across CUs) is runtime IsSupportedArgument/supports(), not these numbers.
- test_streamk_registry.py: replace the fragile rocminfo line-splitting in
detect_arch with a regex; verify each built kernel against the CLI shape AND a
small-M/large-K (128x128x16384) shape that stresses the split-K tolerance.
Compile-validated for gfx942; small-M/large-K shape verified PASS on MI300X.
Addressed review minors (commit 9470ae9)
Compile-validated for gfx942; the small-M/large-K shape verifies PASS on MI300X. |
Document the Stream-K deep-core path (the feature this branch adds): generate kernels via unified_gemm_codegen --variants stream_k, build/run the 03 standalone driver (perf surface) and the 04 registry driver (Registry -> Dispatcher -> verify), run the CTest, plus the reduction strategies, supported dtypes/layouts, the cold-cache perf methodology, the split-K-aware tolerance, and known limitations (gfx950 fp8/bf8, tile-coverage gap). Linked from the dispatcher README.
…, thread-safety, row-major C assert) Three correctness fixes from a tough review pass on the deep-core path: 1. Zero the reduction workspace in Dispatcher::ensure_workspace() before every Linear/Tree dispatch. The dispatcher previously only hipMalloc'd the buffer and relied on the backend's per-iteration preprocess reset to zero it, so on the non-benchmarking (nrepeat=1) path -- or with set_benchmarking(false) -- the buffer could be garbage and corrupt the reduction. Now correctness is independent of whether the preprocess runs, mirroring the internal DeviceMem::SetZero() the standalone launch already does. 2. Serialize the shared workspace_ buffer with a mutex. workspace_ is mutable state reused across const run() calls; two concurrent linear/tree dispatches on different streams could share one buffer and corrupt each other. The lock spans size->zero->launch (the buffer is in use for the whole run). Atomic / non-Stream-K paths use no workspace and take no lock. 3. Emit a static_assert(is_row_major<CLayout>) in the generated Atomic launch. The atomic C-reset uses a row-major hipMemset2D; a column-major C would be zeroed incorrectly and silently corrupt atomic accumulation. The assert is atomic-only (Linear/Tree zero the workspace, not C). Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS across 3840x4096x2048 and 128x128x16384.
…t preprocess The atomic C-reset (hipMemset2DAsync) and the linear/tree workspace reset (hipMemsetAsync) in the generated launch() were (void)-cast, so a failed reset silently corrupted atomic accumulation / the reduction buffer with no signal. Both now throw std::runtime_error on non-hipSuccess, consistent with the IsSupportedArgument failure path in the same function. Applied to both the internal-workspace and external-workspace launch overloads. Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree compile and verify PASS.
❌ PR Check — Action Required
📖 Need help? See the Policy FAQ for details on every check and how to fix failures. |
|
🚫 Please fix the failed policies before requesting reviews. The following policy checks failed:
The |
|
Superseded by #8985. This PR auto-closed when its branch was renamed to the policy-compliant path |
Add stream-K variant to the GEMM Dispatcher codegen (the dispatcher way)
This is the next slice of the Tile Engine → Dispatcher consolidation, following the same pattern as the grouped_gemm PR (#8075). It adds the stream-K GEMM variant to the unified GEMM codegen, implemented the dispatcher way (workspace owned internally via
DeviceMem, cleanlaunch(args, stream)signature), and proves numeric + performance parity against Tile Engine.Branch is based on
developand contains only the stream-K work (no grouped_gemm commits).What I did
codegen/arch_filter.py— addedOperatorType.GEMM_STREAMKand its tile constraints.codegen/unified_gemm_codegen.py:GemmVariant.STREAM_K, made it reachable from the CLI (--variants stream_k), wired naming (_streamksuffix), includes, and the variant→operator map._launch_function_streamk: builds a singleStreamKHostArgs,MakeKernelArgs→GetWorkSpaceSize→ allocateDeviceMemworkspace internally +SetZero→SetWorkSpacePointer→IsSupportedArgumentcheck →make_kernelvialaunch_kernel_time_maskwith an Atomic-reduction preprocess that zeros C between timed iterations. No externalkargs_ptr(not the Tile Engine way).A/B/CLayoutin theCK_TILE_SINGLE_KERNEL_INCLUDEblock so a single-kernel driver is layout-generic.cshuffleepilogue (only one the kernel supports).examples/gemm/cpp/03_streamk_gemm_driver.cpp(NEW) — minimal standalone driver:-includes one generated stream-K header, builds a single A/B/C tensor, callsSelectedKernel::launch(args, stream), verifies againstck_tile::reference_gemm, prints TFLOPS/GB/s.Problem tried (config + shape)
fp16_rcr_compv3_cshuffle_intrawave_..._128x128x64_2x2x1_32x32x16(atomic reduction; exists identically in TE and the dispatcher).M=3840, N=4096, K=2048,warmup=10,repeat=50, MI300X (gfx942), ROCm 7.1.1.Performance + numerical verification (Dispatcher vs Tile Engine)
The generated GPU kernel (
StreamKKernel<StreamKTilePartitioner, GemmPipeline, GemmEpilogue>) is identical to TE's; only host-side workspace ownership differs (internalDeviceMemvs TE's external pointer). Numerics match.Next
tile_engine/ops/gemm_streamk/.dispatcherGEMM interface folder (roadmap step 5).Deep-core integration (PR-A…E) — accepted design deviations
The deep-core commits make Stream-K a first-class registry citizen (selectable through
Dispatcher::run()byProblem::reduction_strategy). Two deliberate deviations from the literal deep-core spec are worth calling out for reviewers:Per-iteration reset lives in the backend, not in
Dispatcher::run(). The Dispatcher owns/sizes/frees the reduction workspace (ensure_workspace, grow-on-demand, freed in dtor), but the strategy-aware reset stays inside the generated launch (generated_tile_backend_streamk/_launch_function_streamk). Reason: the reset is per-repeat (it runs insidelaunch_kernel_time_mask's preprocess) and dtype-dependent (atomic C-reset needssizeof(CDataType)), which the dtype-erased Dispatcher does not have. Net: workspace owned by Dispatcher, reset owned by backend.Hardware grid is delegated to the ck_tile partitioner. Grid sizing uses
StreamKGemmKernel::GridSize(tile_partitioner)rather than a dispatcher-sideNumCU/Occupancy/get_num_xccscalculation — matching the bridge and keeping ck_tile as the single source of truth for Stream-K work partitioning.Both are sound and additive: non-Stream-K kernels are byte-identical (the
encode_identifier()Stream-K suffix is guarded byalgorithm.streamk), and the 2-arg internal launch is preserved for the bridge /03driver.See the PR comments for the per-commit detail and the gfx942/MI300X validation table.
Update (2026-06-27) — multi-datatype support: fp16 / bf16 / fp8 / bf8
The Stream-K dispatcher path now supports every float datatype Tile Engine builds for Stream-K, and then some:
fp16,bf16,fp8, andbf8. fp8/bf8 inputs accumulate in fp32 and write an fp16 C tensor (get_output_dtype), exactly matching Tile Engine.What was fixed: the codegen (
unified_gemm_codegen.py,arch_filter.py) was already datatype-generic; onlyfp16had been proven. The one real lock-in was in04_streamk_registry_driver.cpp, which hardcoded theKernelKeysignature toDataType::FP16+ rcr layout — so fp8/bf8/bf16 kernels registered under the wrong key and failed dispatch. It now derivesdtype_a/b/c/accand layout tags from the generated kernel's actual A/B/C types via compile-timedtype_enum_of<T>()/layout_tag_of<Layout>()helpers.test_streamk_registry.pyis parametrized over all four datatypes (dtype-independent core objects built once; per-dtype codegen + build + verify + identifier assertions).Validation (gfx942 / MI300X, M=3840 N=4096 K=2048, all verify against
ck_tile::reference_gemm):All four register with the correct identifier (
fp8_rcr…,bf8_rcr…, etc.), are selected byProblem::reduction_strategy, and verify. (int8 deliberately out of scope — atomic integer reduction is unproven in TE as well.)Update (2026-06-27, #2) — full Old-TE functional equivalence: + all layouts
Following the multi-datatype work above, Stream-K is now equivalent to legacy Tile Engine across both axes TE builds for: datatypes {fp16, bf16, fp8, bf8} × layouts {rcr, rrr, ccr, crr} × reduction strategies {atomic, linear, tree}.
Bug found + fixed (caught by an independent review pass): the Stream-K backend's
make_args(generated_tile_backend_streamk.hpp) hardcoded rcr leading dims (stride_a=K, stride_b=K, stride_c=N) for every layout. The04driver fills host tensors with correct per-layout strides viaget_default_strideand callsDispatcher::run, butmake_argsthen overrode them — so rrr/ccr/crr ran with wrong strides and would fail verification. Fixed by deriving leading dims from the kernel key's layouts: A is MxK (row→K, col→M), B is KxN (row→N, col→K), C is MxN (row→N, col→M). rcr is unchanged; all four TE layouts keep C row-major, so the atomic C-reset assumption still holds.Validation — full 48-combo matrix, gfx942/MI300X, M3840 N4096 K2048, all
Verification: PASS:(TFLOPS; all 48 verify against
ck_tile::reference_gemm.) The registry testtest_streamk_registry.pynow covers this full matrix (with--datatypes/--layoutsflags to trim for faster CI). int8/fp32/fp64 remain out of scope (TE builds no Stream-K configs for them).Commits:
9d033fde99(dtypes),7190a19edd(test layout coverage),8595984d5c(backend layout-stride fix).Update (2026-06-30, #3) — scope correction + review blockers addressed
Scope of "equivalence" (correcting the earlier "full functional equivalence" / "~0% perf delta" framing). Numeric correctness (verify vs
ck_tile::reference_gemm) is validated across the full {fp16,bf16,fp8,bf8} × {rcr,rrr,ccr,crr} × {atomic,linear,tree} matrix on matched tile configs. Performance parity vs Tile Engine is demonstrated on one config (fp16 rcr atomic, 3840×4096×2048) — the 48-combo tables above report dispatcher-only TFLOPS, not a TE comparison. Tile coverage is narrower than TE (e.g. fp16 rcr: TE=180 vs DISP=73 tiles), so "functional equivalence" should be read as per matched tile config, not over TE's whole tile surface.Review blockers fixed (correctness):
ensure_workspace), so correctness no longer depends on the backend's per-iteration preprocess running (the non-benchmarkingnrepeat=1path could otherwise hand the kernel a garbage buffer).DeviceMemthat PR-D replaced was concurrency-safe; the owned buffer needed the guard back).static_asserts a row-major C — thehipMemset2DC-reset assumes row-major and would silently miszero a column-major C under atomic accumulation.(void)-discarded.Commits:
111e1f48ce8(workspace zero + mutex + row-major assert),aa0181a8906(memset return checks). Re-verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS across 3840×4096×2048 and 128×128×16384.Next
tile_engine/ops/gemm_streamk/.dispatcherGEMM interface folder (roadmap step 5).validate()a numeric path or an explicit "no-reference" return instead of reusingsupports(); de-duplicate the reset lambda across the two generated launch overloads.