Skip to content

feat(ck-tile): add stream_k variant to GEMM Dispatcher codegen#8985

Open
ozturkosu wants to merge 21 commits into
developfrom
users/muozturk/ck-tile/dispatcher-streamk-gemm
Open

feat(ck-tile): add stream_k variant to GEMM Dispatcher codegen#8985
ozturkosu wants to merge 21 commits into
developfrom
users/muozturk/ck-tile/dispatcher-streamk-gemm

Conversation

@ozturkosu

Copy link
Copy Markdown
Contributor

Supersedes #8094 (closed when its branch was renamed to a policy-compliant path). Same commits, same head SHA.

Add stream-K variant to the GEMM Dispatcher codegen (the dispatcher way)

This is the next slice of the Tile Engine → Dispatcher consolidation, following the same pattern as the grouped_gemm PR (#8075). It adds the stream-K GEMM variant to the unified GEMM codegen, implemented the dispatcher way (workspace owned internally via DeviceMem, clean launch(args, stream) signature), and proves numeric + performance parity against Tile Engine.

Branch is based on develop and contains only the stream-K work (no grouped_gemm commits).

What I did

  • codegen/arch_filter.py — added OperatorType.GEMM_STREAMK and its tile constraints.
  • codegen/unified_gemm_codegen.py:
    • Added GemmVariant.STREAM_K, made it reachable from the CLI (--variants stream_k), wired naming (_streamk suffix), includes, and the variant→operator map.
    • New _launch_function_streamk: builds a single StreamKHostArgs, MakeKernelArgsGetWorkSpaceSize → allocate DeviceMem workspace internally + SetZeroSetWorkSpacePointerIsSupportedArgument check → make_kernel via launch_kernel_time_mask with an Atomic-reduction preprocess that zeros C between timed iterations. No external kargs_ptr (not the Tile Engine way).
    • Exported A/B/CLayout in the CK_TILE_SINGLE_KERNEL_INCLUDE block so a single-kernel driver is layout-generic.
    • Restricted stream_k configs to the cshuffle epilogue (only one the kernel supports).
  • examples/gemm/cpp/03_streamk_gemm_driver.cpp (NEW) — minimal standalone driver: -includes one generated stream-K header, builds a single A/B/C tensor, calls SelectedKernel::launch(args, stream), verifies against ck_tile::reference_gemm, prints TFLOPS/GB/s.

Problem tried (config + shape)

  • Config: fp16_rcr_compv3_cshuffle_intrawave_..._128x128x64_2x2x1_32x32x16 (atomic reduction; exists identically in TE and the dispatcher).
  • Shape: M=3840, N=4096, K=2048, warmup=10, repeat=50, MI300X (gfx942), ROCm 7.1.1.

Performance + numerical verification (Dispatcher vs Tile Engine)

latency (ms) TFLOPS GB/s verify
Tile Engine (warmup=10, repeat=50) 0.24 266.7 264.8 correct
Dispatcher (warmup=10, repeat=50) 0.242 266.1 264.2 PASS
Δ ~0% ~0% ~0% identical

Methodology note: TE's benchmark forces repeat=1, warmup=0 whenever verify=1 (the atomic kernel accumulates into C, so it can only verify a single run). A verify=1 invocation therefore reports a single cold iteration (~0.30 ms), which is not a representative perf number. The table above uses TE verify=0 (so warmup/repeat are honored) for the perf row and a separate TE verify=1 run for correctness. The dispatcher driver times (warmup=10/repeat=50) and verifies in the same run because it re-zeros C between timed iterations via the masked preprocess.

The generated GPU kernel (StreamKKernel<StreamKTilePartitioner, GemmPipeline, GemmEpilogue>) is identical to TE's; only host-side workspace ownership differs (internal DeviceMem vs TE's external pointer). Numerics match.

Next

  • Once signed off, delete tile_engine/ops/gemm_streamk/.
  • Continue toward a first-class dispatcher GEMM interface folder (roadmap step 5).

Deep-core integration (PR-A…E) — accepted design deviations

