|
9 | 9 | #include <executorch/backends/webgpu/runtime/WebGPUGraph.h> |
10 | 10 | #include <executorch/backends/webgpu/runtime/WebGPUUtils.h> |
11 | 11 | #include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h> |
| 12 | +#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h> |
12 | 13 | #include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h> |
13 | 14 |
|
14 | 15 | #include <webgpu/webgpu.h> |
@@ -89,18 +90,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
89 | 90 | "WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)"); |
90 | 91 | } |
91 | 92 |
|
92 | | - // Register-tiled GEMM: one thread per TM x TN tile; validate before alloc. |
93 | | - const uint32_t wg_size = |
94 | | - utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); |
95 | | - const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) * |
96 | | - utils::div_up<int64_t>(N, kQ4gswTileN); |
97 | | - if (total_tiles > static_cast<int64_t>(UINT32_MAX)) { |
98 | | - throw std::runtime_error( |
99 | | - "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); |
100 | | - } |
101 | | - const uint32_t workgroup_count = utils::compute_1d_workgroup_count( |
102 | | - device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw"); |
103 | | - |
104 | 93 | // fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail. |
105 | 94 | const uint64_t scales_numel = |
106 | 95 | static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N); |
@@ -128,6 +117,35 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
128 | 117 | "WebGPU linear_q4gsw: scales dims too small for K/N"); |
129 | 118 | } |
130 | 119 |
|
| 120 | + // M==1 decode -> coop4 GEMV (needs K%8==0 && gs%8==0); else tiled GEMM. |
| 121 | + const uint32_t wg_size = |
| 122 | + utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); |
| 123 | + const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u); |
| 124 | + const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL; |
| 125 | + uint32_t workgroup_count; |
| 126 | + if (use_gemv) { |
| 127 | + // coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N. |
| 128 | + const uint64_t outputs = |
| 129 | + static_cast<uint64_t>(M) * static_cast<uint64_t>(N); |
| 130 | + if (outputs == 0u || outputs > UINT32_MAX) { |
| 131 | + throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range"); |
| 132 | + } |
| 133 | + workgroup_count = |
| 134 | + utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs)); |
| 135 | + if (workgroup_count == 0u) { |
| 136 | + throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch"); |
| 137 | + } |
| 138 | + } else { |
| 139 | + const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) * |
| 140 | + utils::div_up<int64_t>(N, kQ4gswTileN); |
| 141 | + if (total_tiles > static_cast<int64_t>(UINT32_MAX)) { |
| 142 | + throw std::runtime_error( |
| 143 | + "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); |
| 144 | + } |
| 145 | + workgroup_count = utils::compute_1d_workgroup_count( |
| 146 | + device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw"); |
| 147 | + } |
| 148 | + |
131 | 149 | // Optional bias: real buffer if present, else a dummy for the fixed layout. |
132 | 150 | uint32_t has_bias = 0; |
133 | 151 | WGPUBuffer bias_buffer = nullptr; |
@@ -168,7 +186,7 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
168 | 186 |
|
169 | 187 | WGPUShaderSourceWGSL wgsl_desc = {}; |
170 | 188 | wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; |
171 | | - wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN}; |
| 189 | + wgsl_desc.code = {shader_src, WGPU_STRLEN}; |
172 | 190 | WGPUShaderModuleDescriptor shader_desc = {}; |
173 | 191 | shader_desc.nextInChain = &wgsl_desc.chain; |
174 | 192 | WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); |
@@ -206,8 +224,9 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) { |
206 | 224 | pipeline_desc.layout = pipeline_layout; |
207 | 225 | pipeline_desc.compute.module = shader; |
208 | 226 | pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; |
209 | | - pipeline_desc.compute.constantCount = 1; |
210 | | - pipeline_desc.compute.constants = &wg_size_constant; |
| 227 | + // coop4 GEMV uses fixed @workgroup_size(64); only the GEMM has an override. |
| 228 | + pipeline_desc.compute.constantCount = use_gemv ? 0u : 1u; |
| 229 | + pipeline_desc.compute.constants = use_gemv ? nullptr : &wg_size_constant; |
211 | 230 | WGPUComputePipeline pipeline = |
212 | 231 | wgpuDeviceCreateComputePipeline(device, &pipeline_desc); |
213 | 232 |
|
|
0 commit comments