Skip to content

Commit 3985cfd

Browse files
committed
[CK_TILE] Stream-K bridge: derive strides from layout (no rcr hardcoding)
Make the Stream-K bridge layout-generic instead of rcr-hardcoded, so all 4 A/B/C layouts (rcr/rrr/ccr/crr) work end to end: - streamk_gemm_ctypes_lib.cpp: derive stride_A/B/C at compile time from the kernel's ALayout/BLayout/CLayout (RowMajor RxC -> ld=C, ColumnMajor -> ld=R) instead of the hardcoded K/K/N. - generated_tile_backend_streamk.hpp (registry path): same layout-derived strides. - GpuGemmRunner: read dtype AND layout off the kernel name; arrange each operand per layout (RowMajor=C-contiguous, ColumnMajor=F-contiguous); bf16 encode is now memory-order-preserving so column-major operands stay column-major. - run_one_streamk_gemm_kernel.py: dtype/layout-aware A/B + reference (was fp16-only). - streamk_gemm_full_benchmark.py: SUPPORTED_LAYOUTS now rcr/rrr/ccr/crr, SUPPORTED_DTYPES fp16+bf16 (fp8/bf8/int8 still need runner codecs).
1 parent 68c7d39 commit 3985cfd

5 files changed

Lines changed: 93 additions & 44 deletions

File tree

projects/composablekernel/dispatcher/bindings/ctypes/streamk_gemm_ctypes_lib.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <cstring>
3737
#include <exception>
3838
#include <string>
39+
#include <type_traits>
3940

4041
// Kernel header included via -include compiler flag (with CK_TILE_SINGLE_KERNEL_INCLUDE).
4142
// Defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME
@@ -100,11 +101,13 @@ int dispatcher_init() { return dispatcher_initialize(); }
100101
*
101102
* hipMalloc A/B/C, copy A and B host->device, memset C (the Atomic reduction
102103
* strategy accumulates into C, so it must start zeroed), build a
103-
* ck_tile::StreamKHostArgs with rcr default strides (stride_A=K, stride_B=K,
104-
* stride_C=N) and launch. The launch allocates the reduction workspace
105-
* internally and resets C between timed iterations. C is then copied back.
104+
* ck_tile::StreamKHostArgs whose strides are derived from the kernel's actual
105+
* ALayout/BLayout/CLayout (no layout hardcoding) and launch. The launch
106+
* allocates the reduction workspace internally and resets C between timed
107+
* iterations. C is then copied back.
106108
*
107-
* Layout contract (rcr): A row-major MxK, B col-major KxN, C row-major MxN.
109+
* The host buffers must be laid out to match each operand's layout (the Python
110+
* runner arranges A/B/C as RowMajor=C-contiguous, ColumnMajor=F-contiguous).
108111
*
109112
* Returns: 0 on success, -1 on HIP error / generic throw, -2 if the kernel
110113
* reports the arguments are unsupported.
@@ -166,17 +169,28 @@ int dispatcher_run_gemm(
166169
return -1;
167170
}
168171

169-
// rcr default strides: A row-major (stride=K), B col-major (stride=K),
170-
// C row-major (stride=N). k_batch is fixed to 1 inside StreamKHostArgs.
172+
// Strides are DERIVED from the kernel's actual layouts (ALayout/BLayout/CLayout
173+
// come from the force-included generated header) -- nothing layout-specific is
174+
// hardcoded, so every layout (rcr/rrr/ccr/crr/...) works. A RowMajor R x C
175+
// matrix has leading dim C; a ColumnMajor one has leading dim R.
176+
// A is M x K, B is K x N, C is M x N.
177+
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
178+
const ck_tile::index_t lda = static_cast<ck_tile::index_t>(
179+
std::is_same_v<ALayout, RowMajor> ? K : M);
180+
const ck_tile::index_t ldb = static_cast<ck_tile::index_t>(
181+
std::is_same_v<BLayout, RowMajor> ? N : K);
182+
const ck_tile::index_t ldc = static_cast<ck_tile::index_t>(
183+
std::is_same_v<CLayout, RowMajor> ? N : M);
184+
// k_batch is fixed to 1 inside StreamKHostArgs.
171185
ck_tile::StreamKHostArgs args(static_cast<const void*>(A_dev),
172186
static_cast<const void*>(B_dev),
173187
static_cast<void*>(C_dev),
174188
static_cast<ck_tile::index_t>(M),
175189
static_cast<ck_tile::index_t>(N),
176190
static_cast<ck_tile::index_t>(K),
177-
/*stride_A=*/static_cast<ck_tile::index_t>(K),
178-
/*stride_B=*/static_cast<ck_tile::index_t>(K),
179-
/*stride_C=*/static_cast<ck_tile::index_t>(N));
191+
/*stride_A=*/lda,
192+
/*stride_B=*/ldb,
193+
/*stride_C=*/ldc);
180194

