Skip to content

Commit 02faae5

Browse files
committed
[ET-VK] Address coopmat dispatch review feedback
Three correctness fixes flagged on PR #19009. 1. The linear_coopmat / matmul_coopmat dispatch gate previously only checked `M >= 64`. We now tighten the gates in `Linear.cpp` and `Matmul.cpp` to require `M % TILE_M == 0 && N % TILE_N == 0 && K % TILE_K == 0`; misaligned shapes correctly fall back to the tiled shader. 2. The bias path in `linear_coopmat.glsl` previously read the just-written output buffer back, added bias, and wrote it again. We now fold bias into the fp32 accumulator before `coopMatStore`. The binding now becomes `w` instead of `rw`. 3. We now use `packFloat2x16` directly to avoid fp16 -> fp32 -> fp16 round trip.
1 parent b26728a commit 02faae5

6 files changed

Lines changed: 281 additions & 140 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.glsl

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
* Output is always fp32 (fp32 accumulator -> fp32 store) when DTYPE=float,
2424
* or fp16 when DTYPE=half.
2525
*
26-
* Optional bias: when HAS_BIAS is defined, bias is added post-store via
27-
* read-modify-write on the output buffer (one pass over the tile).
26+
* Optional bias: when HAS_BIAS is defined, bias is staged once into shared
27+
* memory and broadcast into each accumulator tile (stride-0 coopMatLoad)
28+
* before the store, so t_output is write-only.
2829
*/
2930

3031
#version 450 core
@@ -51,10 +52,7 @@ layout(std430) buffer;
5152
#include "common.glslh"
5253

5354
// Bindings: output(0), mat1(1), weight_packed(2), [bias(3)]
54-
$if HAS_BIAS:
55-
${layout_declare_tensor(B, "rw", "t_output", DTYPE, "buffer", is_scalar_array=True)}
56-
$else:
57-
${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)}
55+
${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)}
5856
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer", is_scalar_array=False)}
5957
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, "buffer", is_scalar_array=False)}
6058
$if HAS_BIAS:
@@ -94,6 +92,12 @@ const uint B_STRIDE_VEC4 = (TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4; // 9
9492
shared uvec4 Ash[TILE_M * A_STRIDE_VEC4]; // 5KB
9593
shared uvec4 Bsh[TILE_K * B_STRIDE_VEC4]; // 4.5KB
9694

95+
#ifdef HAS_BIAS
96+
// fp32 staging buffer so coopMatLoad can broadcast directly into the
97+
// fp32 accumulator coopmat without a type conversion at the load.
98+
shared float bias_sh[TILE_N]; // 256B
99+
#endif
100+
97101
// Accumulator tiles (fp32)
98102
coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> result[C_ROWS][C_COLS];
99103

@@ -146,8 +150,8 @@ void main() {
146150
f16vec4 v0 = t_mat1[row * K4 + k_hv4];
147151
f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1];
148152
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(
149-
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
150-
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
153+
packFloat2x16(v0.xy), packFloat2x16(v0.zw),
154+
packFloat2x16(v1.xy), packFloat2x16(v1.zw));
151155
#else
152156
uint k_vec4 = k_elem / 4;
153157
vec4 v0 = t_mat1[row * K4 + k_vec4];
@@ -173,8 +177,8 @@ void main() {
173177
f16vec4 v0 = t_weight_packed[(k4 * N4 + n4_0) * 4u + dk];
174178
f16vec4 v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk];
175179
Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(
176-
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
177-
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
180+
packFloat2x16(v0.xy), packFloat2x16(v0.zw),
181+
packFloat2x16(v1.xy), packFloat2x16(v1.zw));
178182
#else
179183
vec4 v0 = t_weight_packed[(k4 * N4 + n4_0) * 4u + dk];
180184
vec4 v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk];
@@ -218,11 +222,37 @@ void main() {
218222
barrier();
219223
}
220224

