feat(ck-tile): add stream_k variant to GEMM Dispatcher codegen#8985
feat(ck-tile): add stream_k variant to GEMM Dispatcher codegen#8985ozturkosu wants to merge 21 commits into
Conversation
Add the stream-K GEMM variant to the unified GEMM dispatcher codegen the dispatcher way: a single-GEMM launch(args, stream) that allocates the reduction workspace internally via DeviceMem (GetWorkSpaceSize / SetWorkSpacePointer), zeroes it, and launches StreamKKernel with an atomic-reduction preprocess that resets C between timed iterations. No external workspace pointer (not the Tile Engine way). - arch_filter.py: add OperatorType.GEMM_STREAMK + tile constraints. - unified_gemm_codegen.py: add GemmVariant.STREAM_K, CLI --variants stream_k, naming, includes, _launch_function_streamk, variant->operator map, cshuffle-only config selection, and A/B/CLayout export in the CK_TILE_SINGLE_KERNEL_INCLUDE block. - examples/gemm/cpp/03_streamk_gemm_driver.cpp: standalone single-kernel driver that calls SelectedKernel::launch and verifies vs reference_gemm. Parity vs Tile Engine on MI300X (gfx942), fp16 rcr atomic 128x128x64_2x2x1_32x32x16, 3840x4096x2048, warmup=10/repeat=50: dispatcher 0.242 ms / 266 TFLOPS PASS vs TE 0.24 ms / 266 TFLOPS correct.
…rategy fields
First slice of moving Stream-K into the dispatcher core (registry-addressable),
per the deep-core checklist. Additive and inert by default:
- KernelKey: new ReductionStrategy enum {None,Atomic,Linear,Tree}; Algorithm
gains streamk / reduction_strategy / workspace. tie() includes them so the
three strategies are distinct keys. encode_identifier() appends the Stream-K
suffix ("_streamk" / "_streamk_linear" / "_streamk_tree") byte-for-byte with
unified_gemm_codegen.py KernelNaming.generate(), guarded by algorithm.streamk
so non-Stream-K identifiers are unchanged.
- Problem: streamk / reduction_strategy request fields + ProblemBuilder::stream_k().
Validated on gfx90a (hipcc 7.12): non-SK encode_identifier byte-identical;
atomic/linear/tree suffixes correct; tie() distinguishes strategies.
Add two non-pure virtuals so existing GEMM/FMHA/Conv instances compile unchanged: - get_workspace_size(Problem) -> bytes (default 0) - run(a,b,c,d_ptrs, void* workspace, problem, stream) overload whose default forwards to the existing no-workspace run(). The Dispatcher invokes these through a base KernelInstance* pointer (so the new overload is visible despite derived 6-arg run() overrides). The Stream-K backend (PR-C) overrides both to size and bind the reduction workspace. Validated on gfx90a (hipcc 7.12): a concrete instance overriding only the pre-existing pure virtuals compiles; default get_workspace_size==0 and the workspace-run forwards correctly via base pointer.
…tree codegen Adds the C++ backend that lets Stream-K ride the registry, plus the reduction strategy codegen needed to generate the three variants on this branch. - generated_tile_backend_streamk.hpp (NEW): GeneratedStreamKKernelInstance wraps a generated Stream-K kernel and builds ck_tile::StreamKHostArgs (the ABI-incompatible args the GemmHostArgs path could not). supports() gates on Problem.streamk + reduction_strategy so atomic/linear/tree coexist in the registry and the Dispatcher's first-fit selection picks the requested one. create_generated_streamk_kernel<> mirrors create_generated_tile_kernel<>. - codegen: reduction_strategy axis (atomic/linear/tree) -> KernelConfig field, key_name redux_*, KernelNaming "_streamk"/"_streamk_linear"/"_streamk_tree" (matches KernelKey::encode_identifier from PR-A), per-strategy StreamKReductionStrategy in the generated launch, and a streamk_config sweep axis. (Ported from the bridge branch reduction-strategy work.) PR-C keeps the generated launch's internal workspace/reset; PR-D relocates those to Dispatcher::run() via get_workspace_size()/the workspace-aware run(). Validated on gfx90a (hipcc 7.12): codegen emits 584 atomic + 584 linear + 584 tree headers with correct names; the backend device-compiles (22s) against a generated header and supports() accepts the matching strategy while rejecting the others and non-Stream-K problems.
…pace Relocate the Stream-K reduction-workspace buffer from the per-call generated launch() to a grow-on-demand buffer owned by the Dispatcher, so a long-lived dispatcher stops paying a hipMalloc/hipFree on every invocation. - codegen: hoist the StreamKGemmKernel type to struct scope and add GetWorkSpaceSize() + an external-workspace launch(args, cfg, workspace) overload. The existing 2-arg launch (internal DeviceMem) is unchanged so the bridge ctypes lib and the standalone 03 driver keep working. - backend: override get_workspace_size() and the workspace-aware run(); the no-workspace run() delegates with a null buffer. The per-iteration reset stays in the backend (it needs CDataType + the reduction strategy). - dispatcher: own a grow-on-demand workspace (raw void*/size_t to keep HIP out of the public header), size it via get_workspace_size(), and pass it through run_fused()/run_explicit(); free it in the destructor. Atomic needs none (size 0 -> null -> internal path); linear/tree consume the owned buffer. Validated on MI210/gfx90a: atomic/linear/tree all verify vs reference_gemm at unchanged perf, with linear/tree now running on the dispatcher-owned workspace.
…river
Add 04_streamk_registry_driver.cpp: a runnable proof of the full deep-core path
(Registry::register_kernel -> Dispatcher::run -> first-fit supports() gate on
reduction_strategy -> GeneratedStreamKKernelInstance::run -> generated launch ->
verify vs reference_gemm). Unlike 03_streamk_gemm_driver.cpp, which calls
SelectedKernel::launch() directly and bypasses the dispatcher, this exercises the
registry selection and the Dispatcher-owned workspace.
Selectable strategy via --strategy {atomic,linear,tree}. Validated on
MI210/gfx90a for all three (distinct registry identifiers, each PASS).
…K backend The dispatcher-wrapper generator emitted ONE template for every variant: backends::GeneratedKernelInstance<KernelStruct> with no streamk/reduction_strategy on the key. For Stream-K that is wrong twice over -- the regular backend calls launch(GemmHostArgs,...) which the SK kernel struct does not have (so the aggregate register_all_kernels.hpp would not compile against SK), and the key omits the SK fields so encode_identifier() emits no _streamk suffix and atomic/linear/tree collide in the registry. Make the wrapper variant-aware: for STREAM_K configs include generated_tile_backend_streamk.hpp, set key.algorithm.streamk + reduction_strategy + workspace (and pad flags for identifier parity), and return create_generated_streamk_kernel<KernelStruct, KernelStruct::ADataType, ...>. All other variants are unchanged. Validated on MI210/gfx90a: a registry populated via the generated wrappers holds atomic+linear+tree side by side; Dispatcher::run() selects each by Problem::reduction_strategy and all three verify vs reference_gemm.
…are atomic reset P2: GeneratedStreamKKernelInstance::supports() now ends with SelectedKernel::IsSupported(make_args(problem)) (a new generated static that runs MakeKernelArgs + IsSupportedArgument). A problem too small to partition across CUs is rejected during selection, so first-fit falls back to a non-Stream-K kernel instead of throwing std::runtime_error at launch. P3: the atomic reduction reset zeroes C with a stride-aware hipMemset2DAsync (pitch = stride_E * sizeof(C), width = N * sizeof(C), height = M) instead of a flat hipMemsetAsync over M*N. Correct for a padded/strided C; identical coverage for the contiguous rcr case. Applied to both the internal and external-workspace launch overloads. Validated on MI210/gfx90a: atomic/linear/tree still select + run + verify from a multi-kernel registry; valid small problems are accepted (no false-negatives).
…eductionStrategy) Close two review nits on the Stream-K drivers: - Parse M/N/K with std::stoll instead of std::stoi in the 03/04 drivers so large GEMM dimensions no longer overflow/throw int range (Copilot nit). - Add inline to_string(ReductionStrategy) in kernel_key.hpp and route the 04 driver through it, removing the driver-local strategy_name() duplicate so callers share one spelling that matches the codegen suffix scheme.
Adds dispatcher_test_streamk_registry, a GPU test that generates the three reduction-strategy kernels (atomic/linear/tree) from one tile config, builds the 04 registry driver once per strategy (each force-including its own header, since SkReductionStrategy is a compile-time constexpr), and asserts for each that the encode_identifier() suffix matches, the Dispatcher selects it by Problem::reduction_strategy, and the result verifies against the reference. This converts the previously manual deep-core validation into a regression- guarded CTest. It SKIPs (return 77) when no GPU or hipcc is present, so CPU-only CI is unaffected.
… driver The standalone stream-K driver verified atomic results with the single-pass GEMM tolerance get_*_threshold<...>(K). Atomic reduction accumulates K-split partials directly into low-precision C (workspace size 0), incurring rounding error that grows with the split factor -- correct results were flagged FAIL on small-M/N, large-K shapes (e.g. 512x512x8192) where tiles < CUs. Mirror tile_engine's calculate_rtol_atol (validation.hpp): derive kbatch from the kernel's tile partitioner (estimate_num_wgs_per_tile), widen atol/rtol with the split-K CDataType accumulation term, and take the max with the per-split tolerance. The driver and tile_engine now verify identically; the kernel is unchanged.
…gine
The standalone stream-K driver built its stream_config as {stream, true, 0,
warmup, repeat}, leaving is_gpu_timer/flush_cache/rotating_count at defaults
(flush_cache=false, rotating_count=1). The tile_engine benchmark instead times
with flush_cache=true and rotating_count=1000, so the driver measured a
warm-cache best case while tile_engine measured cold-cache -- the entire source
of the reported dispatcher-vs-TE "performance gap" at low tile counts.
Add --timer/--flush_cache/--rotating_count (defaulting to the tile_engine
values) and pass them through stream_config so both sides use identical timing
methodology. A validating run still times a single cold shot, mirroring
tile_engine's repeat_once_if_verify(); collect perf with a separate --validate 0
pass.
The 04 registry driver hardcoded the KernelKey signature to DataType::FP16 and an rcr layout, so fp8/bf8/bf16 Stream-K kernels registered under the wrong key and failed dispatch/identifier checks. Derive dtype_a/b/c/acc and layout tags from the generated kernel's actual A/B/C types via compile-time dtype_enum_of<T>()/layout_tag_of<Layout>() helpers (fp8/bf8 inputs accumulate in fp32 and write fp16 C, matching Tile Engine). Parametrize test_streamk_registry.py over fp16/bf16/fp8/bf8 (dtype-independent core objects built once; per-dtype codegen + build + verify with per-dtype identifier assertions). All four datatypes register, dispatch, and verify across atomic/linear/tree on gfx942 (MI300X).
Parametrize the registry test over all four layouts Tile Engine builds
Stream-K for (rcr/rrr/ccr/crr) in addition to datatypes and reduction
strategies. Full coverage is now {fp16,bf16,fp8,bf8} x {rcr,rrr,ccr,crr}
x {atomic,linear,tree}; all four layouts keep C row-major, which the
atomic C-reset relies on. The encode_identifier assertion is generalized
to {dtype}_{layout}, and --datatypes/--layouts flags allow trimming the
matrix for faster CI runs.
make_args() hardcoded rcr leading dims (stride_a=K, stride_b=K, stride_c=N)
for every layout, so non-rcr Stream-K kernels (rrr/ccr/crr) ran with wrong
strides and failed verification. Derive the leading dims from the kernel
key's layouts instead: A is MxK (row->K, col->M), B is KxN (row->N, col->K),
C is MxN (row->N, col->M). rcr is unchanged; rrr/ccr/crr now match the host
tensor strides the driver builds via get_default_stride.
Verified on gfx942/MI300X: full {fp16,bf16,fp8,bf8} x {rcr,rrr,ccr,crr} x
{atomic,linear,tree} matrix (48/48) registers, dispatches, and verifies.
…wing, timing, coverage) - 04 registry driver: port the split-K-aware verification tolerance from the 03 driver (a0ff521) so the deep-core path uses kbatch-derived rtol/atol instead of the plain single-pass threshold that spuriously FAILs correct atomic results on small-M/N, large-K shapes. - Stream-K backend make_args: guard the int64->int32 (index_t) narrowing of M/N/K and derived leading dims; throw on overflow instead of silently wrapping (the parser was widened to std::stoll for exactly this reason). - Stream-K backend run(): flush the L2 between timed iterations so the measurement is cold like tile_engine/the 03 driver (warm cache over-reported TFlops); document that this path is cold-but-non-rotated and not the calibrated apple-to-apple perf surface. - codegen: document the Stream-K tile-coverage limitation (dispatcher emits a narrower tile surface than TE; equivalence is per matched tile config). Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS.
…tadata, arch-filter, test)
- Backend validate(): stop returning a blind "true". Without a host reference it
cannot do a numeric check, but it now validates what it can -- non-null
operands, a well-formed problem, and that this instance supports() it -- so an
unrunnable config is not mis-reported as valid.
- 04 registry driver make_streamk_key(): derive wave_shape/block_size/
transpose_c/double_buffer/persistent/preshuffle/num_wave_groups from the
generated kernel's own static traits instead of hardcoding (the wave_shape was
wrong, e.g. {2,2,1} vs the actual {1,4,1}); the registry identifier now
describes the kernel that was built. pipeline/scheduler stay fixed (baked into
the kernel type, not on the SK selection axis) -- now documented.
- arch_filter GEMM_STREAMK: document that the min_tile constraints are copied
from plain GEMM and the real feasibility gate (enough tiles to partition K
across CUs) is runtime IsSupportedArgument/supports(), not these numbers.
- test_streamk_registry.py: replace the fragile rocminfo line-splitting in
detect_arch with a regex; verify each built kernel against the CLI shape AND a
small-M/large-K (128x128x16384) shape that stresses the split-K tolerance.
Compile-validated for gfx942; small-M/large-K shape verified PASS on MI300X.
Document the Stream-K deep-core path (the feature this branch adds): generate kernels via unified_gemm_codegen --variants stream_k, build/run the 03 standalone driver (perf surface) and the 04 registry driver (Registry -> Dispatcher -> verify), run the CTest, plus the reduction strategies, supported dtypes/layouts, the cold-cache perf methodology, the split-K-aware tolerance, and known limitations (gfx950 fp8/bf8, tile-coverage gap). Linked from the dispatcher README.
…, thread-safety, row-major C assert) Three correctness fixes from a tough review pass on the deep-core path: 1. Zero the reduction workspace in Dispatcher::ensure_workspace() before every Linear/Tree dispatch. The dispatcher previously only hipMalloc'd the buffer and relied on the backend's per-iteration preprocess reset to zero it, so on the non-benchmarking (nrepeat=1) path -- or with set_benchmarking(false) -- the buffer could be garbage and corrupt the reduction. Now correctness is independent of whether the preprocess runs, mirroring the internal DeviceMem::SetZero() the standalone launch already does. 2. Serialize the shared workspace_ buffer with a mutex. workspace_ is mutable state reused across const run() calls; two concurrent linear/tree dispatches on different streams could share one buffer and corrupt each other. The lock spans size->zero->launch (the buffer is in use for the whole run). Atomic / non-Stream-K paths use no workspace and take no lock. 3. Emit a static_assert(is_row_major<CLayout>) in the generated Atomic launch. The atomic C-reset uses a row-major hipMemset2D; a column-major C would be zeroed incorrectly and silently corrupt atomic accumulation. The assert is atomic-only (Linear/Tree zero the workspace, not C). Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS across 3840x4096x2048 and 128x128x16384.
…t preprocess The atomic C-reset (hipMemset2DAsync) and the linear/tree workspace reset (hipMemsetAsync) in the generated launch() were (void)-cast, so a failed reset silently corrupted atomic accumulation / the reduction buffer with no signal. Both now throw std::runtime_error on non-hipSuccess, consistent with the IsSupportedArgument failure path in the same function. Applied to both the internal-workspace and external-workspace launch overloads. Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree compile and verify PASS.
✅ All Checks Passed — Ready for Review
📖 Need help? See the Policy FAQ for details on every check and how to fix failures. |
|
🎉 All checks passed! This PR is ready for review. |
| // tile, taken from the kernel's own tile partitioner so the driver and | ||
| // tile_engine agree on the split factor. | ||
| auto kargs = SelectedKernel::StreamKGemmKernel::MakeKernelArgs(args); | ||
| const ck_tile::index_t kbatch = |
There was a problem hiding this comment.
I suggest altering the variable name of kbatch because kbatch refers to the split-k value in normal CK Tile terminology. Stream-K does not use a "split-k" since there may not be the same number of workgroups contributing to a tile. I recommend the variable name num_wgs_per_tile or something similar. Please update your comments accordingly as well.
| auto kargs = SelectedKernel::StreamKGemmKernel::MakeKernelArgs(args); | ||
| const ck_tile::index_t kbatch = | ||
| std::max<ck_tile::index_t>(1, kargs.tile_partitioner.estimate_num_wgs_per_tile()); | ||
|
|
There was a problem hiding this comment.
Is this code common to other operators? If so, we should probably avoid duplicating it here. I recommend using the existing helper function (if present) or placing one in a common file.
(If this isn't common code, then disregard this comment)
| // instead of assuming fp16. Keeps the registry identifier and selection correct | ||
| // across every datatype the codegen emits. | ||
| template <typename T> | ||
| static constexpr DataType dtype_enum_of() |
There was a problem hiding this comment.
Is this code common to other ops? If so, can we avoid duplication?
| } | ||
|
|
||
| template <typename Layout> | ||
| static constexpr LayoutTag layout_tag_of() |
There was a problem hiding this comment.
Same thing about duplication here.
| key.signature.transpose_a = false; | ||
| key.signature.transpose_b = false; | ||
| key.signature.grouped = false; | ||
| key.signature.split_k = 1; |
There was a problem hiding this comment.
Can you add a comment here that for Stream-K, split-k must be 1?
| // concurrent Stream-K linear/tree dispatches on different streams cannot | ||
| // corrupt each other's reduction. Atomic / non-Stream-K paths use no | ||
| // workspace and take no lock. | ||
| mutable std::mutex workspace_mutex_; |
There was a problem hiding this comment.
Can you explain why we are using a mutex here? Are concurrent kernel runs using the same workspace? If so, we likely want to avoid this.
|
|
||
| /// Stream-K partial-sum reduction strategy. `None` = not a Stream-K kernel. | ||
| /// Mirrors ck_tile::StreamKReductionStrategy (Atomic/Linear/Tree). | ||
| enum class ReductionStrategy : std::uint8_t |
There was a problem hiding this comment.
Can we avoid duplicating this Reduction strategy? Can we use the one defined in ck_tile::StreamKReductionStrategy instead? I just fear this adds unnecessary maintenance.
| LAYOUTS = ["rcr", "rrr", "ccr", "crr"] | ||
|
|
||
|
|
||
| def detect_arch(fallback=None): |
There was a problem hiding this comment.
I think we want to avoid reliance on rocm-info. If possible, can you remove this dependency?
There was a problem hiding this comment.
Pull request overview
Adds Stream-K GEMM as a first-class “deep-core” citizen in the CK Tile Dispatcher, including codegen support, registry/dispatcher selection by reduction strategy, dispatcher-owned workspace management, and end-to-end drivers + CTest coverage.
Changes:
- Extend unified GEMM codegen to generate Stream-K kernels (atomic/linear/tree) and register them via a Stream-K-specific backend wrapper.
- Add dispatcher-owned, reusable reduction workspace (with serialization) for Stream-K linear/tree dispatch.
- Add standalone + registry drivers, documentation, and a GPU/hipcc-gated CTest to validate dispatch + correctness across datatype/layout/strategy combinations.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| projects/composablekernel/dispatcher/codegen/arch_filter.py | Adds GEMM_STREAMK operator type and tile-shape constraints for arch filtering. |
| projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py | Adds Stream-K variant generation, naming, headers, launch paths (internal & external workspace), and wrapper/backend selection. |
| projects/composablekernel/dispatcher/examples/gemm/cpp/03_streamk_gemm_driver.cpp | New direct-launch Stream-K driver for perf/correctness (bypasses dispatcher). |
| projects/composablekernel/dispatcher/examples/gemm/cpp/04_streamk_registry_driver.cpp | New deep-core driver exercising Registry → Dispatcher → Stream-K backend with split-K-aware verification tolerance. |
| projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend_streamk.hpp | New KernelInstance wrapper for Stream-K kernels (supports, workspace sizing, workspace-aware run). |
| projects/composablekernel/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp | Adds dispatcher-owned Stream-K workspace state + mutex and declares ensure_workspace + destructor. |
| projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp | Adds workspace-sizing and workspace-aware run() virtuals (defaulting to no-workspace behavior). |
| projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp | Adds ReductionStrategy and Stream-K algorithm fields; updates identifier encoding and hashing. |
| projects/composablekernel/dispatcher/include/ck_tile/dispatcher/problem.hpp | Adds Stream-K request fields and ProblemBuilder::stream_k(). |
| projects/composablekernel/dispatcher/src/dispatcher.cpp | Implements workspace allocation/zeroing and serialized workspace use in run paths; adds destructor. |
| projects/composablekernel/dispatcher/tests/CMakeLists.txt | Adds a GPU+hipcc-gated CTest target for Stream-K registry testing (SKIP_RETURN_CODE=77). |
| projects/composablekernel/dispatcher/tests/test_streamk_registry.py | New Python test that codegens/builds/runs/verifies Stream-K across datatype/layout/strategy matrix. |
| projects/composablekernel/dispatcher/README.md | Links to Stream-K documentation. |
| projects/composablekernel/dispatcher/STREAMK.md | New documentation for generating, running, validating, and testing Stream-K deep-core path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| out = r.stdout | ||
| ok_verify = "Verification: PASS" in out | ||
| ok_suffix = f"identifier={dtype}_{layout}" in out and want_suffix in out.split( | ||
| "identifier=" | ||
| )[1].split()[0] |
| // Zero the region the kernel will use. Linear/Tree reductions accumulate into | ||
| // this buffer and read it before writing, so a stale/garbage buffer corrupts | ||
| // results. Doing it here makes correctness independent of whether the backend's | ||
| // per-iteration preprocess reset runs (e.g. on the non-benchmarking nrepeat=1 | ||
| // path), mirroring the internal DeviceMem::SetZero() the standalone launch does. | ||
| if(bytes > 0 && hipMemset(workspace_, 0, bytes) != hipSuccess) | ||
| { | ||
| throw std::runtime_error("Dispatcher: failed to zero Stream-K reduction workspace"); | ||
| } |
|
LGTM! |
Add stream-K variant to the GEMM Dispatcher codegen (the dispatcher way)
This is the next slice of the Tile Engine → Dispatcher consolidation, following the same pattern as the grouped_gemm PR (#8075). It adds the stream-K GEMM variant to the unified GEMM codegen, implemented the dispatcher way (workspace owned internally via
DeviceMem, cleanlaunch(args, stream)signature), and proves numeric + performance parity against Tile Engine.Branch is based on
developand contains only the stream-K work (no grouped_gemm commits).What I did
codegen/arch_filter.py— addedOperatorType.GEMM_STREAMKand its tile constraints.codegen/unified_gemm_codegen.py:GemmVariant.STREAM_K, made it reachable from the CLI (--variants stream_k), wired naming (_streamksuffix), includes, and the variant→operator map._launch_function_streamk: builds a singleStreamKHostArgs,MakeKernelArgs→GetWorkSpaceSize→ allocateDeviceMemworkspace internally +SetZero→SetWorkSpacePointer→IsSupportedArgumentcheck →make_kernelvialaunch_kernel_time_maskwith an Atomic-reduction preprocess that zeros C between timed iterations. No externalkargs_ptr(not the Tile Engine way).A/B/CLayoutin theCK_TILE_SINGLE_KERNEL_INCLUDEblock so a single-kernel driver is layout-generic.cshuffleepilogue (only one the kernel supports).examples/gemm/cpp/03_streamk_gemm_driver.cpp(NEW) — minimal standalone driver:-includes one generated stream-K header, builds a single A/B/C tensor, callsSelectedKernel::launch(args, stream), verifies againstck_tile::reference_gemm, prints TFLOPS/GB/s.Problem tried (config + shape)
fp16_rcr_compv3_cshuffle_intrawave_..._128x128x64_2x2x1_32x32x16(atomic reduction; exists identically in TE and the dispatcher).M=3840, N=4096, K=2048,warmup=10,repeat=50, MI300X (gfx942), ROCm 7.1.1.Performance + numerical verification (Dispatcher vs Tile Engine)
The generated GPU kernel (
StreamKKernel<StreamKTilePartitioner, GemmPipeline, GemmEpilogue>) is identical to TE's; only host-side workspace ownership differs (internalDeviceMemvs TE's external pointer). Numerics match.Next
tile_engine/ops/gemm_streamk/.dispatcherGEMM interface folder (roadmap step 5).Deep-core integration (PR-A…E) — accepted design deviations
The deep-core commits make Stream-K a first-class registry citizen (selectable through
Dispatcher::run()byProblem::reduction_strategy). Two deliberate deviations from the literal deep-core spec are worth calling out for reviewers:Per-iteration reset lives in the backend, not in
Dispatcher::run(). The Dispatcher owns/sizes/frees the reduction workspace (ensure_workspace, grow-on-demand, freed in dtor), but the strategy-aware reset stays inside the generated launch (generated_tile_backend_streamk/_launch_function_streamk). Reason: the reset is per-repeat (it runs insidelaunch_kernel_time_mask's preprocess) and dtype-dependent (atomic C-reset needssizeof(CDataType)), which the dtype-erased Dispatcher does not have. Net: workspace owned by Dispatcher, reset owned by backend.Hardware grid is delegated to the ck_tile partitioner. Grid sizing uses
StreamKGemmKernel::GridSize(tile_partitioner)rather than a dispatcher-sideNumCU/Occupancy/get_num_xccscalculation — matching the bridge and keeping ck_tile as the single source of truth for Stream-K work partitioning.Both are sound and additive: non-Stream-K kernels are byte-identical (the
encode_identifier()Stream-K suffix is guarded byalgorithm.streamk), and the 2-arg internal launch is preserved for the bridge /03driver.See the PR comments for the per-commit detail and the gfx942/MI300X validation table.
Update (2026-06-27) — multi-datatype support: fp16 / bf16 / fp8 / bf8
The Stream-K dispatcher path now supports every float datatype Tile Engine builds for Stream-K, and then some:
fp16,bf16,fp8, andbf8. fp8/bf8 inputs accumulate in fp32 and write an fp16 C tensor (get_output_dtype), exactly matching Tile Engine.What was fixed: the codegen (
unified_gemm_codegen.py,arch_filter.py) was already datatype-generic; onlyfp16had been proven. The one real lock-in was in04_streamk_registry_driver.cpp, which hardcoded theKernelKeysignature toDataType::FP16+ rcr layout — so fp8/bf8/bf16 kernels registered under the wrong key and failed dispatch. It now derivesdtype_a/b/c/accand layout tags from the generated kernel's actual A/B/C types via compile-timedtype_enum_of<T>()/layout_tag_of<Layout>()helpers.test_streamk_registry.pyis parametrized over all four datatypes (dtype-independent core objects built once; per-dtype codegen + build + verify + identifier assertions).Validation (gfx942 / MI300X, M=3840 N=4096 K=2048, all verify against
ck_tile::reference_gemm):All four register with the correct identifier (
fp8_rcr…,bf8_rcr…, etc.), are selected byProblem::reduction_strategy, and verify. (int8 deliberately out of scope — atomic integer reduction is unproven in TE as well.)Update (2026-06-27, #2) — full Old-TE functional equivalence: + all layouts
Following the multi-datatype work above, Stream-K is now equivalent to legacy Tile Engine across both axes TE builds for: datatypes {fp16, bf16, fp8, bf8} × layouts {rcr, rrr, ccr, crr} × reduction strategies {atomic, linear, tree}.
Bug found + fixed (caught by an independent review pass): the Stream-K backend's
make_args(generated_tile_backend_streamk.hpp) hardcoded rcr leading dims (stride_a=K, stride_b=K, stride_c=N) for every layout. The04driver fills host tensors with correct per-layout strides viaget_default_strideand callsDispatcher::run, butmake_argsthen overrode them — so rrr/ccr/crr ran with wrong strides and would fail verification. Fixed by deriving leading dims from the kernel key's layouts: A is MxK (row→K, col→M), B is KxN (row→N, col→K), C is MxN (row→N, col→M). rcr is unchanged; all four TE layouts keep C row-major, so the atomic C-reset assumption still holds.Validation — full 48-combo matrix, gfx942/MI300X, M3840 N4096 K2048, all
Verification: PASS:(TFLOPS; all 48 verify against
ck_tile::reference_gemm.) The registry testtest_streamk_registry.pynow covers this full matrix (with--datatypes/--layoutsflags to trim for faster CI). int8/fp32/fp64 remain out of scope (TE builds no Stream-K configs for them).Commits:
9d033fde99(dtypes),7190a19edd(test layout coverage),8595984d5c(backend layout-stride fix).Update (2026-06-30, #3) — scope correction + review blockers addressed
Scope of "equivalence" (correcting the earlier "full functional equivalence" / "~0% perf delta" framing). Numeric correctness (verify vs
ck_tile::reference_gemm) is validated across the full {fp16,bf16,fp8,bf8} × {rcr,rrr,ccr,crr} × {atomic,linear,tree} matrix on matched tile configs. Performance parity vs Tile Engine is demonstrated on one config (fp16 rcr atomic, 3840×4096×2048) — the 48-combo tables above report dispatcher-only TFLOPS, not a TE comparison. Tile coverage is narrower than TE (e.g. fp16 rcr: TE=180 vs DISP=73 tiles), so "functional equivalence" should be read as per matched tile config, not over TE's whole tile surface.Review blockers fixed (correctness):
ensure_workspace), so correctness no longer depends on the backend's per-iteration preprocess running (the non-benchmarkingnrepeat=1path could otherwise hand the kernel a garbage buffer).DeviceMemthat PR-D replaced was concurrency-safe; the owned buffer needed the guard back).static_asserts a row-major C — thehipMemset2DC-reset assumes row-major and would silently miszero a column-major C under atomic accumulation.(void)-discarded.Commits:
111e1f48ce8(workspace zero + mutex + row-major assert),aa0181a8906(memset return checks). Re-verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS across 3840×4096×2048 and 128×128×16384.Next
tile_engine/ops/gemm_streamk/.dispatcherGEMM interface folder (roadmap step 5).validate()a numeric path or an explicit "no-reference" return instead of reusingsupports(); de-duplicate the reset lambda across the two generated launch overloads.