181195
// Benchmark parameters. warmup/repeat default to old Tile Engine's values
182196
// (warmup=50, repeat=100); a generous warmup keeps the GPU clock ramped, and

projects/composablekernel/dispatcher/include/ck_tile/dispatcher/backends/generated_tile_backend_streamk.hpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
1111
#include <hip/hip_runtime.h>
1212
#include <string>
13+
#include <type_traits>
1314

1415
namespace ck_tile {
1516
namespace dispatcher {
@@ -155,26 +156,35 @@ class GeneratedStreamKKernelInstance : public KernelInstance
155156
}
156157

157158
private:
158-
/// Build StreamKHostArgs for `problem`. rcr strides: row-major A (K),
159-
/// column-major B (K), row-major C (N). k_batch is owned by the Stream-K tile
160-
/// partitioner, not passed here. Pointers default to null for sizing-only use
161-
/// (GetWorkSpaceSize). StreamKHostArgs uses ck_tile::index_t (int32); cast
162-
/// from Problem's int64.
159+
/// Build StreamKHostArgs for `problem`. Strides are DERIVED from the kernel's
160+
/// actual layouts (the force-included single kernel exposes the global
161+
/// ALayout/BLayout/CLayout), not hardcoded to rcr. k_batch is owned by the
162+
/// Stream-K tile partitioner, not passed here. Pointers default to null for
163+
/// sizing-only use (GetWorkSpaceSize). StreamKHostArgs uses ck_tile::index_t
164+
/// (int32); cast from Problem's int64.
163165
ck_tile::StreamKHostArgs make_args(const Problem& problem,
164166
const void* a_ptr = nullptr,
165167
const void* b_ptr = nullptr,
166168
void* c_ptr = nullptr) const
167169
{
168-
using idx = ck_tile::index_t;
170+
using idx = ck_tile::index_t;
171+
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
172+
// A is MxK, B is KxN, C is MxN; RowMajor RxC has leading dim C, else R.
173+
const idx lda = static_cast<idx>(
174+
std::is_same_v<::ALayout, RowMajor> ? problem.K : problem.M);
175+
const idx ldb = static_cast<idx>(
176+
std::is_same_v<::BLayout, RowMajor> ? problem.N : problem.K);
177+
const idx ldc = static_cast<idx>(
178+
std::is_same_v<::CLayout, RowMajor> ? problem.N : problem.M);
169179
return ck_tile::StreamKHostArgs{a_ptr,
170180
b_ptr,
171181
c_ptr,
172182
static_cast<idx>(problem.M),
173183
static_cast<idx>(problem.N),
174184
static_cast<idx>(problem.K),
175-
static_cast<idx>(problem.K),
176-
static_cast<idx>(problem.K),
177-
static_cast<idx>(problem.N)};
185+
lda,
186+
ldb,
187+
ldc};
178188
}
179189

180190
KernelKey key_;

