Skip to content

Commit f22e3a9

Browse files
authored
webgpu: Optimize DP4A SmallM MatMulNBits tiling (#27910)
This pull request adjusts the tiling strategy for small matrix sizes in the DP4A matmul kernel. The changes are aimed at improving performance and compatibility, especially for specific GPU vendors. On Qualcomm, improving token generation from ~20 tps to ~25 tps.
1 parent 048e7dc commit f22e3a9

2 files changed

Lines changed: 3 additions & 7 deletions

File tree

onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,9 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
128128
const bool has_weight_idx_indirect = weight_index_indirect != nullptr;
129129
const bool single_scale_weights = (block_size == K * N);
130130
if (M < min_M_for_tile_optimization) {
131-
uint32_t tile_size_k_vec = 16;
132-
uint32_t tile_size_n = 32;
131+
uint32_t tile_size_k_vec = 32;
132+
uint32_t tile_size_n = 4;
133133

134-
if (context.AdapterInfo().vendor == std::string_view{"intel"}) {
135-
tile_size_k_vec = 32;
136-
tile_size_n = 4;
137-
}
138134
const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components);
139135
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights};
140136
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
230230
}
231231
#endif
232232

233-
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
233+
// On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
234234
// DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
235235
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
236236
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {

0 commit comments

Comments
 (0)