The deep-core commits make Stream-K a first-class registry citizen (selectable through Dispatcher::run() by Problem::reduction_strategy). Two deliberate deviations from the literal deep-core spec are worth calling out for reviewers:

  1. Per-iteration reset lives in the backend, not in Dispatcher::run(). The Dispatcher owns/sizes/frees the reduction workspace (ensure_workspace, grow-on-demand, freed in dtor), but the strategy-aware reset stays inside the generated launch (generated_tile_backend_streamk / _launch_function_streamk). Reason: the reset is per-repeat (it runs inside launch_kernel_time_mask's preprocess) and dtype-dependent (atomic C-reset needs sizeof(CDataType)), which the dtype-erased Dispatcher does not have. Net: workspace owned by Dispatcher, reset owned by backend.

  2. Hardware grid is delegated to the ck_tile partitioner. Grid sizing uses StreamKGemmKernel::GridSize(tile_partitioner) rather than a dispatcher-side NumCU/Occupancy/get_num_xccs calculation — matching the bridge and keeping ck_tile as the single source of truth for Stream-K work partitioning.

Both are sound and additive: non-Stream-K kernels are byte-identical (the encode_identifier() Stream-K suffix is guarded by algorithm.streamk), and the 2-arg internal launch is preserved for the bridge / 03 driver.

See the PR comments for the per-commit detail and the gfx942/MI300X validation table.


Update (2026-06-27) — multi-datatype support: fp16 / bf16 / fp8 / bf8

The Stream-K dispatcher path now supports every float datatype Tile Engine builds for Stream-K, and then some: fp16, bf16, fp8, and bf8. fp8/bf8 inputs accumulate in fp32 and write an fp16 C tensor (get_output_dtype), exactly matching Tile Engine.

What was fixed: the codegen (unified_gemm_codegen.py, arch_filter.py) was already datatype-generic; only fp16 had been proven. The one real lock-in was in 04_streamk_registry_driver.cpp, which hardcoded the KernelKey signature to DataType::FP16 + rcr layout — so fp8/bf8/bf16 kernels registered under the wrong key and failed dispatch. It now derives dtype_a/b/c/acc and layout tags from the generated kernel's actual A/B/C types via compile-time dtype_enum_of<T>() / layout_tag_of<Layout>() helpers. test_streamk_registry.py is parametrized over all four datatypes (dtype-independent core objects built once; per-dtype codegen + build + verify + identifier assertions).

Validation (gfx942 / MI300X, M=3840 N=4096 K=2048, all verify against ck_tile::reference_gemm):

dtype atomic (TFLOPS) linear (TFLOPS) tree (TFLOPS) verify
fp16 275.8 304.7 299.0 PASS
bf16 284.0 298.9 307.6 PASS
fp8 316.2 371.1 373.3 PASS
bf8 353.0 394.6 370.9 PASS

All four register with the correct identifier (fp8_rcr…, bf8_rcr…, etc.), are selected by Problem::reduction_strategy, and verify. (int8 deliberately out of scope — atomic integer reduction is unproven in TE as well.)


Update (2026-06-27, #2) — full Old-TE functional equivalence: + all layouts

Following the multi-datatype work above, Stream-K is now equivalent to legacy Tile Engine across both axes TE builds for: datatypes {fp16, bf16, fp8, bf8} × layouts {rcr, rrr, ccr, crr} × reduction strategies {atomic, linear, tree}.

Bug found + fixed (caught by an independent review pass): the Stream-K backend's make_args (generated_tile_backend_streamk.hpp) hardcoded rcr leading dims (stride_a=K, stride_b=K, stride_c=N) for every layout. The 04 driver fills host tensors with correct per-layout strides via get_default_stride and calls Dispatcher::run, but make_args then overrode them — so rrr/ccr/crr ran with wrong strides and would fail verification. Fixed by deriving leading dims from the kernel key's layouts: A is MxK (row→K, col→M), B is KxN (row→N, col→K), C is MxN (row→N, col→M). rcr is unchanged; all four TE layouts keep C row-major, so the atomic C-reset assumption still holds.

Validation — full 48-combo matrix, gfx942/MI300X, M3840 N4096 K2048, all Verification: PASS:

dtype layout atomic linear tree
fp16 rcr 282.6 302.2 308.1
fp16 rrr 235.1 288.9 298.9
fp16 ccr 250.8 272.0 278.6
fp16 crr 208.8 272.5 236.7
bf16 rcr 281.2 306.8 309.1
bf16 rrr 276.3 298.3 296.0
bf16 ccr 266.0 285.4 288.3
bf16 crr 234.7 282.6 244.7
fp8 rcr 322.3 372.3 372.7
fp8 rrr 244.4 264.2 264.6
fp8 ccr 267.5 298.4 295.1
fp8 crr 253.0 271.7 266.9
bf8 rcr 311.4 378.5 376.1
bf8 rrr 240.3 269.2 265.3
bf8 ccr 271.5 305.9 295.3
bf8 crr 240.2 264.9 272.0

(TFLOPS; all 48 verify against ck_tile::reference_gemm.) The registry test test_streamk_registry.py now covers this full matrix (with --datatypes/--layouts flags to trim for faster CI). int8/fp32/fp64 remain out of scope (TE builds no Stream-K configs for them).

Commits: 9d033fde99 (dtypes), 7190a19edd (test layout coverage), 8595984d5c (backend layout-stride fix).


Update (2026-06-30, #3) — scope correction + review blockers addressed

Scope of "equivalence" (correcting the earlier "full functional equivalence" / "~0% perf delta" framing). Numeric correctness (verify vs ck_tile::reference_gemm) is validated across the full {fp16,bf16,fp8,bf8} × {rcr,rrr,ccr,crr} × {atomic,linear,tree} matrix on matched tile configs. Performance parity vs Tile Engine is demonstrated on one config (fp16 rcr atomic, 3840×4096×2048) — the 48-combo tables above report dispatcher-only TFLOPS, not a TE comparison. Tile coverage is narrower than TE (e.g. fp16 rcr: TE=180 vs DISP=73 tiles), so "functional equivalence" should be read as per matched tile config, not over TE's whole tile surface.

Review blockers fixed (correctness):

  • Reduction workspace is now zeroed by the Dispatcher before every linear/tree dispatch (ensure_workspace), so correctness no longer depends on the backend's per-iteration preprocess running (the non-benchmarking nrepeat=1 path could otherwise hand the kernel a garbage buffer).
  • Workspace access is serialized with a mutex spanning size→zero→launch, so concurrent linear/tree dispatches on different streams cannot corrupt each other's reduction (the per-call DeviceMem that PR-D replaced was concurrency-safe; the owned buffer needed the guard back).
  • Atomic launch now static_asserts a row-major C — the hipMemset2D C-reset assumes row-major and would silently miszero a column-major C under atomic accumulation.
  • HIP memset return codes are checked in the reset preprocess (both launch overloads); a failed reset now throws instead of being (void)-discarded.

Commits: 111e1f48ce8 (workspace zero + mutex + row-major assert), aa0181a8906 (memset return checks). Re-verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch, and verify PASS across 3840×4096×2048 and 128×128×16384.

Next

  • The GEMM bridge for StreamK is being developed in a separate PR**; this PR is scoped to the Stream-K deep-core dispatcher only.
  • Land a true TE-vs-dispatcher perf-parity sweep across the matched tile set (not just the single fp16 rcr atomic config) so the perf claim matches the breadth of the correctness claim.
  • Close the tile-coverage gap by feeding the missing TE tiles into the codegen tile list (fp16/bf16 rcr 124 TE-only, ccr similar; fp8/bf8 closer but still short).
  • Once signed off, delete tile_engine/ops/gemm_streamk/.
  • Continue toward a first-class dispatcher GEMM interface folder (roadmap step 5).
  • Optional follow-ups from the review (non-blocking): give validate() a numeric path or an explicit "no-reference" return instead of reusing supports(); de-duplicate the reset lambda across the two generated launch overloads.

Muhammed Ozturk and others added 21 commits June 5, 2026 02:48
Add the stream-K GEMM variant to the unified GEMM dispatcher codegen the
dispatcher way: a single-GEMM launch(args, stream) that allocates the
reduction workspace internally via DeviceMem (GetWorkSpaceSize /
SetWorkSpacePointer), zeroes it, and launches StreamKKernel with an
atomic-reduction preprocess that resets C between timed iterations. No
external workspace pointer (not the Tile Engine way).

- arch_filter.py: add OperatorType.GEMM_STREAMK + tile constraints.
- unified_gemm_codegen.py: add GemmVariant.STREAM_K, CLI --variants
  stream_k, naming, includes, _launch_function_streamk, variant->operator
  map, cshuffle-only config selection, and A/B/CLayout export in the
  CK_TILE_SINGLE_KERNEL_INCLUDE block.
- examples/gemm/cpp/03_streamk_gemm_driver.cpp: standalone single-kernel
  driver that calls SelectedKernel::launch and verifies vs reference_gemm.

Parity vs Tile Engine on MI300X (gfx942), fp16 rcr atomic
128x128x64_2x2x1_32x32x16, 3840x4096x2048, warmup=10/repeat=50:
dispatcher 0.242 ms / 266 TFLOPS PASS vs TE 0.24 ms / 266 TFLOPS correct.
…rategy fields

First slice of moving Stream-K into the dispatcher core (registry-addressable),
per the deep-core checklist. Additive and inert by default:

- KernelKey: new ReductionStrategy enum {None,Atomic,Linear,Tree}; Algorithm
  gains streamk / reduction_strategy / workspace. tie() includes them so the
  three strategies are distinct keys. encode_identifier() appends the Stream-K
  suffix ("_streamk" / "_streamk_linear" / "_streamk_tree") byte-for-byte with
  unified_gemm_codegen.py KernelNaming.generate(), guarded by algorithm.streamk
  so non-Stream-K identifiers are unchanged.
- Problem: streamk / reduction_strategy request fields + ProblemBuilder::stream_k().

Validated on gfx90a (hipcc 7.12): non-SK encode_identifier byte-identical;
atomic/linear/tree suffixes correct; tie() distinguishes strategies.
Add two non-pure virtuals so existing GEMM/FMHA/Conv instances compile
unchanged:
- get_workspace_size(Problem) -> bytes (default 0)
- run(a,b,c,d_ptrs, void* workspace, problem, stream) overload whose default
  forwards to the existing no-workspace run().

The Dispatcher invokes these through a base KernelInstance* pointer (so the new
overload is visible despite derived 6-arg run() overrides). The Stream-K backend
(PR-C) overrides both to size and bind the reduction workspace.

Validated on gfx90a (hipcc 7.12): a concrete instance overriding only the
pre-existing pure virtuals compiles; default get_workspace_size==0 and the
workspace-run forwards correctly via base pointer.
…tree codegen

Adds the C++ backend that lets Stream-K ride the registry, plus the reduction
strategy codegen needed to generate the three variants on this branch.

- generated_tile_backend_streamk.hpp (NEW): GeneratedStreamKKernelInstance wraps
  a generated Stream-K kernel and builds ck_tile::StreamKHostArgs (the
  ABI-incompatible args the GemmHostArgs path could not). supports() gates on
  Problem.streamk + reduction_strategy so atomic/linear/tree coexist in the
  registry and the Dispatcher's first-fit selection picks the requested one.
  create_generated_streamk_kernel<> mirrors create_generated_tile_kernel<>.
- codegen: reduction_strategy axis (atomic/linear/tree) -> KernelConfig field,
  key_name redux_*, KernelNaming "_streamk"/"_streamk_linear"/"_streamk_tree"
  (matches KernelKey::encode_identifier from PR-A), per-strategy
  StreamKReductionStrategy in the generated launch, and a streamk_config sweep
  axis. (Ported from the bridge branch reduction-strategy work.)

PR-C keeps the generated launch's internal workspace/reset; PR-D relocates those
to Dispatcher::run() via get_workspace_size()/the workspace-aware run().

Validated on gfx90a (hipcc 7.12): codegen emits 584 atomic + 584 linear + 584
tree headers with correct names; the backend device-compiles (22s) against a
generated header and supports() accepts the matching strategy while rejecting
the others and non-Stream-K problems.
…pace

Relocate the Stream-K reduction-workspace buffer from the per-call generated
launch() to a grow-on-demand buffer owned by the Dispatcher, so a long-lived
dispatcher stops paying a hipMalloc/hipFree on every invocation.

- codegen: hoist the StreamKGemmKernel type to struct scope and add
  GetWorkSpaceSize() + an external-workspace launch(args, cfg, workspace)
  overload. The existing 2-arg launch (internal DeviceMem) is unchanged so the
  bridge ctypes lib and the standalone 03 driver keep working.
- backend: override get_workspace_size() and the workspace-aware run(); the
  no-workspace run() delegates with a null buffer. The per-iteration reset stays
  in the backend (it needs CDataType + the reduction strategy).
- dispatcher: own a grow-on-demand workspace (raw void*/size_t to keep HIP out
  of the public header), size it via get_workspace_size(), and pass it through
  run_fused()/run_explicit(); free it in the destructor. Atomic needs none
  (size 0 -> null -> internal path); linear/tree consume the owned buffer.

Validated on MI210/gfx90a: atomic/linear/tree all verify vs reference_gemm at
unchanged perf, with linear/tree now running on the dispatcher-owned workspace.
…river

Add 04_streamk_registry_driver.cpp: a runnable proof of the full deep-core path
(Registry::register_kernel -> Dispatcher::run -> first-fit supports() gate on
reduction_strategy -> GeneratedStreamKKernelInstance::run -> generated launch ->
verify vs reference_gemm). Unlike 03_streamk_gemm_driver.cpp, which calls
SelectedKernel::launch() directly and bypasses the dispatcher, this exercises the
registry selection and the Dispatcher-owned workspace.

Selectable strategy via --strategy {atomic,linear,tree}. Validated on
MI210/gfx90a for all three (distinct registry identifiers, each PASS).
…K backend

The dispatcher-wrapper generator emitted ONE template for every variant:
backends::GeneratedKernelInstance<KernelStruct> with no streamk/reduction_strategy
on the key. For Stream-K that is wrong twice over -- the regular backend calls
launch(GemmHostArgs,...) which the SK kernel struct does not have (so the
aggregate register_all_kernels.hpp would not compile against SK), and the key
omits the SK fields so encode_identifier() emits no _streamk suffix and
atomic/linear/tree collide in the registry.

Make the wrapper variant-aware: for STREAM_K configs include
generated_tile_backend_streamk.hpp, set key.algorithm.streamk +
reduction_strategy + workspace (and pad flags for identifier parity), and return
create_generated_streamk_kernel<KernelStruct, KernelStruct::ADataType, ...>.
All other variants are unchanged.

Validated on MI210/gfx90a: a registry populated via the generated wrappers holds
atomic+linear+tree side by side; Dispatcher::run() selects each by
Problem::reduction_strategy and all three verify vs reference_gemm.
…are atomic reset

P2: GeneratedStreamKKernelInstance::supports() now ends with
SelectedKernel::IsSupported(make_args(problem)) (a new generated static that runs
MakeKernelArgs + IsSupportedArgument). A problem too small to partition across CUs
is rejected during selection, so first-fit falls back to a non-Stream-K kernel
instead of throwing std::runtime_error at launch.

P3: the atomic reduction reset zeroes C with a stride-aware hipMemset2DAsync
(pitch = stride_E * sizeof(C), width = N * sizeof(C), height = M) instead of a
flat hipMemsetAsync over M*N. Correct for a padded/strided C; identical coverage
for the contiguous rcr case. Applied to both the internal and external-workspace
launch overloads.

Validated on MI210/gfx90a: atomic/linear/tree still select + run + verify from a
multi-kernel registry; valid small problems are accepted (no false-negatives).
…eductionStrategy)

Close two review nits on the Stream-K drivers:
- Parse M/N/K with std::stoll instead of std::stoi in the 03/04 drivers so
  large GEMM dimensions no longer overflow/throw int range (Copilot nit).
- Add inline to_string(ReductionStrategy) in kernel_key.hpp and route the 04
  driver through it, removing the driver-local strategy_name() duplicate so
  callers share one spelling that matches the codegen suffix scheme.
Adds dispatcher_test_streamk_registry, a GPU test that generates the three
reduction-strategy kernels (atomic/linear/tree) from one tile config, builds the
04 registry driver once per strategy (each force-including its own header, since
SkReductionStrategy is a compile-time constexpr), and asserts for each that the
encode_identifier() suffix matches, the Dispatcher selects it by
Problem::reduction_strategy, and the result verifies against the reference.

This converts the previously manual deep-core validation into a regression-
guarded CTest. It SKIPs (return 77) when no GPU or hipcc is present, so CPU-only
CI is unaffected.
… driver

The standalone stream-K driver verified atomic results with the single-pass
GEMM tolerance get_*_threshold<...>(K). Atomic reduction accumulates K-split
partials directly into low-precision C (workspace size 0), incurring rounding
error that grows with the split factor -- correct results were flagged FAIL on
small-M/N, large-K shapes (e.g. 512x512x8192) where tiles < CUs.

Mirror tile_engine's calculate_rtol_atol (validation.hpp): derive kbatch from
the kernel's tile partitioner (estimate_num_wgs_per_tile), widen atol/rtol with
the split-K CDataType accumulation term, and take the max with the per-split
tolerance. The driver and tile_engine now verify identically; the kernel is
unchanged.
…gine

The standalone stream-K driver built its stream_config as {stream, true, 0,
warmup, repeat}, leaving is_gpu_timer/flush_cache/rotating_count at defaults
(flush_cache=false, rotating_count=1). The tile_engine benchmark instead times
with flush_cache=true and rotating_count=1000, so the driver measured a
warm-cache best case while tile_engine measured cold-cache -- the entire source
of the reported dispatcher-vs-TE "performance gap" at low tile counts.

Add --timer/--flush_cache/--rotating_count (defaulting to the tile_engine
values) and pass them through stream_config so both sides use identical timing
methodology. A validating run still times a single cold shot, mirroring
tile_engine's repeat_once_if_verify(); collect perf with a separate --validate 0
pass.
The 04 registry driver hardcoded the KernelKey signature to DataType::FP16
and an rcr layout, so fp8/bf8/bf16 Stream-K kernels registered under the
wrong key and failed dispatch/identifier checks. Derive dtype_a/b/c/acc and
layout tags from the generated kernel's actual A/B/C types via compile-time
dtype_enum_of<T>()/layout_tag_of<Layout>() helpers (fp8/bf8 inputs accumulate
in fp32 and write fp16 C, matching Tile Engine).

Parametrize test_streamk_registry.py over fp16/bf16/fp8/bf8 (dtype-independent
core objects built once; per-dtype codegen + build + verify with per-dtype
identifier assertions). All four datatypes register, dispatch, and verify
across atomic/linear/tree on gfx942 (MI300X).
Parametrize the registry test over all four layouts Tile Engine builds
Stream-K for (rcr/rrr/ccr/crr) in addition to datatypes and reduction
strategies. Full coverage is now {fp16,bf16,fp8,bf8} x {rcr,rrr,ccr,crr}
x {atomic,linear,tree}; all four layouts keep C row-major, which the
atomic C-reset relies on. The encode_identifier assertion is generalized
to {dtype}_{layout}, and --datatypes/--layouts flags allow trimming the
matrix for faster CI runs.
make_args() hardcoded rcr leading dims (stride_a=K, stride_b=K, stride_c=N)
for every layout, so non-rcr Stream-K kernels (rrr/ccr/crr) ran with wrong
strides and failed verification. Derive the leading dims from the kernel
key's layouts instead: A is MxK (row->K, col->M), B is KxN (row->N, col->K),
C is MxN (row->N, col->M). rcr is unchanged; rrr/ccr/crr now match the host
tensor strides the driver builds via get_default_stride.

Verified on gfx942/MI300X: full {fp16,bf16,fp8,bf8} x {rcr,rrr,ccr,crr} x
{atomic,linear,tree} matrix (48/48) registers, dispatches, and verifies.
…wing, timing, coverage)

- 04 registry driver: port the split-K-aware verification tolerance from the
  03 driver (a0ff521) so the deep-core path uses kbatch-derived rtol/atol
  instead of the plain single-pass threshold that spuriously FAILs correct
  atomic results on small-M/N, large-K shapes.
- Stream-K backend make_args: guard the int64->int32 (index_t) narrowing of
  M/N/K and derived leading dims; throw on overflow instead of silently
  wrapping (the parser was widened to std::stoll for exactly this reason).
- Stream-K backend run(): flush the L2 between timed iterations so the
  measurement is cold like tile_engine/the 03 driver (warm cache over-reported
  TFlops); document that this path is cold-but-non-rotated and not the
  calibrated apple-to-apple perf surface.
- codegen: document the Stream-K tile-coverage limitation (dispatcher emits a
  narrower tile surface than TE; equivalence is per matched tile config).

Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch,
and verify PASS.
…tadata, arch-filter, test)

