Skip to content

[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul#19009

Open
xuyanwen2012 wants to merge 4 commits intopytorch:mainfrom
sarc-acl:yanwen/pr-amend-staging
Open

[ET-VK] Add VK_KHR_cooperative_matrix dispatch for linear/matmul#19009
xuyanwen2012 wants to merge 4 commits intopytorch:mainfrom
sarc-acl:yanwen/pr-amend-staging

Conversation

@xuyanwen2012
Copy link
Copy Markdown

Summary

Adds cooperative-matrix (WMMA) drop-in variants of the existing tiled linear_vec / matmul_vec shaders, dispatched automatically when two conditions hold:

  1. The device exposes VK_KHR_cooperative_matrix (checked via a new Adapter::supports_cooperative_matrix() helper)
  2. The output tensor is in buffer storage

When either condition fails, dispatch falls back to the existing tiled shader — no change in behavior for any existing user.

Why

Modern discrete and mobile GPUs (AMD RDNA3+, NVIDIA Turing+) expose hardware matrix-multiply-accumulate tiles through the VK_KHR_cooperative_matrix extension, typically delivering 3–4x throughput on compute-bound GEMM vs software tiling. ExecuTorch's Vulkan backend currently uses linear_vec / matmul_vec (scalar/vector compute tiles) uniformly regardless of device capability, leaving WMMA throughput on the table on capable hardware.

What changes

Area Change
Adapter.h +9 LOC. Adds Adapter::supports_cooperative_matrix() querying the cooperative_matrix_features physical-device field already populated in Device.cpp
New shaders linear_coopmat.glsl (+261) and matmul_coopmat.glsl (+227): fp16×fp16→fp32 cooperative-matrix MMA on 16×16×16 tiles; 64×64 output tile per 512-thread workgroup targeting subgroupSize=64
Linear.cpp / Linear.h Adds add_linear_coopmat_node + pickers; prepack_fp_linear_weight gains a force_buffer parameter so the coopmat path can obtain buffer-stored weights
Matmul.cpp Dispatch branch for both runtime-mat2 and constant-mat2 cases; routes constant-mat2 through the linear path, runtime-mat2 through add_matmul_coopmat_node
cm_utils.{h,cpp} queryCooperativeMatrixProperties() helper that prints the device's supported coopmat configs at startup (diagnostic only)
linear_coopmat_bench.cpp / matmul_coopmat_bench.cpp GPU-timestamp microbenchmarks comparing coopmat vs tiled across BERT/LLM/square shapes

How to test

1. Configure and build the core runtime

cmake . \
    -Bcmake-out-vk \
    --preset "linux" \
    -DCMAKE_INSTALL_PREFIX=cmake-out-vk \
    -DCMAKE_BUILD_TYPE=Release \
    -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
    -DEXECUTORCH_PAL_DEFAULT=posix \
    -DEXECUTORCH_BUILD_VULKAN=ON \
    -DEXECUTORCH_BUILD_TESTS=ON \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_FLAGS="-include algorithm"

cmake --build cmake-out-vk -j$(nproc) --target install --config Release

2. Configure and build the Vulkan custom ops (GEMM tests and benchmarks)

cmake backends/vulkan/test/custom_ops/ \
    -Bcmake-out-vk/backends/vulkan/test/custom_ops \
    -DCMAKE_INSTALL_PREFIX=cmake-out-vk \
    -DCMAKE_BUILD_TYPE=Release \
    -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
    -DEXECUTORCH_ROOT=$(pwd) \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache

cmake --build cmake-out-vk/backends/vulkan/test/custom_ops -j$(nproc)

3. Run the benchmarks on a device supporting VK_KHR_cooperative_matrix

./cmake-out-vk/backends/vulkan/test/custom_ops/linear_coopmat_bench
./cmake-out-vk/backends/vulkan/test/custom_ops/matmul_coopmat_bench

@SS-JIA

Convenience helper that queries VK_KHR_cooperative_matrix feature
support on the physical device. Used by the drop-in coopmat shader
variants to gate dispatch onto the tiled fallback when unsupported.
Adds VK_KHR_cooperative_matrix GLSL variants of the tiled linear and
matmul shaders. Dispatch is gated by
Adapter::supports_cooperative_matrix() and buffer output storage, with
automatic fallback to the tiled shader when unsupported. An M >= 64
guard avoids a known OOB in the current coopmat store; that guard will
be removed once partial-tile bounds checking is added to the shader.

Includes linear_coopmat_bench and matmul_coopmat_bench microbenchmarks
that compare against linear_vec / matmul_vec across BERT and LLM-sized
shapes using Vulkan query-pool timestamps.
Copilot AI review requested due to automatic review settings April 20, 2026 21:59
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19009

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 11 Awaiting Approval, 1 Unrelated Failure

As of commit 9605ece with merge base 8ed6e85 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla Bot commented Apr 20, 2026

Hi @xuyanwen2012!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a Vulkan cooperative-matrix (VK_KHR_cooperative_matrix / WMMA-style) fast path for linear/matmul when the device supports it and the output is buffer-backed, plus diagnostic tooling and microbenchmarks to compare against the existing tiled (*_vec) shaders.

Changes:

  • Introduces cooperative-matrix GLSL shaders and shader variants for linear and matmul.
  • Adds runtime dispatch branching to select coopmat vs tiled implementations, plus a supports_cooperative_matrix() adapter helper.
  • Adds coopmat diagnostics (cm_utils) and two microbenchmarks for linear/matmul.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
backends/vulkan/runtime/vk_api/Adapter.h Adds supports_cooperative_matrix() feature check.
backends/vulkan/runtime/graph/ops/impl/Matmul.cpp Adds coopmat node and dispatch selection for matmul (including constant-mat2 route via linear).
backends/vulkan/runtime/graph/ops/impl/Linear.h Extends prepack API and declares add_linear_coopmat_node.
backends/vulkan/runtime/graph/ops/impl/Linear.cpp Adds coopmat linear node + selection logic and a force_buffer prepack option.
backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.yaml Registers matmul coopmat shader variants (dtype).
backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.glsl New cooperative-matrix matmul shader (buffer-only).
backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.yaml Registers linear coopmat shader variants (dtype, bias).
backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.glsl New cooperative-matrix linear shader for prepacked weights (buffer-only).
backends/vulkan/test/custom_ops/cm_utils.h Declares cooperative-matrix property query helper for benchmarks/diagnostics.
backends/vulkan/test/custom_ops/cm_utils.cpp Implements cooperative-matrix property enumeration/printing.
backends/vulkan/test/custom_ops/linear_coopmat_bench.cpp Adds linear coopmat vs vec microbenchmark.
backends/vulkan/test/custom_ops/matmul_coopmat_bench.cpp Adds matmul coopmat vs vec microbenchmark.
backends/vulkan/test/custom_ops/CMakeLists.txt Wires new cm_utils + benchmark targets into the custom_ops build.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread backends/vulkan/runtime/graph/ops/impl/Linear.cpp Outdated
Comment thread backends/vulkan/runtime/graph/ops/impl/Linear.cpp Outdated
Comment thread backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.glsl Outdated
Comment thread backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.glsl Outdated
Comment thread backends/vulkan/runtime/graph/ops/impl/Matmul.cpp Outdated
Comment thread backends/vulkan/runtime/graph/ops/impl/Matmul.cpp Outdated
Comment thread backends/vulkan/runtime/vk_api/Adapter.h
Comment thread backends/vulkan/runtime/graph/ops/glsl/matmul_coopmat.glsl Outdated
Comment thread backends/vulkan/runtime/graph/ops/glsl/linear_coopmat.glsl Outdated
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 21, 2026
Comment on lines +241 to +259

#ifdef HAS_BIAS
// Add bias via read-modify-write on the output buffer.
// barrier() ensures all coopMatStore writes within this workgroup are visible.
barrier();

const uint tile_m_start = TILE_M * tileID.y;
const uint tile_n_start = TILE_N * tileID.x;
// 64x64 tile = 4096 elements, 256 threads -> 16 elements per thread
for (uint idx = gl_LocalInvocationID.x; idx < TILE_M * TILE_N; idx += INVOCATIONS) {
uint local_m = idx / TILE_N;
uint local_n = idx % TILE_N;
uint gm = tile_m_start + local_m;
uint gn = tile_n_start + local_n;
if (gm < M && gn < N) {
uint out_idx = gm * N + gn;
t_output[out_idx] = t_output[out_idx] + t_bias[gn];
}
}
Copy link
Copy Markdown
Contributor

@SS-JIA SS-JIA Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: fold bias into the accumulator before coopMatStore instead of doing a post-store Read-Modify-Write.

The current RMW pattern couples three issues that all go away if the bias is added to the accumulator before the store:

  1. t_output no longer needs rw — it can be declared w only. Today it's rw purely to make the load on line 257 valid.
  2. The cross-subgroup buffer barrier becomes unnecessary — and is currently missing. With GL_KHR_memory_scope_semantics : require (line 33), barrier() only orders workgroup-shared memory, not SSBOs. The RMW reads t_output[out_idx] written by other subgroups via coopMatStore, so as written this is a cross-subgroup buffer RAW with no buffer-memory ordering. The fix would be memoryBarrierBuffer(); barrier(); or, since the extension is enabled, a controlBarrier(...) with gl_StorageSemanticsBuffer | gl_SemanticsAcquireRelease. Folding bias in earlier sidesteps this entirely.
  3. The 16-iter-per-thread RMW loop disappears — replaced by one coopMatLoad per accumulator tile (a stride-0 broadcast load).

Sketch

Stage one row of bias into shared memory once, then broadcast it across all lM rows of every accumulator tile via a stride-0 row-major load:

#ifdef HAS_BIAS
shared float16_t bias_sh[TILE_N];

// One-time bias staging (before the store loop; could be before the K-loop too).
for (uint t = gl_LocalInvocationID.x; t < TILE_N; t += INVOCATIONS) {
    uint gn = TILE_N * tileID.x + t;
    bias_sh[t] = (gn < N) ? float16_t(t_bias[gn]) : float16_t(0);
}
memoryBarrierShared();
barrier();
#endif

[[unroll]] for (uint i = 0; i < C_ROWS; ++i) {
    [[unroll]] for (uint j = 0; j < C_COLS; ++j) {
        uint gi = TILE_M * tileID.y + lM * (C_ROWS * warpInTile.y + i);
        uint gj = TILE_N * tileID.x + lN * (C_COLS * warpInTile.x + j);

#ifdef HAS_BIAS
        // Stride-0 row-major load broadcasts lN bias values across all lM
        // rows of an accumulator-shaped coopmat.
        uint local_n = lN * (C_COLS * warpInTile.x + j);
        coopmat<float, gl_ScopeSubgroup, lM, lN, gl_MatrixUseAccumulator> bias_tile;
        coopMatLoad(bias_tile, bias_sh, local_n, /*stride=*/0,
                    gl_CooperativeMatrixLayoutRowMajor);
        result[i][j] += bias_tile;
#endif

        // ... existing coopMatStore(result[i][j], t_output, gi * N + gj, N, ...)
    }
}

Cost: TILE_N × dtype of extra shared memory (128 B fp16 / 256 B fp32) — rounding error against the ~9.5 KB already used by Ash/Bsh. Removes a workgroup-wide buffer barrier, makes t_output write-only, and eliminates a 4096-element RMW pass.

(Comment authored with Claude.)

Comment on lines +68 to +85

// Tile dimensions (same as matmul_coopmat)
const uint lM = 16;
const uint lN = 16;
const uint lK = 16;
const uint TILE_M = 64;
const uint TILE_N = 64;
const uint TILE_K = 32;

// Workgroup: 4 subgroups in 2x2 grid, 64 threads each = 256 total
const uint WG_WIDTH = 2;
const uint WG_HEIGHT = 2;
const uint NUM_SUBGROUPS = 4;
const uint INVOCATIONS = 64 * NUM_SUBGROUPS;

// Result tiles per subgroup: 2x2
const uint C_ROWS = TILE_M / WG_HEIGHT / lM; // 2
const uint C_COLS = TILE_N / WG_WIDTH / lN; // 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: rename these constants and lift them into linear_coopmat.yaml as variant parameters.

Naming

The shader actually has four tiling levels (global problem → workgroup tile → subgroup tile → MMA instruction), but only the smallest and largest are named here. The middle level (per-subgroup tile = 32×32) is implicit in C_ROWS / C_COLS, the magic 64 inside INVOCATIONS silently encodes "Adreno subgroup size", and WG_WIDTH / WG_HEIGHT are actually the subgroup-grid dims, not the workgroup's. Proposed renames borrow CUTLASS / Triton conventions so the hierarchy is visible from the names:

Current Proposed Meaning
lM, lN, lK MMA_M, MMA_N, MMA_K One cooperative-matrix instruction shape (per-subgroup, hardware-enumerated)
TILE_M, TILE_N, TILE_K WG_TILE_M, WG_TILE_N, WG_TILE_K Output tile produced per workgroup
(implicit) SG_TILE_M, SG_TILE_N Per-subgroup output tile (= WG_TILE / SG_GRID)
WG_WIDTH, WG_HEIGHT SG_GRID_X, SG_GRID_Y Subgroup grid inside the workgroup
(magic 64) SUBGROUP_SIZE Assumed Adreno subgroup width
INVOCATIONS WG_SIZE Total threads per workgroup
C_ROWS, C_COLS MMAS_PER_SG_M, MMAS_PER_SG_N MMA instructions per subgroup tile

With derived constants made explicit, the block becomes:

// Cooperative-matrix instruction shape (must match a property enumerated
// by vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR).
const uint MMA_M = ${MMA_M};
const uint MMA_N = ${MMA_N};
const uint MMA_K = ${MMA_K};

// Output tile produced per workgroup.
const uint WG_TILE_M = ${WG_TILE_M};
const uint WG_TILE_N = ${WG_TILE_N};
const uint WG_TILE_K = ${WG_TILE_K};

// Subgroup grid inside the workgroup; each subgroup owns a
// (WG_TILE_M / SG_GRID_Y) x (WG_TILE_N / SG_GRID_X) output region.
const uint SG_GRID_X     = ${SG_GRID_X};
const uint SG_GRID_Y     = ${SG_GRID_Y};
const uint NUM_SUBGROUPS = SG_GRID_X * SG_GRID_Y;
const uint SUBGROUP_SIZE = ${SUBGROUP_SIZE};
const uint WG_SIZE       = NUM_SUBGROUPS * SUBGROUP_SIZE;

// Derived: per-subgroup tile and MMAs per subgroup tile.
const uint SG_TILE_M     = WG_TILE_M / SG_GRID_Y;
const uint SG_TILE_N     = WG_TILE_N / SG_GRID_X;
const uint MMAS_PER_SG_M = SG_TILE_M / MMA_M;
const uint MMAS_PER_SG_N = SG_TILE_N / MMA_N;

Index math like lM * (C_ROWS * warpInTile.y + i) becomes MMA_M * (MMAS_PER_SG_M * warpInTile.y + i) — longer, but the names tell the reader what level each factor lives at.

Lift the constants into linear_coopmat.yaml

The four numbers here — MMA shape, workgroup-tile shape, subgroup grid, and subgroup size — are the autotuning knobs for this kernel. Different devices want different values:

  • Adreno enumerates 16×16×16 fp16, but Mali / NV / Intel may only expose 8×8×32 or 16×8×16.
  • A bigger WG_TILE_* improves arithmetic intensity but uses more shared memory and registers — sweet spot is per-device.
  • SUBGROUP_SIZE is hardware-fixed (64 Adreno, 32 everyone else) and the index math currently assumes 64 — making this a yaml parameter forces every variant to declare what it expects.

Lift them into yaml so the codegen can produce drop-in variants:

linear_coopmat:
  parameter_names_with_default_values:
    DTYPE: float
    STORAGE: buffer
    PRECISION: highp
    MMA_M: 16
    MMA_N: 16
    MMA_K: 16
    WG_TILE_M: 64
    WG_TILE_N: 64
    WG_TILE_K: 32
    SG_GRID_X: 2
    SG_GRID_Y: 2
    SUBGROUP_SIZE: 64
    HAS_BIAS: false
  generate_variant_forall:
    DTYPE:
      - VALUE: float
      - VALUE: half
  shader_variants:
    - NAME: linear_coopmat
    - NAME: linear_coopmat_bias
      HAS_BIAS: true
    # Future-friendly: drop-in variants for other device tiers without
    # touching the GLSL.
    # - NAME: linear_coopmat_8x8x32
    #   MMA_M: 8
    #   MMA_N: 8
    #   MMA_K: 32
    # - NAME: linear_coopmat_sg32_128x128
    #   SUBGROUP_SIZE: 32
    #   SG_GRID_X: 4
    #   SG_GRID_Y: 2
    #   WG_TILE_M: 128
    #   WG_TILE_N: 128

The C++ pick function (pick_linear_coopmat_shader at Linear.cpp:248-258) can then select the variant whose (MMA_M, MMA_N, MMA_K) matches a property returned by vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR and whose SUBGROUP_SIZE matches adapter->subgroup_size(). That eliminates the silent-miscompute path on non-Adreno devices and makes the kernel portable.

Caveats

  • Validity constraints must be enforced by the dispatch (or static_assert-style failures at codegen time): WG_TILE_M % (SG_GRID_Y * MMA_M) == 0, WG_TILE_N % (SG_GRID_X * MMA_N) == 0, WG_TILE_K % MMA_K == 0, and WG_TILE_M * (WG_TILE_K + FP16_PER_VEC4) / FP16_PER_VEC4 + WG_TILE_K * (WG_TILE_N + FP16_PER_VEC4) / FP16_PER_VEC4 ≤ device shared-memory budget.
  • Binary-size cost: every variant is a separate SPIR-V blob. Keep the explicit set small; don't generate_variant_forall over the tile knobs.
  • The 64-element coopMatStore row stride (currently passed as N) doesn't depend on these constants, so renames don't ripple into the store path.

(Comment authored with Claude.)

Comment on lines +148 to +150
Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(
packHalf2x16(vec2(v0.xy)), packHalf2x16(vec2(v0.zw)),
packHalf2x16(vec2(v1.xy)), packHalf2x16(vec2(v1.zw)));
Copy link
Copy Markdown
Contributor

@SS-JIA SS-JIA Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: use packFloat2x16 directly to avoid an fp16 → fp32 → fp16 roundtrip on the fp16-input path.

packHalf2x16(vec2) takes fp32 inputs and converts them to fp16 before packing. Here v0 is already f16vec4, so vec2(v0.xy) upcasts each fp16 to fp32 just so packHalf2x16 can downcast it right back. The data flow is:

   fp16 (loaded from t_mat1)
     │
     │  vec2(...)         ← upcast to fp32 (wasteful)
     ▼
   fp32 (intermediate)
     │
     │  packHalf2x16      ← downcast back to fp16 + pack
     ▼
   fp16 packed in uint32

Since GL_EXT_shader_explicit_arithmetic_types_float16 is enabled (line 36), the direct form is packFloat2x16(f16vec2) — packs an fp16 pair into a uint32 with no float-width conversion:

Ash[a_row_offset * A_STRIDE_VEC4 + a_col] = uvec4(
    packFloat2x16(v0.xy), packFloat2x16(v0.zw),
    packFloat2x16(v1.xy), packFloat2x16(v1.zw));

Same pattern appears in the B loader at lines 175-177 — should be updated together.

(Comment authored with Claude.)

const uint a_row_base = TILE_M * tileID.y;
const uint b_col_base = TILE_N * tileID.x;

for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bounds bugs: the K-loop and the load math both silently overrun when K or N are not multiples of specific sizes.

This K-loop body has two latent OOB conditions, neither guarded in the shader and neither guaranteed by the C++ pick gate (Linear.cpp:354 only checks M >= 64). Both fire on common production shapes.

1. K-loop has no tail handling (this line)

for (uint chunkK = 0; chunkK < K; chunkK += TILE_K) { ... }

When K is not a multiple of TILE_K = 32, the last iteration runs past the end of K. Inside it:

  • The A loader reads t_mat1[row * K4 + k_hv4 + 1] (line 147) → reads from the next row (or past the buffer).
  • The B loader reads t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk] (line 174) → reads off the K-block.
  • The MMAs then accumulate garbage K-columns into the result.