projects/composablekernel/dispatcher/python/gemm_utils.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -405,44 +405,57 @@ def __init__(self, lib_path: Path):
405405
raise RuntimeError(f"Failed to initialize dispatcher .so: {lib_path}")
406406
names = self.lib.kernel_names
407407
self._kernel_name = names[0] if names else "unknown"
408-
# Input dtype is encoded in the kernel name: gemm_<dtype>_<layout>_...
408+
# dtype and layout are encoded in the kernel name: gemm_<dtype>_<layout>_...
409+
# layout is the 3-char A/B/C major code (e.g. 'rcr'). Nothing layout- or
410+
# dtype-specific is hardcoded -- both are read off the compiled kernel.
409411
parts = self._kernel_name.split("_")
410412
self._dtype = parts[1] if len(parts) > 1 else "fp16"
413+
lay = parts[2] if len(parts) > 2 and len(parts[2]) == 3 else "rcr"
414+
self._layout = lay if set(lay) <= {"r", "c"} else "rcr"
411415

412416
@property
413417
def kernel_name(self) -> str:
414418
return self._kernel_name
415419

416420
@staticmethod
417421
def _bf16_encode(x: np.ndarray) -> np.ndarray:
418-
"""float -> bfloat16 bits (uint16), round-to-nearest-even. ENCODE need only
419-
be nearest-representable; DECODE must be bit-exact to device bf16_t so the
420-
numpy reference multiplies the same values the GPU does."""
421-
u = np.ascontiguousarray(x, dtype=np.float32).view(np.uint32)
422+
"""float -> bfloat16 bits (uint16), round-to-nearest-even, PRESERVING the
423+
input's memory order (C or F) so column-major operands stay column-major.
424+
ENCODE need only be nearest-representable; DECODE must be bit-exact to
425+
device bf16_t so the numpy reference multiplies what the GPU does."""
426+
f = np.asarray(x, dtype=np.float32)
427+
if not (f.flags["C_CONTIGUOUS"] or f.flags["F_CONTIGUOUS"]):
428+
f = np.ascontiguousarray(f)
429+
u = f.view(np.uint32)
422430
rounded = (u + 0x7FFF + ((u >> 16) & 1)) >> 16
423431
return rounded.astype(np.uint16)
424432

425433
@staticmethod
426434
def _bf16_decode(u16: np.ndarray) -> np.ndarray:
427435
return (u16.astype(np.uint32) << 16).view(np.float32)
428436

437+
def _to_buf(self, X: np.ndarray, major: str) -> np.ndarray:
438+
"""Lay out an operand in the order its layout implies: RowMajor ->
439+
C-contiguous, ColumnMajor -> F-contiguous. The .so reads a flat buffer
440+
with the matching stride, so the raw byte order is what matters."""
441+
arr = np.ascontiguousarray(X) if major == "r" else np.asfortranarray(X)
442+
if self._dtype == "bf16":
443+
return self._bf16_encode(arr)
444+
return arr.astype(np.float16, order="K")
445+
429446
def run(
430447
self, A: np.ndarray, B: np.ndarray, problem: GemmProblem
431448
) -> GemmResult:
432449
M, N, K = problem.M, problem.N, problem.K
433450

434-
# A is row-major MxK; B is supplied KxN and stored column-major (the
435-
# 'c' in rcr), matching how the kernel expects its operands. bf16 is passed
436-
# as raw uint16 bits (the ctypes ABI is void* + sizeof, so 2-byte bf16 and
437-
# fp16 share the path; only the bit pattern differs).
438-
if self._dtype == "bf16":
439-
A_h = self._bf16_encode(A)
440-
B_h = self._bf16_encode(np.ascontiguousarray(B.T))
441-
C_h = np.zeros((M, N), dtype=np.uint16)
442-
else:
443-
A_h = np.ascontiguousarray(A, dtype=np.float16)
444-
B_h = np.ascontiguousarray(B.T, dtype=np.float16)
445-
C_h = np.zeros((M, N), dtype=np.float16)
451+
# Arrange A (MxK), B (KxN), C (MxN) per the kernel's actual layout. bf16 is
452+
# passed as raw uint16 bits (the ctypes ABI is void*+sizeof, so 2-byte bf16
453+
# and fp16 share the path; only the bit pattern differs).
454+
la, lb, lc = self._layout[0], self._layout[1], self._layout[2]
455+
A_h = self._to_buf(A, la)
456+
B_h = self._to_buf(B, lb)
457+
cdt = np.uint16 if self._dtype == "bf16" else np.float16
458+
C_h = np.zeros((M, N), dtype=cdt, order=("C" if lc == "r" else "F"))
446459