- Backend validate(): stop returning a blind "true". Without a host reference it
  cannot do a numeric check, but it now validates what it can -- non-null
  operands, a well-formed problem, and that this instance supports() it -- so an
  unrunnable config is not mis-reported as valid.
- 04 registry driver make_streamk_key(): derive wave_shape/block_size/
  transpose_c/double_buffer/persistent/preshuffle/num_wave_groups from the
  generated kernel's own static traits instead of hardcoding (the wave_shape was
  wrong, e.g. {2,2,1} vs the actual {1,4,1}); the registry identifier now
  describes the kernel that was built. pipeline/scheduler stay fixed (baked into
  the kernel type, not on the SK selection axis) -- now documented.
- arch_filter GEMM_STREAMK: document that the min_tile constraints are copied
  from plain GEMM and the real feasibility gate (enough tiles to partition K
  across CUs) is runtime IsSupportedArgument/supports(), not these numbers.
- test_streamk_registry.py: replace the fragile rocminfo line-splitting in
  detect_arch with a regex; verify each built kernel against the CLI shape AND a
  small-M/large-K (128x128x16384) shape that stresses the split-K tolerance.

Compile-validated for gfx942; small-M/large-K shape verified PASS on MI300X.
Document the Stream-K deep-core path (the feature this branch adds): generate
kernels via unified_gemm_codegen --variants stream_k, build/run the 03 standalone
driver (perf surface) and the 04 registry driver (Registry -> Dispatcher ->
verify), run the CTest, plus the reduction strategies, supported dtypes/layouts,
the cold-cache perf methodology, the split-K-aware tolerance, and known
limitations (gfx950 fp8/bf8, tile-coverage gap). Linked from the dispatcher README.
…, thread-safety, row-major C assert)