Concrete: K = 33, chunkK = 32 starts a second iteration with k_elem = 32..63, but K = 33 → positions 33..63 are out-of-bounds.

2. N%8 OOB on the B-loader's second vec4 (line 174)

Independent of the K-loop issue:

v1 = t_weight_packed[(k4 * N4 + n4_0 + 1u) * 4u + dk];

When N is not a multiple of 8, the thread with b_col = INVS_PER_ROW_B - 1 = 7 (last column group) reads n4_0 + 1 == N4, off the end of the K-row. The +1 advances unconditionally.

What the shader actually requires

None of these are stated in a comment, and only M >= 64 (which is incorrect — should be M % 64 == 0) is checked at dispatch.

Suggested fix

Tighten the C++ gate at Linear.cpp:353-356:

bool use_coopmat = adapter->supports_cooperative_matrix()
                && M % kLinearCoopMatTileM == 0
                && N % kLinearCoopMatTileN == 0
                && K % kLinearCoopMatTileK == 0;

That eliminates the silent miscompute. Adding partial-tile / K-tail handling to the shader is a follow-up; without it the kernel won't apply to common Llama shapes (e.g. N = 11008, N = 14336 — both non-multiples of 64).

The TODO at Linear.cpp:347 flags this but the current implementation is incomplete.