221-
// --- Store result ---
225+
#ifdef HAS_BIAS
226+
// Stage one TILE_N-wide row of bias into shared memory. The C++ dispatch
227+
// gate ensures N % TILE_N == 0, so no per-element bounds check is needed.
228+
{
229+
const uint tile_n_start = TILE_N * tileID.x;
230+
for (uint t = gl_LocalInvocationID.x; t < TILE_N; t += INVOCATIONS) {
231+
bias_sh[t] = float(t_bias[tile_n_start + t]);
232+
}
233+
}
234+
memoryBarrierShared();
235+
barrier();
236+
#endif
237+
238+
// --- Store result (with bias folded in pre-store, if present) ---
222239
[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
223240
[[unroll]] for (uint j = 0; j < C_COLS; ++j) {
224241
uint gi = TILE_M * tileID.y + lM * (C_ROWS * warpInTile.y + i);
225242
uint gj = TILE_N * tileID.x + lN * (C_COLS * warpInTile.x + j);
243+
244+
#ifdef HAS_BIAS
245+
// Stride-0 row-major load broadcasts lN bias values across all
246+
// lM rows of the accumulator tile.
247+
uint local_n = lN * (C_COLS * warpInTile.x + j);
248+
coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> bias_tile;
249+
coopMatLoad(
250+
bias_tile, bias_sh,
251+
local_n, /*stride=*/0u,
252+
gl_CooperativeMatrixLayoutRowMajor);
253+
result[i][j] += bias_tile;
254+
#endif
255+
226256
#ifdef IS_FP16_INPUT
227257
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> out_tile =
228258
coopmat<float16_t, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator>(result[i][j]);
@@ -238,24 +268,4 @@ void main() {
238268
#endif
239269
}
240270
}
241-
242-
#ifdef HAS_BIAS
243-
// Add bias via read-modify-write on the output buffer.
244-
// barrier() ensures all coopMatStore writes within this workgroup are visible.
245-
barrier();
246-
247-
const uint tile_m_start = TILE_M * tileID.y;
248-
const uint tile_n_start = TILE_N * tileID.x;
249-
// 64x64 tile = 4096 elements, 256 threads -> 16 elements per thread
250-
for (uint idx = gl_LocalInvocationID.x; idx < TILE_M * TILE_N; idx += INVOCATIONS) {
251-
uint local_m = idx / TILE_N;
252-
uint local_n = idx % TILE_N;
253-
uint gm = tile_m_start + local_m;
254-
uint gn = tile_n_start + local_n;
255-
if (gm < M && gn < N) {
256-
uint out_idx = gm * N + gn;
257-
t_output[out_idx] = t_output[out_idx] + t_bias[gn];
258-
}
259-
}
260-
#endif
261271
}

backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ void main() {
134134
f16vec4 v0 = t_mat1[row * K4 + k_hv4];
135135
f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1];
136136
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(
137-
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
138-
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
137+
packFloat2x16(v0.xy), packFloat2x16(v0.zw),
138+
packFloat2x16(v1.xy), packFloat2x16(v1.zw));
139139
#else
140140
// fp32 inputs: load two vec4 (8 fp32), convert to 8 fp16
141141
uint k_vec4 = k_elem / 4;
@@ -157,8 +157,8 @@ void main() {
157157
f16vec4 v0 = t_mat2[row * N4 + n_hv4];
158158
f16vec4 v1 = t_mat2[row * N4 + n_hv4 + 1];
159159
Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(
160-
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
161-
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
160+
packFloat2x16(v0.xy), packFloat2x16(v0.zw),
161+
packFloat2x16(v1.xy), packFloat2x16(v1.zw));
162162
#else
163163
uint n_vec4 = n_elem / 4;
164164
vec4 v0 = t_mat2[row * N4 + n_vec4];

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ void add_linear_tiled_node(
243243

244244
static constexpr uint32_t kLinearCoopMatTileM = 64;
245245
static constexpr uint32_t kLinearCoopMatTileN = 64;
246+
static constexpr uint32_t kLinearCoopMatTileK = 32;
246247
static constexpr uint32_t kLinearCoopMatInvocations = 256; // 4 subgroups x 64
247248

248249
vkapi::ShaderInfo pick_linear_coopmat_shader(
@@ -251,8 +252,7 @@ vkapi::ShaderInfo pick_linear_coopmat_shader(
251252
const std::vector<ValueRef>& resize_args) {
252253
const ValueRef out = args.at(0).refs.at(0);
253254
bool has_bias = graph->get_bool(resize_args.at(1));
254-
std::string kernel_name =
255-
has_bias ? "linear_coopmat_bias" : "linear_coopmat";
255+
std::string kernel_name = has_bias ? "linear_coopmat_bias" : "linear_coopmat";
256256
kernel_name.reserve(kShaderNameReserve);
257257
add_dtype_suffix(kernel_name, graph->dtype_of(out));
258258
return VK_KERNEL_FROM_STR(kernel_name);
@@ -342,27 +342,38 @@ void linear_packed_weight(
342342
ValueRef out = args.at(3);
343343

344344
bool has_bias = graph.val_is_not_none(bias);
345-
// Coopmat shader assumes M is a multiple of TILE_M (64) because the store
346-
// does not bounds-check. Fall back to the tiled shader otherwise.
347-
// TODO: remove this guard once the coopmat shader gains partial-tile
348-
// bounds checking.
345+
// Coopmat shader has no partial-tile / K-tail handling: the store overruns
346+
// unless M and N are multiples of the output tile, and the K-loop reads past
347+
// the end unless K is a multiple of TILE_K. Fall back to the tiled shader
348+
// when alignment is not met.
349+
// TODO: remove this guard once the coopmat shader gains partial-tile +
350+
// K-tail bounds checking.
349351
auto input_sizes = graph.sizes_of(input);
350-
int64_t M = input_sizes.size() >= 2
351-
? input_sizes.at(input_sizes.size() - 2)
352-
: 1;
352+
auto out_sizes_vec = graph.sizes_of(out);
353+
int64_t M =
354+
input_sizes.size() >= 2 ? input_sizes.at(input_sizes.size() - 2) : 1;
355+
int64_t K = input_sizes.back();
356+
int64_t N = out_sizes_vec.back();
353357
bool use_coopmat =
354358
graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
355359
graph.storage_type_of(out) == utils::kBuffer &&
356-
M >= 64;
360+
M % kLinearCoopMatTileM == 0 && N % kLinearCoopMatTileN == 0 &&
361+
K % kLinearCoopMatTileK == 0;
357362

358363
ValueRef packed_weight = prepack_fp_linear_weight(
359-
graph, weight_data, /*is_transposed=*/true, /*B=*/1,
364+
graph,
365+
weight_data,
366+
/*is_transposed=*/true,
367+
/*B=*/1,
360368
/*force_buffer=*/use_coopmat);
361369

362370
ValueRef packed_bias = kDummyValueRef;
363371
if (has_bias) {
364372
packed_bias = prepack_standard(
365-
graph, bias, graph.storage_type_of(out), utils::kWidthPacked,
373+
graph,
374+
bias,
375+
graph.storage_type_of(out),
376+
utils::kWidthPacked,
366377
/*passthrough=*/use_coopmat);
367378
}
368379

backends/vulkan/runtime/graph/ops/impl/Matmul.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void resize_matmul_tiled_node(
2929

3030
static constexpr uint32_t kCoopMatTileM = 64;
3131
static constexpr uint32_t kCoopMatTileN = 64;
32+
static constexpr uint32_t kCoopMatTileK = 32;
3233
static constexpr uint32_t kCoopMatInvocations = 256; // 4 subgroups × 64
3334

3435
vkapi::ShaderInfo pick_matmul_coopmat_shader(
@@ -275,18 +276,35 @@ void matmul_tiled(ComputeGraph& graph, const std::vector<ValueRef>& args) {
275276
ValueRef mat2 = args[1];
276277
ValueRef out = args[2];
277278

279+
// Coopmat path requires M%TILE_M==0, N%TILE_N==0, K%TILE_K==0 — the shader
280+
// has no partial-tile or K-tail handling.
281+
auto mat1_sizes = graph.sizes_of(mat1);
282+
int64_t M = mat1_sizes.at(mat1_sizes.size() - 2);
283+
int64_t K = mat1_sizes.back();
284+
int64_t N = graph.sizes_of(out).back();
285+
const bool coopmat_aligned = M % kCoopMatTileM == 0 &&
286+
N % kCoopMatTileN == 0 && K % kCoopMatTileK == 0;
287+
278288
if (graph.val_is_tref(mat2)) {
279289
auto mat2_sizes = graph.sizes_of(mat2);
280290
int64_t B = mat2_sizes.size() >= 3 ? mat2_sizes.at(0) : 1;
281291
bool use_coopmat =
282292
graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
283-
graph.storage_type_of(out) == utils::kBuffer;
293+
graph.storage_type_of(out) == utils::kBuffer && coopmat_aligned;
284294
ValueRef packed = prepack_fp_linear_weight(
285-
graph, mat2, /*is_transposed=*/false, B,
295+
graph,
296+
mat2,
297+
/*is_transposed=*/false,
298+
B,
286299
/*force_buffer=*/use_coopmat);
287300
if (use_coopmat) {
288301
add_linear_coopmat_node(
289-
graph, mat1, packed, kDummyValueRef, false, out,
302+
graph,
303+
mat1,
304+
packed,
305+
kDummyValueRef,
306+
false,
307+
out,
290308
utils::safe_downcast<int32_t>(B));
291309
} else {
292310
add_linear_tiled_node(
@@ -300,7 +318,7 @@ void matmul_tiled(ComputeGraph& graph, const std::vector<ValueRef>& args) {
300318
}
301319
} else if (
302320
graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
303-
graph.storage_type_of(out) == utils::kBuffer) {
321+
graph.storage_type_of(out) == utils::kBuffer && coopmat_aligned) {
304322
add_matmul_coopmat_node(graph, mat1, mat2, out);
305323
} else {
306324
add_matmul_tiled_node(graph, mat1, mat2, out);

0 commit comments

Comments
 (0)