Three correctness fixes from a tough review pass on the deep-core path:

1. Zero the reduction workspace in Dispatcher::ensure_workspace() before every
   Linear/Tree dispatch. The dispatcher previously only hipMalloc'd the buffer
   and relied on the backend's per-iteration preprocess reset to zero it, so on
   the non-benchmarking (nrepeat=1) path -- or with set_benchmarking(false) --
   the buffer could be garbage and corrupt the reduction. Now correctness is
   independent of whether the preprocess runs, mirroring the internal
   DeviceMem::SetZero() the standalone launch already does.

2. Serialize the shared workspace_ buffer with a mutex. workspace_ is mutable
   state reused across const run() calls; two concurrent linear/tree dispatches
   on different streams could share one buffer and corrupt each other. The lock
   spans size->zero->launch (the buffer is in use for the whole run). Atomic /
   non-Stream-K paths use no workspace and take no lock.

3. Emit a static_assert(is_row_major<CLayout>) in the generated Atomic launch.
   The atomic C-reset uses a row-major hipMemset2D; a column-major C would be
   zeroed incorrectly and silently corrupt atomic accumulation. The assert is
   atomic-only (Linear/Tree zero the workspace, not C).

Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree register, dispatch,
and verify PASS across 3840x4096x2048 and 128x128x16384.
…t preprocess

