-
Notifications
You must be signed in to change notification settings - Fork 980
[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul #19009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xuyanwen2012
wants to merge
8
commits into
pytorch:main
Choose a base branch
from
sarc-acl:yanwen/pr-amend-staging
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
08d748d
[ET-VK] Add Adapter::supports_cooperative_matrix() helper
xuyanwen2012 b26728a
Add linear_coopmat + matmul_coopmat drop-in shader variants
xuyanwen2012 02faae5
[ET-VK] Address coopmat dispatch review feedback
xuyanwen2012 9605ece
[ET-VK] Restructure coopmat dispatch and tests per review feedback
xuyanwen2012 ab96e47
[ET-VK] Tighten coopmat eligibility gate and address bot review nits
xuyanwen2012 695824f
[ET-VK] Address review: gate coopmat on subgroup-size and discrete-GPU
xuyanwen2012 9d1e49d
[ET-VK] Address bot review: bias prepack passthrough and test sweep gate
xuyanwen2012 e921d65
[ET-VK] Add coopmat bounds-check + document dispatch-math idiom
xuyanwen2012 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,304 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| /* | ||
| * KHR Cooperative Matrix MM kernel — unified linear + matmul. | ||
| * | ||
| * Variants (set in coopmat_mm.yaml): | ||
| * matmul_coopmat row_major weight, no bias (aten.mm runtime mat2) | ||
| * linear_coopmat prepacked weight, no bias (aten.linear) | ||
| * linear_coopmat_bias prepacked weight, +bias (aten.linear w/ bias) | ||
| * | ||
| * Computes: D = A * B[ + bias] | ||
| * A is [M, K] (row-major). | ||
| * B is either [K, N] row-major (matmul), or 4OC x 4IC blocked prepacked | ||
| * with t_weight[(k4 * N4 + n4) * 4 + dk] returning a vec4 of 4 N-elements | ||
| * at K-row k4*4+dk (linear). | ||
| * D is [M, N], buffer storage. | ||
| * | ||
| * fp16 x fp16 -> fp32 MMA. When DTYPE=half, inputs/outputs are native fp16 | ||
| * (no conversion, half the bandwidth). When DTYPE=float, inputs are fp32 | ||
| * with on-the-fly packHalf2x16 conversion at the shared-memory load. | ||
| * | ||
| * When HAS_BIAS, bias is staged once into shared memory and broadcast into | ||
| * each accumulator tile (stride-0 coopMatLoad) before coopMatStore, so | ||
| * t_output is write-only. | ||
| * | ||
| * Tile hierarchy (configurable via yaml; defaults shown for Adreno): | ||
| * MMA_* per-MMA-instruction shape (16x16x16 fp16) | ||
| * WG_TILE_* output tile produced per workgroup (64x64; K-step 32) | ||
| * SG_GRID_* subgroup grid inside the workgroup (2x2 = 4 subgroups) | ||
| * SG_TILE_* per-subgroup output tile (= WG_TILE / SG_GRID; 32x32) | ||
| * SUBGROUP_SIZE hardware subgroup width (64 on Adreno) | ||
| * WG_SIZE threads per workgroup (= NUM_SUBGROUPS * SUBGROUP_SIZE) | ||
| */ | ||
|
|
||
| #version 450 core | ||
|
|
||
| #extension GL_KHR_cooperative_matrix : require | ||
| #extension GL_KHR_memory_scope_semantics : require | ||
| #extension GL_KHR_shader_subgroup_basic : enable | ||
| #extension GL_EXT_shader_explicit_arithmetic_types : require | ||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||
| #extension GL_EXT_control_flow_attributes : enable | ||
|
|
||
| #define PRECISION ${PRECISION} | ||
|
|
||
| $if DTYPE == "half": | ||
| #define IS_FP16_INPUT | ||
| $if DTYPE == "float": | ||
| #define IS_FP32_INPUT | ||
|
|
||
| $if HAS_BIAS: | ||
| #define HAS_BIAS | ||
|
|
||
| $if WEIGHT_LAYOUT == "prepacked": | ||
| #define WEIGHT_PREPACKED | ||
|
|
||
| layout(std430) buffer; | ||
|
|
||
| #include "common.glslh" | ||
|
|
||
| // Bindings: output(0), mat1(1), weight(2), [bias(3)] | ||
| ${layout_declare_tensor(B, "w", "t_output", DTYPE, "buffer", is_scalar_array=True)} | ||
| ${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer", is_scalar_array=False)} | ||
| ${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer", is_scalar_array=False)} | ||
| $if HAS_BIAS: | ||
| ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=True)} | ||
|
|
||
| // UBOs — N comes from out_sizes (linear) or mat2_sizes (matmul). | ||
| ${layout_declare_ubo(B, "ivec4", "mat1_sizes")} | ||
| $if WEIGHT_LAYOUT == "prepacked": | ||
| ${layout_declare_ubo(B, "ivec4", "out_sizes")} | ||
| $else: | ||
| ${layout_declare_ubo(B, "ivec4", "mat2_sizes")} | ||
|
|
||
| layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|
|
||
| // Cooperative-matrix instruction shape (must match a property enumerated by | ||
| // vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR for this device). | ||
| const uint MMA_M = ${MMA_M}; | ||
| const uint MMA_N = ${MMA_N}; | ||
| const uint MMA_K = ${MMA_K}; | ||
|
|
||
| // Output tile produced per workgroup. | ||
| const uint WG_TILE_M = ${WG_TILE_M}; | ||
| const uint WG_TILE_N = ${WG_TILE_N}; | ||
| const uint WG_TILE_K = ${WG_TILE_K}; | ||
|
|
||
| // Subgroup grid inside the workgroup. | ||
| const uint SG_GRID_X = ${SG_GRID_X}; | ||
| const uint SG_GRID_Y = ${SG_GRID_Y}; | ||
| const uint SUBGROUP_SIZE = ${SUBGROUP_SIZE}; | ||
| const uint NUM_SUBGROUPS = SG_GRID_X * SG_GRID_Y; | ||
| const uint WG_SIZE = NUM_SUBGROUPS * SUBGROUP_SIZE; | ||
|
|
||
| // Derived: per-subgroup tile and MMAs per subgroup tile. | ||
| const uint SG_TILE_M = WG_TILE_M / SG_GRID_Y; | ||
| const uint SG_TILE_N = WG_TILE_N / SG_GRID_X; | ||
| const uint MMAS_PER_SG_M = SG_TILE_M / MMA_M; | ||
| const uint MMAS_PER_SG_N = SG_TILE_N / MMA_N; | ||
|
|
||
| // fp16: 8 elements per uvec4 (128-bit) | ||
| const uint FP16_PER_VEC4 = 8; | ||
|
|
||
| // Shared memory with skew padding | ||
| const uint A_STRIDE_VEC4 = (WG_TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4; | ||
| const uint B_STRIDE_VEC4 = (WG_TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4; | ||
|
|
||
| shared uvec4 Ash[WG_TILE_M * A_STRIDE_VEC4]; | ||
| shared uvec4 Bsh[WG_TILE_K * B_STRIDE_VEC4]; | ||
|
|
||
| #ifdef HAS_BIAS | ||
| // fp32 staging buffer so coopMatLoad can broadcast directly into the | ||
| // fp32 accumulator coopmat without a type conversion at the load. | ||
| shared float bias_sh[WG_TILE_N]; | ||
| #endif | ||
|
|
||
| // Accumulator tiles (fp32) | ||
| coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> result[MMAS_PER_SG_M][MMAS_PER_SG_N]; | ||
|
|
||
| #ifdef IS_FP32_INPUT | ||
| uvec2 f32x4_to_f16x4(vec4 v) { | ||
| return uvec2(packHalf2x16(v.xy), packHalf2x16(v.zw)); | ||
| } | ||
| #endif | ||
|
|
||
| void main() { | ||
| const uvec2 tileID = uvec2(gl_WorkGroupID.xy); | ||
| const uvec2 warpInTile = uvec2( | ||
| gl_SubgroupID % SG_GRID_X, | ||
| gl_SubgroupID / SG_GRID_X); | ||
|
|
||
| const uint K = uint(mat1_sizes.x); | ||
| const uint M = uint(mat1_sizes.y); | ||
| #ifdef WEIGHT_PREPACKED | ||
| const uint N = uint(out_sizes.x); | ||
| #else | ||
| const uint N = uint(mat2_sizes.x); | ||
| #endif | ||
| const uint K4 = (K + 3u) / 4u; | ||
| const uint N4 = (N + 3u) / 4u; | ||
|
|
||
| [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { | ||
| [[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) { | ||
| result[i][j] = coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator>(0.0); | ||
| } | ||
| } | ||
|
|
||
| // Thread assignment for A tile (WG_TILE_M rows x INVS_PER_ROW_A uvec4/row) | ||
| const uint INVS_PER_ROW_A = WG_TILE_K / FP16_PER_VEC4; | ||
| const uint a_col = gl_LocalInvocationID.x % INVS_PER_ROW_A; | ||
| const uint a_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_A; | ||
|
|
||
| // Thread assignment for B tile (WG_TILE_K rows x INVS_PER_ROW_B uvec4/row) | ||
| const uint INVS_PER_ROW_B = WG_TILE_N / FP16_PER_VEC4; | ||
| const uint b_col = gl_LocalInvocationID.x % INVS_PER_ROW_B; | ||
| const uint b_row_offset = gl_LocalInvocationID.x / INVS_PER_ROW_B; | ||
|
|
||
| const uint a_row_base = WG_TILE_M * tileID.y; | ||
| const uint b_col_base = WG_TILE_N * tileID.x; | ||
|
|
||
| for (uint chunkK = 0; chunkK < K; chunkK += WG_TILE_K) { | ||
|
|
||
| // --- Load A tile -> shared (single pass) --- | ||
| { | ||
| uint row = a_row_base + a_row_offset; | ||
| uint k_elem = chunkK + a_col * FP16_PER_VEC4; | ||
|
|
||
| #ifdef IS_FP16_INPUT | ||
| uint k_hv4 = k_elem / 4; | ||
| f16vec4 v0 = t_mat1[row * K4 + k_hv4]; | ||
| f16vec4 v1 = t_mat1[row * K4 + k_hv4 + 1]; | ||
| Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4( | ||
| packFloat2x16(v0.xy), packFloat2x16(v0.zw), | ||
| packFloat2x16(v1.xy), packFloat2x16(v1.zw)); | ||
| #else | ||
| uint k_vec4 = k_elem / 4; | ||
| vec4 v0 = t_mat1[row * K4 + k_vec4]; | ||
| vec4 v1 = t_mat1[row * K4 + k_vec4 + 1]; | ||
| uvec2 h0 = f32x4_to_f16x4(v0); | ||
| uvec2 h1 = f32x4_to_f16x4(v1); | ||
| Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(h0, h1); | ||
| #endif | ||
| } | ||
|
|
||
| // --- Load B tile -> shared (single pass) --- | ||
| { | ||
| uint k_row = chunkK + b_row_offset; | ||
| uint n_elem = b_col_base + b_col * FP16_PER_VEC4; | ||
| uint n4_0 = n_elem >> 2u; | ||
| #ifdef WEIGHT_PREPACKED | ||
| // Prepacked: t_weight[(k4 * N4 + n4) * 4 + dk] yields vec4 of | ||
| // 4 N-elements at K-row (k4*4+dk). | ||
| uint k4 = k_row >> 2u; | ||
| uint dk = k_row & 3u; | ||
| uint b_idx0 = (k4 * N4 + n4_0) * 4u + dk; | ||
| uint b_idx1 = (k4 * N4 + n4_0 + 1u) * 4u + dk; | ||
| #else | ||
| // Row-major: t_weight[k_row * N4 + n4] yields vec4 of 4 N-elements. | ||
| uint b_idx0 = k_row * N4 + n4_0; | ||
| uint b_idx1 = k_row * N4 + n4_0 + 1u; | ||
| #endif | ||
|
|
||
| #ifdef IS_FP16_INPUT | ||
| f16vec4 v0 = t_weight[b_idx0]; | ||
| f16vec4 v1 = t_weight[b_idx1]; | ||
| Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4( | ||
| packFloat2x16(v0.xy), packFloat2x16(v0.zw), | ||
| packFloat2x16(v1.xy), packFloat2x16(v1.zw)); | ||
| #else | ||
| vec4 v0 = t_weight[b_idx0]; | ||
| vec4 v1 = t_weight[b_idx1]; | ||
| uvec2 h0 = f32x4_to_f16x4(v0); | ||
| uvec2 h1 = f32x4_to_f16x4(v1); | ||
| Bsh[b_row_offset * B_STRIDE_VEC4 + b_col] = uvec4(h0, h1); | ||
| #endif | ||
| } | ||
|
|
||
| barrier(); | ||
|
|
||
| // --- Cooperative matrix MMA --- | ||
| [[unroll]] for (uint k = 0; k < WG_TILE_K / MMA_K; ++k) { | ||
| uint k_start = MMA_K * k; | ||
|
|
||
| coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_K, gl_MatrixUseA> matA[MMAS_PER_SG_M]; | ||
| [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { | ||
| uint row_a = MMA_M * (MMAS_PER_SG_M * warpInTile.y + i); | ||
| coopMatLoad( | ||
| matA[i], Ash, | ||
| row_a * A_STRIDE_VEC4 + k_start / FP16_PER_VEC4, | ||
| A_STRIDE_VEC4, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| } | ||
|
|
||
| coopmat<float16_t, gl_ScopeSubgroup, MMA_K, MMA_N, gl_MatrixUseB> matB; | ||
| [[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) { | ||
| uint col_b = MMA_N * (MMAS_PER_SG_N * warpInTile.x + j) / FP16_PER_VEC4; | ||
| coopMatLoad( | ||
| matB, Bsh, | ||
| k_start * B_STRIDE_VEC4 + col_b, | ||
| B_STRIDE_VEC4, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
|
|
||
| [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { | ||
| result[i][j] = coopMatMulAdd(matA[i], matB, result[i][j]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| barrier(); | ||
| } | ||
|
|
||
| #ifdef HAS_BIAS | ||
| // Stage one WG_TILE_N-wide row of bias into shared memory. The C++ dispatch | ||
| // gate ensures N % WG_TILE_N == 0, so no per-element bounds check is needed. | ||
| { | ||
| const uint tile_n_start = WG_TILE_N * tileID.x; | ||
| for (uint t = gl_LocalInvocationID.x; t < WG_TILE_N; t += WG_SIZE) { | ||
| bias_sh[t] = float(t_bias[tile_n_start + t]); | ||
| } | ||
| } | ||
| memoryBarrierShared(); | ||
| barrier(); | ||
| #endif | ||
|
|
||
| // --- Store result (with bias folded in pre-store, if present) --- | ||
| [[unroll]] for (uint i = 0; i < MMAS_PER_SG_M; ++i) { | ||
| [[unroll]] for (uint j = 0; j < MMAS_PER_SG_N; ++j) { | ||
| uint gi = WG_TILE_M * tileID.y + MMA_M * (MMAS_PER_SG_M * warpInTile.y + i); | ||
| uint gj = WG_TILE_N * tileID.x + MMA_N * (MMAS_PER_SG_N * warpInTile.x + j); | ||
|
|
||
| #ifdef HAS_BIAS | ||
| // Stride-0 row-major load broadcasts MMA_N bias values across all | ||
| // MMA_M rows of the accumulator tile. | ||
| uint local_n = MMA_N * (MMAS_PER_SG_N * warpInTile.x + j); | ||
| coopmat<float, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> bias_tile; | ||
| coopMatLoad( | ||
| bias_tile, bias_sh, | ||
| local_n, /*stride=*/0u, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| result[i][j] += bias_tile; | ||
| #endif | ||
|
|
||
| #ifdef IS_FP16_INPUT | ||
| coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator> out_tile = | ||
| coopmat<float16_t, gl_ScopeSubgroup, MMA_M, MMA_N, gl_MatrixUseAccumulator>(result[i][j]); | ||
| coopMatStore( | ||
| out_tile, t_output, | ||
| gi * N + gj, N, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| #else | ||
| coopMatStore( | ||
| result[i][j], t_output, | ||
| gi * N + gj, N, | ||
| gl_CooperativeMatrixLayoutRowMajor); | ||
| #endif | ||
| } | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # Unified KHR Cooperative Matrix MM kernel (linear + matmul). | ||
| # Three shader variants over two weight layouts: | ||
| # matmul_coopmat row_major weight, no bias (aten.mm runtime mat2) | ||
| # linear_coopmat prepacked weight, no bias (aten.linear) | ||
| # linear_coopmat_bias prepacked weight, +bias (aten.linear with bias) | ||
|
|
||
| coopmat_mm: | ||
| parameter_names_with_default_values: | ||
| DTYPE: float | ||
| PRECISION: highp | ||
| WEIGHT_LAYOUT: row_major | ||
| HAS_BIAS: false | ||
| MMA_M: 16 | ||
| MMA_N: 16 | ||
| MMA_K: 16 | ||
| WG_TILE_M: 64 | ||
| WG_TILE_N: 64 | ||
| WG_TILE_K: 32 | ||
| SG_GRID_X: 2 | ||
| SG_GRID_Y: 2 | ||
| SUBGROUP_SIZE: 64 | ||
| generate_variant_forall: | ||
| DTYPE: | ||
| - VALUE: float | ||
| - VALUE: half | ||
| shader_variants: | ||
| - NAME: matmul_coopmat | ||
| WEIGHT_LAYOUT: row_major | ||
| - NAME: linear_coopmat | ||
| WEIGHT_LAYOUT: prepacked | ||
| - NAME: linear_coopmat_bias | ||
| WEIGHT_LAYOUT: prepacked | ||
| HAS_BIAS: true |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xuyanwen2012 fair comment. Could we also add bounds checking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed. Done in e921d65