(Comment authored with Claude.)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 02faae5. Implemented the suggestion.

* LICENSE file in the root directory of this source tree.
*/

/*
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (follow-up): consolidate matmul_coopmat.glsl and linear_coopmat.glsl into a single shader.

These two shaders are ~95% identical — the same extensions, constants, A-loader, MMA inner loop, store, and shared-memory layout. Only 5 things differ:

  1. Output access mode (rw for linear+bias, w otherwise)
  2. Bias binding + post-store RMW pass (linear-bias only)
  3. B binding name (t_weight_packed vs t_mat2)
  4. B-load address math (4OC×4IC prepacked vs row-major)
  5. N source UBO (out_sizes.x vs mat2_sizes.x)

All five are switchable via existing yaml-parameter / Jinja conventions in this codebase (see linear_qcsnw.glsl, q_8w_linear*.glsl for precedent).

Sketch

One shader coopmat_mm.glsl with WEIGHT_LAYOUT and HAS_BIAS as variant parameters:

# coopmat_mm.yaml
coopmat_mm:
  parameter_names_with_default_values:
    DTYPE: float
    PRECISION: highp
    WEIGHT_LAYOUT: row_major
    HAS_BIAS: false
  generate_variant_forall:
    DTYPE: [{VALUE: float}, {VALUE: half}]
  shader_variants:
    - NAME: matmul_coopmat
      WEIGHT_LAYOUT: row_major
    - NAME: linear_coopmat
      WEIGHT_LAYOUT: prepacked
    - NAME: linear_coopmat_bias
      WEIGHT_LAYOUT: prepacked
      HAS_BIAS: true
// In the B-loader inside the K-loop:
$if WEIGHT_LAYOUT == "prepacked":
    uint k4 = k_row >> 2u;
    uint dk = k_row & 3u;
    v0 = t_b[(k4 * N4 + n4_0)     * 4u + dk];
    v1 = t_b[(k4 * N4 + n4_0 + 1) * 4u + dk];
$else:
    v0 = t_b[k_row * N4 + n_hv4];
    v1 = t_b[k_row * N4 + n_hv4 + 1];

Diff cost

delete   matmul_coopmat.{glsl,yaml}              -243 LOC
modify   linear → coopmat_mm.{glsl,yaml}         +~15 LOC
──────────────────────────────────────────────────────────
net                                              ~-225 LOC

Zero C++ changes if the existing shader variant names are kept — Linear.cpp and Matmul.cpp look up shaders by string name, and the names stay the same; only the source file changes.

Why it matters beyond cleanup

Every bug fix currently has to be applied in two places. All four inline comments on this PR (K-loop tail, fp16 → packFloat2x16 roundtrip, naming/yaml-lift, bias-RMW restructure) are duplicated work between the two shaders. Consolidating means one fix, one verification, one place for future MMA-property gating + partial-tile handling.

Recommendation

Land this PR as-is with the inline bug fixes, then file a follow-up for consolidation. Doing the refactor in this PR muddies review; doing it after the fixes means the consolidation migrates correct code rather than buggy code that needs fixing twice during the merge.

(Comment authored with Claude.)


// ── Cooperative matrix linear ──

static constexpr uint32_t kLinearCoopMatTileM = 64;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend moving the C++ dispatch logic for the cooperative matrix shaders to its own file:

GemmCoopmat.[h|cpp] -- we can put all dispatch logic for linear / matmul shaders that use cooperative matrices here

GemmCommon.[h|cpp] -- we can put functions currently defined in Linear.cpp and Matmul.cpp that can be shared with GemmCoopmat.cpp, such as the resize functions, here

* LICENSE file in the root directory of this source tree.
*/