The atomic C-reset (hipMemset2DAsync) and the linear/tree workspace reset
(hipMemsetAsync) in the generated launch() were (void)-cast, so a failed reset
silently corrupted atomic accumulation / the reduction buffer with no signal.
Both now throw std::runtime_error on non-hipSuccess, consistent with the
IsSupportedArgument failure path in the same function. Applied to both the
internal-workspace and external-workspace launch overloads.

Verified on MI300X (gfx942): fp16 rcr atomic/linear/tree compile and verify PASS.
@therock-pr-bot

therock-pr-bot Bot commented Jun 30, 2026

Copy link
Copy Markdown

✅ All Checks Passed — Ready for Review

Check Status Details
🌿 Branch Name ✅ Pass
📝 PR Title/Description ✅ Pass
Forbidden Files ✅ Pass
🧪 Unit Test ✅ Pass
🔎 pre-commit ✅ Pass
🚫 Draft PR 🔜 To Be Enabled
🚩 Feature Flag 🔜 To Be Enabled
📊 Code Coverage 🔜 To Be Enabled
🤖 therock-pr-bot ✅ Pass

🎉 All checks passed! This PR is ready for review.

📖 Need help? See the Policy FAQ for details on every check and how to fix failures.

@therock-pr-bot

therock-pr-bot Bot commented Jun 30, 2026