447460
status, time_ms = self.lib.run(A_h, B_h, C_h, M, N, K)
448461

projects/composablekernel/tile_engine/ops/gemm/run_one_streamk_gemm_kernel.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ def _run_one(idx, so_path, prob_dict, kernel_name, verify=False, verify_tol=2e-2
5555
problem = GemmProblem.from_dict(prob_dict)
5656

5757
np.random.seed(42)
58-
A = (np.random.randn(problem.M, problem.K) * 0.1).astype(np.float16)
59-
B = (np.random.randn(problem.K, problem.N) * 0.1).astype(np.float16)
58+
A = (np.random.randn(problem.M, problem.K) * 0.1).astype(np.float32)
59+
B = (np.random.randn(problem.K, problem.N) * 0.1).astype(np.float32)
6060

61-
# CRITICAL: load the library ONLY inside this subprocess.
61+
# CRITICAL: load the library ONLY inside this subprocess. The runner reads
62+
# dtype + layout off the kernel name and arranges/encodes A/B accordingly.
6263
runner = GpuGemmRunner(lib_path=so_path)
6364
result = runner.run(A, B, problem)
6465

@@ -77,7 +78,16 @@ def _run_one(idx, so_path, prob_dict, kernel_name, verify=False, verify_tol=2e-2
7778
"kernel": kernel_name,
7879
}
7980
if verify:
80-
ref = A.astype(np.float32) @ B.astype(np.float32)
81+
# Reference uses the SAME quantized inputs the device sees, per the
82+
# kernel's dtype (bf16 bit-truncation vs fp16), so the metric isolates
83+
# compute error from input quantization.
84+
if getattr(runner, "_dtype", "fp16") == "bf16":
85+
Aq = GpuGemmRunner._bf16_decode(GpuGemmRunner._bf16_encode(A))
86+
Bq = GpuGemmRunner._bf16_decode(GpuGemmRunner._bf16_encode(B))
87+
else:
88+
Aq = A.astype(np.float16).astype(np.float32)
89+
Bq = B.astype(np.float16).astype(np.float32)
90+
ref = Aq @ Bq
8191
got = result.output.astype(np.float32)
8292
denom = float(np.max(np.abs(ref))) or 1.0
8393
max_rel = float(np.max(np.abs(got - ref)) / denom)

projects/composablekernel/tile_engine/ops/gemm/streamk_gemm_full_benchmark.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,13 @@
6060
{"M": 512, "N": 512, "K": 8192},
6161
]
6262

63-
# Bridge surface for Stream-K: fp16/rcr only, matching the dispatcher host path
64-
# in streamk_gemm_ctypes_lib.cpp and the fp16 worker in
65-
# run_one_streamk_gemm_kernel.py.
66-
SUPPORTED_DTYPES = ("fp16",)
67-
SUPPORTED_LAYOUTS = ("rcr",)
63+
# Bridge surface for Stream-K. The dispatcher host path
64+
# (streamk_gemm_ctypes_lib.cpp) derives strides from the kernel's layouts and the
65+
# worker (run_one_streamk_gemm_kernel.py) reads dtype/layout off the kernel name,
66+
# so all 4 A/B/C layouts are supported; dtypes cover fp16 + bf16 (the codecs the
67+
# bridge runner implements). fp8/bf8/int8 await runner codecs.
68+
SUPPORTED_DTYPES = ("fp16", "bf16")
69+
SUPPORTED_LAYOUTS = ("rcr", "rrr", "ccr", "crr")
6870

6971

7072
def detect_devices():

0 commit comments

Comments
 (0)