// Microbenchmark: linear_coopmat vs linear_vec.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (follow-up): drop linear_coopmat_bench.cpp (+ matmul) and extend test_mm.cpp with an impl_selector argument — the same pattern that test_conv2d_dw.cpp already uses.

test_mm.cpp (465 LOC, already in this directory) covers the full matmul/linear test surface — linear, linear-bias, addmm, bmm, mm, with prepacked + non-prepacked mat2, batched + non-batched, fp16 + fp32, all storage types and memory layouts. The two new bench files duplicate ~450 LOC of scaffolding only to drive a different shader path.

Existing precedent: test_conv2d_dw.cpp + impl/TestConv2dPw.cpp

The conv-dw test already plumbs an impl_selector string through the test op:

// test_conv2d_dw.cpp:55  — default-empty string parameter
static TestCase create_conv2d_dw_test_case(
    const Conv2dDwConfig& config, vkapi::ScalarType dtype,
    utils::StorageType storage_type, utils::GPUMemoryLayout memory_layout,
    const std::string& impl_selector = "");

// test_conv2d_dw.cpp:144  — wire through as ValueSpec
test_case.add_input_spec(ValueSpec::make_string(impl_selector));

// test_conv2d_dw.cpp:374-385  — sweep variants for the same shape
test_cases.push_back(create_conv2d_dw_test_case(cfg, dtype, st, layout));            // auto
test_cases.push_back(create_conv2d_dw_test_case(cfg, dtype, st, layout, "b4x2"));    // force
test_cases.push_back(create_conv2d_dw_test_case(cfg, dtype, st, layout, "b1x1"));    // force