Copy link
Copy Markdown

🎉 All checks passed! This PR is ready for review.

@ozturkosu ozturkosu changed the title feat(ck_tile): add stream_k variant to GEMM Dispatcher codegen feat(ck-tile): add stream_k variant to GEMM Dispatcher codegen Jun 30, 2026
// tile, taken from the kernel's own tile partitioner so the driver and
// tile_engine agree on the split factor.
auto kargs = SelectedKernel::StreamKGemmKernel::MakeKernelArgs(args);
const ck_tile::index_t kbatch =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest altering the variable name of kbatch because kbatch refers to the split-k value in normal CK Tile terminology. Stream-K does not use a "split-k" since there may not be the same number of workgroups contributing to a tile. I recommend the variable name num_wgs_per_tile or something similar. Please update your comments accordingly as well.

auto kargs = SelectedKernel::StreamKGemmKernel::MakeKernelArgs(args);
const ck_tile::index_t kbatch =
std::max<ck_tile::index_t>(1, kargs.tile_partitioner.estimate_num_wgs_per_tile());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code common to other operators? If so, we should probably avoid duplicating it here. I recommend using the existing helper function (if present) or placing one in a common file.

(If this isn't common code, then disregard this comment)

// instead of assuming fp16. Keeps the registry identifier and selection correct
// across every datatype the codegen emits.
template <typename T>
static constexpr DataType dtype_enum_of()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code common to other ops? If so, can we avoid duplication?

}

template <typename Layout>
static constexpr LayoutTag layout_tag_of()

@ecamartins ecamartins Jun 30, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing about duplication here.

key.signature.transpose_a = false;
key.signature.transpose_b = false;
key.signature.grouped = false;
key.signature.split_k = 1;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here that for Stream-K, split-k must be 1?

// concurrent Stream-K linear/tree dispatches on different streams cannot
// corrupt each other's reduction. Atomic / non-Stream-K paths use no
// workspace and take no lock.
mutable std::mutex workspace_mutex_;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we are using a mutex here? Are concurrent kernel runs using the same workspace? If so, we likely want to avoid this.


/// Stream-K partial-sum reduction strategy. `None` = not a Stream-K kernel.
/// Mirrors ck_tile::StreamKReductionStrategy (Atomic/Linear/Tree).
enum class ReductionStrategy : std::uint8_t

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid duplicating this Reduction strategy? Can we use the one defined in ck_tile::StreamKReductionStrategy instead? I just fear this adds unnecessary maintenance.

LAYOUTS = ["rcr", "rrr", "ccr", "crr"]


def detect_arch(fallback=None):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to avoid reliance on rocm-info. If possible, can you remove this dependency?

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Stream-K GEMM as a first-class “deep-core” citizen in the CK Tile Dispatcher, including codegen support, registry/dispatcher selection by reduction strategy, dispatcher-owned workspace management, and end-to-end drivers + CTest coverage.

Changes:

  • Extend unified GEMM codegen to generate Stream-K kernels (atomic/linear/tree) and register them via a Stream-K-specific backend wrapper.
  • Add dispatcher-owned, reusable reduction workspace (with serialization) for Stream-K linear/tree dispatch.
  • Add standalone + registry drivers, documentation, and a GPU/hipcc-gated CTest to validate dispatch + correctness across datatype/layout/strategy combinations.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
projects/composablekernel/dispatcher/codegen/arch_filter.py Adds GEMM_STREAMK operator type and tile-shape constraints for arch filtering.
projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py Adds Stream-K variant generation, naming, headers, launch paths (internal & external workspace), and wrapper/backend selection.
projects/composablekernel/dispatcher/examples/gemm/cpp/03_streamk_gemm_driver.cpp New direct-launch Stream-K driver for perf/correctness (bypasses dispatcher).
projects/composablekernel/dispatcher/examples/gemm/cpp/04_streamk_registry_driver.cpp New deep-core driver exercising Registry → Dispatcher → Stream-K backend with split-K-aware verification tolerance.
projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend_streamk.hpp New KernelInstance wrapper for Stream-K kernels (supports, workspace sizing, workspace-aware run).
projects/composablekernel/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp Adds dispatcher-owned Stream-K workspace state + mutex and declares ensure_workspace + destructor.
projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_instance.hpp Adds workspace-sizing and workspace-aware run() virtuals (defaulting to no-workspace behavior).
projects/composablekernel/dispatcher/include/ck_tile/dispatcher/kernel_key.hpp Adds ReductionStrategy and Stream-K algorithm fields; updates identifier encoding and hashing.
projects/composablekernel/dispatcher/include/ck_tile/dispatcher/problem.hpp Adds Stream-K request fields and ProblemBuilder::stream_k().
projects/composablekernel/dispatcher/src/dispatcher.cpp Implements workspace allocation/zeroing and serialized workspace use in run paths; adds destructor.
projects/composablekernel/dispatcher/tests/CMakeLists.txt Adds a GPU+hipcc-gated CTest target for Stream-K registry testing (SKIP_RETURN_CODE=77).
projects/composablekernel/dispatcher/tests/test_streamk_registry.py New Python test that codegens/builds/runs/verifies Stream-K across datatype/layout/strategy matrix.
projects/composablekernel/dispatcher/README.md Links to Stream-K documentation.
projects/composablekernel/dispatcher/STREAMK.md New documentation for generating, running, validating, and testing Stream-K deep-core path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +236 to +240
out = r.stdout
ok_verify = "Verification: PASS" in out
ok_suffix = f"identifier={dtype}_{layout}" in out and want_suffix in out.split(
"identifier="
)[1].split()[0]
Comment on lines +53 to +61
// Zero the region the kernel will use. Linear/Tree reductions accumulate into
// this buffer and read it before writing, so a stale/garbage buffer corrupts
// results. Doing it here makes correctness independent of whether the backend's
// per-iteration preprocess reset runs (e.g. on the non-benchmarking nrepeat=1
// path), mirroring the internal DeviceMem::SetZero() the standalone launch does.
if(bytes > 0 && hipMemset(workspace_, 0, bytes) != hipSuccess)
{
throw std::runtime_error("Dispatcher: failed to zero Stream-K reduction workspace");
}
@ThruptiRajLakshmanaGowda

Copy link
Copy Markdown
Contributor

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants