@@ -664,7 +664,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
664664 }
665665 const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v ) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
666666 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES ;
667- size_t bytes_per_kv = 0 ;
667+ size_t bytes_per_kv = 0 ;
668668 if (!key.kv_direct ) {
669669 bytes_per_kv += std::max (key.head_dim_qk , key.head_dim_v );
670670 }
@@ -701,10 +701,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
701701 (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u );
702702 const bool kv_vec_type_supported =
703703 K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 ;
704- const bool use_vec = context.supports_subgroups && (context.src0 ->ne [1 ] < 20 ) && (context.src0 ->ne [0 ] % 32 == 0 ) &&
705- (context.src2 ->ne [0 ] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0 ) &&
706- kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
707- (context.src2 ->type == K->type );
704+ const bool use_vec = context.supports_subgroups && (context.src0 ->ne [1 ] < 20 ) && (context.src0 ->ne [0 ] % 32 == 0 ) &&
705+ (context.src2 ->ne [0 ] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0 ) &&
706+ kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
707+ (context.src2 ->type == K->type );
708708 const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 &&
709709 V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
710710 (context.src0 ->ne [0 ] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0 ) &&
@@ -862,9 +862,12 @@ struct ggml_webgpu_mul_mat_shader_decisions {
862862struct ggml_webgpu_mul_mat_id_pipeline_key {
863863 ggml_type src0_type;
864864 ggml_type src1_type;
865+ uint32_t n_experts;
866+ int vectorized;
865867
866868 bool operator ==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
867- return src0_type == other.src0_type && src1_type == other.src1_type ;
869+ return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
870+ vectorized == other.vectorized ;
868871 }
869872};
870873
@@ -873,6 +876,8 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
873876 size_t seed = 0 ;
874877 ggml_webgpu_hash_combine (seed, key.src0_type );
875878 ggml_webgpu_hash_combine (seed, key.src1_type );
879+ ggml_webgpu_hash_combine (seed, key.n_experts );
880+ ggml_webgpu_hash_combine (seed, key.vectorized );
876881 return seed;
877882 }
878883};
@@ -1023,6 +1028,8 @@ class ggml_webgpu_shader_lib {
10231028 std::unordered_map<int , webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
10241029 std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
10251030 mul_mat_id_pipelines; // src0_type/src1_type
1031+ std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
1032+ mul_mat_id_vec_pipelines; // src0_type/src1_type
10261033
10271034 std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
10281035 set_rows_pipelines;
@@ -1516,7 +1523,7 @@ class ggml_webgpu_shader_lib {
15161523 key.type = context.dst ->type ;
15171524 key.d_state = (int ) context.src0 ->ne [0 ];
15181525 key.xbc_overlap = ggml_webgpu_tensor_overlap (context.src1 , context.src4 ) &&
1519- ggml_webgpu_tensor_overlap (context.src1 , context.src5 );
1526+ ggml_webgpu_tensor_overlap (context.src1 , context.src5 );
15201527
15211528 auto it = ssm_scan_pipelines.find (key);
15221529 if (it != ssm_scan_pipelines.end ()) {
@@ -1633,10 +1640,10 @@ class ggml_webgpu_shader_lib {
16331640 ggml_webgpu_mul_mat_vec_pipeline_key key = {};
16341641 key.src0_type = context.src0 ->type ;
16351642 key.src1_type = context.src1 ->type ;
1636- key.vectorized = (context.src0 ->ne [0 ] % 4 == 0 &&
1643+ key.vectorized = (context.src0 ->ne [0 ] % 4 == 0 &&
16371644 (context.src0 ->type == GGML_TYPE_F32 || context.src0 ->type == GGML_TYPE_F16 )) ?
1638- 1 :
1639- 0 ;
1645+ 1 :
1646+ 0 ;
16401647
16411648 auto it = mul_mat_vec_pipelines.find (key);
16421649 if (it != mul_mat_vec_pipelines.end ()) {
@@ -2012,6 +2019,11 @@ class ggml_webgpu_shader_lib {
20122019 ggml_webgpu_mul_mat_id_pipeline_key key = {};
20132020 key.src0_type = context.src0 ->type ;
20142021 key.src1_type = context.src1 ->type ;
2022+ key.n_experts = context.src0 ->ne [2 ];
2023+ key.vectorized = (context.src0 ->ne [0 ] % 4 == 0 && context.src0 ->ne [1 ] % 4 == 0 &&
2024+ (context.src0 ->type == GGML_TYPE_F32 || context.src0 ->type == GGML_TYPE_F16 )) ?
2025+ 1 :
2026+ 0 ;
20152027
20162028 auto it = mul_mat_id_pipelines.find (key);
20172029 if (it != mul_mat_id_pipelines.end ()) {
@@ -2041,14 +2053,12 @@ class ggml_webgpu_shader_lib {
20412053 switch (context.src0 ->type ) {
20422054 case GGML_TYPE_F32 :
20432055 defines.push_back (" SRC0_INNER_TYPE=f32" );
2044- defines.push_back (" FLOAT" );
20452056 defines.push_back (" INIT_SRC0_SHMEM_FLOAT" );
20462057 defines.push_back (" INIT_SRC1_SHMEM_FLOAT" );
20472058 variant += " _f32" ;
20482059 break ;
20492060 case GGML_TYPE_F16 :
20502061 defines.push_back (" SRC0_INNER_TYPE=f16" );
2051- defines.push_back (" FLOAT" );
20522062 defines.push_back (" INIT_SRC0_SHMEM_FLOAT" );
20532063 defines.push_back (" INIT_SRC1_SHMEM_FLOAT" );
20542064 variant += " _f16" ;
@@ -2064,12 +2074,32 @@ class ggml_webgpu_shader_lib {
20642074 defines.push_back (" U32_DEQUANT_HELPERS" );
20652075 defines.push_back (" SRC0_INNER_TYPE=u32" );
20662076
2077+ switch (context.src0 ->type ) {
2078+ case GGML_TYPE_IQ1_S :
2079+ case GGML_TYPE_IQ1_M :
2080+ case GGML_TYPE_IQ4_NL :
2081+ case GGML_TYPE_IQ4_XS :
2082+ defines.push_back (type_upper + " _GRID" );
2083+ break ;
2084+ case GGML_TYPE_IQ2_XXS :
2085+ case GGML_TYPE_IQ2_XS :
2086+ case GGML_TYPE_IQ2_S :
2087+ case GGML_TYPE_IQ3_XXS :
2088+ case GGML_TYPE_IQ3_S :
2089+ defines.push_back (type_upper + " _GRID" );
2090+ defines.push_back (type_upper + " _TABLES" );
2091+ break ;
2092+ default :
2093+ break ;
2094+ }
2095+
20672096 variant += std::string (" _" ) + src0_name;
20682097 break ;
20692098 }
20702099 }
20712100
2072- defines.push_back (" SCALAR" );
2101+ // VEC/SCALAR controls
2102+ defines.push_back (key.vectorized ? " VEC" : " SCALAR" );
20732103
20742104 // mul_mat_id is register-tile only.
20752105 const uint32_t tile_k =
@@ -2102,6 +2132,123 @@ class ggml_webgpu_shader_lib {
21022132 return mul_mat_id_pipelines[key];
21032133 }
21042134
2135+ webgpu_pipeline get_mul_mat_id_vec_pipeline (const ggml_webgpu_shader_lib_context & context) {
2136+ ggml_webgpu_mul_mat_id_pipeline_key key = {};
2137+ key.src0_type = context.src0 ->type ;
2138+ key.src1_type = context.src1 ->type ;
2139+ key.n_experts = context.src0 ->ne [2 ];
2140+ key.vectorized = (context.src0 ->ne [0 ] % 4 == 0 &&
2141+ (context.src0 ->type == GGML_TYPE_F32 || context.src0 ->type == GGML_TYPE_F16 )) ?
2142+ 1 :
2143+ 0 ;
2144+
2145+ auto it = mul_mat_id_vec_pipelines.find (key);
2146+ if (it != mul_mat_id_vec_pipelines.end ()) {
2147+ return it->second ;
2148+ }
2149+
2150+ std::vector<std::string> defines;
2151+ std::string variant = " mul_mat_id_vec" ;
2152+ const char * shader_src = wgsl_mul_mat_id_vec;
2153+
2154+ // src1 type
2155+ switch (context.src1 ->type ) {
2156+ case GGML_TYPE_F32 :
2157+ defines.push_back (" SRC1_INNER_TYPE=f32" );
2158+ break ;
2159+ case GGML_TYPE_F16 :
2160+ defines.push_back (" SRC1_INNER_TYPE=f16" );
2161+ break ;
2162+ default :
2163+ GGML_ABORT (" Unsupported src1 type for mul_mat fast shader" );
2164+ }
2165+
2166+ // src0 type
2167+ switch (context.src0 ->type ) {
2168+ case GGML_TYPE_F32 :
2169+ defines.push_back (" SRC0_INNER_TYPE=f32" );
2170+ defines.push_back (" MUL_ACC_FLOAT" );
2171+ variant += " _f32" ;
2172+ break ;
2173+ case GGML_TYPE_F16 :
2174+ defines.push_back (" SRC0_INNER_TYPE=f16" );
2175+ defines.push_back (" MUL_ACC_FLOAT" );
2176+ variant += " _f16" ;
2177+ break ;
2178+ default :
2179+ {
2180+ // Quantized types: use helpers but accumulate in f16
2181+ const struct ggml_type_traits * src0_traits = ggml_get_type_traits (context.src0 ->type );
2182+ std::string src0_name = src0_traits->type_name ;
2183+ std::string type_upper = src0_name;
2184+ variant += " _" + src0_name;
2185+ std::transform (type_upper.begin (), type_upper.end (), type_upper.begin (), ::toupper);
2186+
2187+ defines.push_back (" BYTE_HELPERS" );
2188+ defines.push_back (" MUL_ACC_" + type_upper);
2189+ defines.push_back (" U32_DEQUANT_HELPERS" );
2190+ defines.push_back (" SRC0_INNER_TYPE=u32" );
2191+ switch (context.src0 ->type ) {
2192+ case GGML_TYPE_IQ1_S :
2193+ case GGML_TYPE_IQ1_M :
2194+ case GGML_TYPE_IQ2_S :
2195+ case GGML_TYPE_IQ3_S :
2196+ case GGML_TYPE_IQ4_NL :
2197+ case GGML_TYPE_IQ4_XS :
2198+ defines.push_back (type_upper + " _GRID" );
2199+ break ;
2200+ case GGML_TYPE_IQ2_XXS :
2201+ case GGML_TYPE_IQ2_XS :
2202+ case GGML_TYPE_IQ3_XXS :
2203+ defines.push_back (type_upper + " _GRID" );
2204+ defines.push_back (type_upper + " _TABLES" );
2205+ break ;
2206+ default :
2207+ break ;
2208+ }
2209+ break ;
2210+ }
2211+ }
2212+
2213+ // VEC/SCALAR controls
2214+ defines.push_back (key.vectorized ? " VEC" : " SCALAR" );
2215+
2216+ uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE ;
2217+ uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG ;
2218+
2219+ if (key.src0_type == GGML_TYPE_Q1_0 ) {
2220+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG ;
2221+ } else if (key.src0_type >= GGML_TYPE_Q2_K ) {
2222+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG ;
2223+ } else if (key.src0_type >= GGML_TYPE_Q4_0 ) {
2224+ outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG ;
2225+ }
2226+
2227+ // variant suffix for src1 type
2228+ variant += std::string (" _" ) + (context.src1 ->type == GGML_TYPE_F32 ? " f32" : " f16" );
2229+
2230+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
2231+ defines.push_back (std::string (" OUTPUTS_PER_WG=" ) + std::to_string (outputs_per_wg));
2232+ defines.push_back (context.supports_subgroups ? " USE_SUBGROUP_REDUCTION" : " USE_WORKGROUP_REDUCTION" );
2233+ variant += context.supports_subgroups ? " _sg_reduce" : " _wg_reduce" ;
2234+ if (key.vectorized ) {
2235+ variant += " _vectorized" ;
2236+ }
2237+
2238+ defines.push_back (std::string (" N_EXPERTS=" ) + std::to_string (key.n_experts ));
2239+
2240+ auto processed = preprocessor.preprocess (shader_src, defines);
2241+
2242+ auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
2243+ decisions->wg_size = wg_size;
2244+ decisions->outputs_per_wg = outputs_per_wg;
2245+
2246+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
2247+ pipeline.context = decisions;
2248+ mul_mat_id_vec_pipelines[key] = pipeline;
2249+ return mul_mat_id_vec_pipelines[key];
2250+ }
2251+
21052252 webgpu_pipeline get_unary_pipeline (const ggml_webgpu_shader_lib_context & context) {
21062253 const bool is_unary = context.dst ->op == GGML_OP_UNARY ;
21072254 const int op = is_unary ? (int ) ggml_get_unary_op (context.dst ) : context.dst ->op ;
0 commit comments