Receiving end (scaffold from impl/TestConv2dPw.cpp:24-28):

const ValueRef impl_selector_str = args.at(3);
std::string impl_selector = graph.extract_string(impl_selector_str);
// -- branch on impl_selector to force a specific add_*_node call --

For matmul/linear, the parallel wiring lives in impl/TestMatmulLinear.cpp (already registers test_etvk.test_mm.default). Adding the same selector arg + a branch to add_linear_coopmat_node / add_linear_tiled_node / add_matmul_coopmat_node / add_matmul_tiled_node is the natural extension.

Sketch

// test_mm.cpp — extend MmConfig
struct MmConfig {
  int64_t B, M, K, N;
  bool has_bias, mat2_is_transposed, mat2_is_constant;
  std::string impl_selector;     // ← "" = auto, "coopmat" or "tiled" to force
};

// create_mm_test_case(...)
test_case.add_input_spec(ValueSpec::make_string(config.impl_selector));
// impl/TestMatmulLinear.cpp — branch in the op handler
std::string impl = graph.extract_string(args.at(/* selector idx */));
if (impl == "coopmat") {
    // bypass auto-router; call add_linear_coopmat_node / add_matmul_coopmat_node
} else if (impl == "tiled") {
    // call add_linear_tiled_node / add_matmul_tiled_node
} else {
    // existing aten.mm / aten.linear routing (preserve current behavior)
}

Diff

delete linear_coopmat_bench.cpp                  -196 LOC
delete matmul_coopmat_bench.cpp                  -254 LOC
delete CMakeLists.txt entries (2)                  ...
modify test_mm.cpp                               +~20 LOC
modify impl/TestMatmulLinear.cpp                 +~25 LOC
────────────────────────────────────────────────────────
net                                              ~-410 LOC

Why

  1. Consistent with the existing impl_selector pattern in test_conv2d_dw.cpp + impl/TestConv2dPw.cpp — no new convention.
  2. A/B comparisons in one binary — sweep the same shape across "", "tiled", "coopmat" in a single run for directly-comparable timings.
  3. Consistent shape coverage — every shape currently tested for the tiled path automatically gets coopmat coverage, including the unaligned shapes that would have caught the K-tail / N-tail bugs flagged at #r3171002711.
  4. Tighter tolerances — the new bench files set abs=rel=5e-1; test_mm.cpp already uses dtype-appropriate tolerances. Reusing means free correctness checks at production-grade precision.
  5. Future shader variants get free coverage — the GLSL consolidation at #r3171100736 plus a future "coopmat_8x8x32" (Mali tier) plug straight into the same test cases.

Recommendation

Same as the other refactors: land this PR with the bench files as-is, then file a follow-up to consolidate into test_mm.cpp once the bug fixes have landed.

(Comment authored with Claude.)

int64_t M = input_sizes.size() >= 2
? input_sizes.at(input_sizes.size() - 2)
: 1;
bool use_coopmat =
Copy link
Copy Markdown
Contributor

@SS-JIA SS-JIA May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: gate the coopmat path on non-integrated GPUs using VkPhysicalDeviceType.

The coopmat shader's design — 256-thread workgroups, 4 subgroups, ~9.5 KB of shared memory per workgroup, fp32 accumulator coopmats — is tuned for desktop-class GPUs. On integrated mobile GPUs (Adreno, Mali) the cost profile inverts:

  • Large workgroups stress mobile register files. Adreno/Mali per-wave register files are smaller than desktop NV/AMD; 256-thread workgroups with fp32 accumulators and shared-mem staging spill more easily, hurting occupancy.
  • Heavy shared-memory use is less of a win on TBDR architectures. Adreno's shared-mem peak doesn't match desktop-class L1 throughput; the tiled linear_vec shader's lighter shared-mem footprint generally beats coopmat on Adreno 7xx.
  • Cooperative-matrix MMA throughput on mobile is modest. The headline FLOPs win that justifies coopmat on NV doesn't materialize at the same scale on Adreno.

Today the gate (supports_cooperative_matrix() + buffer storage) routes coopmat only to Adreno, since Adreno is currently the only mobile GPU in the codebase that exposes the extension. That's the device where the design is least well-suited.

Reliable detection: VkPhysicalDeviceType

VkPhysicalDeviceProperties::deviceType is driver-reported, not name-inferred:

VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU = 1,   // Adreno, Mali, Intel iGPU, AMD APU
VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU   = 2,   // NVIDIA desktop, AMD desktop, Intel Arc
VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU    = 3,
VK_PHYSICAL_DEVICE_TYPE_CPU            = 4,   // SwiftShader, lavapipe

More reliable than the existing vkapi::DeviceType enum ({NVIDIA, MALI, ADRENO, SWIFTSHADER} from Device.h:21-27), which infers vendor from deviceName string parsing:

Concern Name-based vkapi::DeviceType VkPhysicalDeviceType
Future vendors Add enum entry + update parser Automatic — driver reports
Same-vendor variants (e.g. Tegra integrated vs RTX discrete) Cannot distinguish Distinguishes correctly
Authoritativeness Inferred from string Driver-reported
Software emulators Needs explicit SWIFTSHADER handling Reports CPU automatically

The data is already populated in the codebase — Adapter.cpp:398 reads properties.deviceType for logging via Adapter::stringize(). We just don't expose a programmatic accessor.

Suggested change

1. Add accessors to Adapter.h (near line 138, alongside the existing device_type()):

inline VkPhysicalDeviceType physical_device_type() const {
  return physical_device_.properties.deviceType;
}

inline bool is_integrated_gpu() const {
  return physical_device_.properties.deviceType ==
         VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU;
}

inline bool is_discrete_gpu() const {
  return physical_device_.properties.deviceType ==
         VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU;
}

Keep the existing device_type() for vendor-specific workarounds ("apply this Adreno bug fix"); use the new helpers for capability-tier decisions like this one.

2. Add passthrough wrappers to ComputeGraph.h (near lines 681-687, alongside the existing device_is_adreno() / device_is_mali()):

inline bool is_integrated_gpu() {
  return context_->adapter_ptr()->is_integrated_gpu();
}

inline bool is_discrete_gpu() {
  return context_->adapter_ptr()->is_discrete_gpu();
}

This matches the existing convention of surfacing adapter capabilities directly on the graph (e.g., device_is_adreno() / device_is_mali() at lines 681-687, plus int16_shader_types_enabled() / float16_buffers_enabled() / int8_buffers_enabled() at lines 1158-1186), so call sites don't need to drill through graph.context()->adapter_ptr().

3. Update the gate at Linear.cpp:353-356 — now reads cleanly via the graph passthrough:

bool use_coopmat =
    graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
    !graph.is_integrated_gpu() &&                  // ← new gate, via ComputeGraph passthrough
    graph.storage_type_of(out) == utils::kBuffer &&
    M % kLinearCoopMatTileM == 0 &&                // alignment guards
    N % kLinearCoopMatTileN == 0 &&                // (per #r3171002711)
    K % kLinearCoopMatTileK == 0;

!graph.is_integrated_gpu() is the most precise expression of intent: any discrete or virtual GPU is fine; integrated (mobile, iGPU, APU) gets the tiled path.

(Optional bonus) supports_cooperative_matrix() itself could get the same passthrough treatment so the whole gate reads as graph.supports_cooperative_matrix() && !graph.is_integrated_gpu() && ... — natural extension of the same convention.

4. Same change at Matmul.cpp:281-283 and :301-303. Folding all three call sites' gating into a single Adapter::can_use_coopmat_mm_16x16x16_fp16_fp32() helper (which also handles the MMA-property check from the Adapter.h comment we discussed) — exposed via ComputeGraph passthrough — avoids triplicating the gate logic.

Why now

Without this gate, the coopmat path will likely regress performance on Adreno (the primary ETVK target) once the alignment-guard bugs are fixed and the kernel actually runs on more shapes. The intended outcome of the PR — coopmat as a perf win — is realized only on devices it's tuned for. On Adreno, the existing linear_vec is already shipping and tuned.

(Comment authored with Claude.)

Three correctness fixes flagged on PR pytorch#19009.

1. The linear_coopmat / matmul_coopmat dispatch gate previously only checked `M >= 64`. We now tighten the gates in `Linear.cpp` and `Matmul.cpp` to require `M % TILE_M == 0 && N % TILE_N == 0 && K % TILE_K == 0`; misaligned shapes correctly fall back to the tiled shader.

2. The bias path in `linear_coopmat.glsl` previously read the just-written output buffer back, added bias, and wrote it again. We now fold bias into the fp32 accumulator before `coopMatStore`. The binding now becomes `w` instead of `rw`.

3. We now use `packFloat2x16` directly to avoid fp16 -> fp32 -> fp16 round trip.
@xuyanwen2012
Copy link
Copy Markdown
Author

@SS-JIA
Hi Stephen, Thank you for your feedbacks, I have applied the immediate fixes on (Bias RMW), (packFloat2x16) and (K-loop + N-loader OOB) in #02faae5255

I will address the rest of the consolidation fixes as suggested in a followup PR.

@xuyanwen2012 xuyanwen2012 requested a review from SS-JIA May 1, 2026 21:38
@xuyanwen2012
Copy link
Copy Markdown
Author

@pytorchbot label "release notes: vulkan"

@pytorch-bot pytorch-bot Bot added the release notes: vulkan Changes to the Vulkan backend delegate label May 1, 2026
Three follow-ups to PR pytorch#19009 review.

1. Consolidated linear_coopmat.glsl + matmul_coopmat.glsl into a single
coopmat_mm.glsl (previously 95% identical). Lifted tile constants into yaml.

2. Split coopmat dispatch out of Linear.cpp / Matmul.cpp into
GemmCoopmat.{h,cpp} and shared helpers into GemmCommon.{h,cpp}
(Linear.cpp / Matmul.cpp shrink ~311 LOC combined).

3. Dropped linear_coopmat_bench.cpp and matmul_coopmat_bench.cpp (~450 LOC
of duplicated scaffolding). Coopmat A/B coverage moved into test_mm.cpp via
impl_selector.
Copilot AI review requested due to automatic review settings May 1, 2026 23:01
@xuyanwen2012
Copy link
Copy Markdown
Author

@SS-JIA

Following up on that. Now I have consolidated the code. Here are the changes.

  1. Consolidated linear_coopmat.glsl + matmul_coopmat.glsl into a single
    coopmat_mm.glsl (previously 95% identical). Lifted tile constants into yaml.

  2. Split coopmat dispatch out of Linear.cpp / Matmul.cpp into
    GemmCoopmat.{h,cpp} and shared helpers into GemmCommon.{h,cpp}
    (Linear.cpp / Matmul.cpp shrink ~311 LOC combined).

  3. Dropped linear_coopmat_bench.cpp and matmul_coopmat_bench.cpp (~450 LOC
    of duplicated scaffolding). Coopmat A/B coverage moved into test_mm.cpp via
    impl_selector.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +64 to +77
void add_linear_coopmat_node(
ComputeGraph& graph,
const ValueRef input,
const ValueRef packed_weight,
const ValueRef packed_bias,
bool has_bias,
const ValueRef out,
int32_t weight_B) {
(void)weight_B;
VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim);
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
VK_CHECK_COND(
graph.storage_type_of(out) == utils::kBuffer,
"linear_coopmat requires buffer storage");
Comment on lines +255 to +262
inline bool supports_cooperative_matrix() {
#ifdef VK_KHR_cooperative_matrix
return physical_device_.cooperative_matrix_features.cooperativeMatrix ==
VK_TRUE;
#else
return false;
#endif /* VK_KHR_cooperative_matrix */
}
Comment on lines +191 to +205
// The coopmat shader uses fp16 intermediates regardless of input dtype
// (inputs are packHalf2x16-converted before entering the MMA), so the
// achievable precision is fp16-bounded for any path that dispatches to
// it. A coopmat dispatch occurs when impl_selector forces it, or when
// the default routing's gate (buffer + M/N/K alignment) is met.
bool routes_to_coopmat = false;
if (config.impl_selector == "coopmat") {
routes_to_coopmat = true;
} else if (
config.impl_selector == "default" && storage_type == utils::kBuffer &&
config.M % 64 == 0 && config.N % 64 == 0 && config.K % 32 == 0) {
routes_to_coopmat = true;
}

if (dtype == vkapi::kHalf || routes_to_coopmat) {
Comment on lines +388 to +390
const std::vector<std::string> impl_selectors_default = {"default"};
const std::vector<std::string> impl_selectors_aligned = {
"default", "tiled", "coopmat"};
Comment on lines +31 to +40
inline bool is_coopmat_eligible(
ComputeGraph& graph,
const ValueRef out,
int64_t M,
int64_t N,
int64_t K) {
return graph.context()->adapter_ptr()->supports_cooperative_matrix() &&
graph.storage_type_of(out) == utils::kBuffer && M % kCoopmatTileM == 0 &&
N % kCoopmatTileN == 0 && K % kCoopmatTileK == 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: vulkan Changes to the Vulkan backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants