@@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
9595 uint32_t wg_size = 0 ;
9696};
9797
98+ struct ggml_webgpu_processed_shader {
99+ std::string wgsl;
100+ std::string variant;
101+ std::shared_ptr<void > decisions;
102+ };
103+
98104struct ggml_webgpu_ssm_conv_shader_decisions {
99105 uint32_t block_size;
100106 uint32_t tokens_per_wg;
@@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
384390 bool has_mask;
385391 bool has_sinks;
386392 bool uses_logit_softcap;
393+ bool use_vec;
387394
388395 bool operator ==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
389396 return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
390397 kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
391- uses_logit_softcap == other.uses_logit_softcap ;
398+ uses_logit_softcap == other.uses_logit_softcap && use_vec == other. use_vec ;
392399 }
393400};
394401
@@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
402409 ggml_webgpu_hash_combine (seed, key.has_mask );
403410 ggml_webgpu_hash_combine (seed, key.has_sinks );
404411 ggml_webgpu_hash_combine (seed, key.uses_logit_softcap );
412+ ggml_webgpu_hash_combine (seed, key.use_vec );
405413 return seed;
406414 }
407415};
@@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions {
421429 uint32_t wg_size = 0 ;
422430};
423431
432+ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne (const ggml_webgpu_flash_attn_pipeline_key & key) {
433+ // Keep conservative defaults unless this is the f16 vec-split shape family.
434+ if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v ) {
435+ return 1u ;
436+ }
437+
438+ // Head-dim specializations used by the tuned vec f16 path.
439+ switch (key.head_dim_qk ) {
440+ case 64 : return 2u ;
441+ case 96 : return 4u ;
442+ case 128 : return 1u ;
443+ case 192 : return 2u ;
444+ case 576 : return 2u ;
445+ default : return 1u ;
446+ }
447+ }
448+
449+ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
450+ uint32_t head_dim_v;
451+ uint32_t wg_size;
452+ };
453+
454+ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
455+ size_t operator ()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
456+ size_t seed = 0 ;
457+ ggml_webgpu_hash_combine (seed, key.head_dim_v );
458+ ggml_webgpu_hash_combine (seed, key.wg_size );
459+ return seed;
460+ }
461+ };
462+
463+ inline bool operator ==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
464+ const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
465+ return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size ;
466+ }
467+
468+ struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
469+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
470+ uint32_t max_wg_size;
471+ };
472+
473+ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader (
474+ pre_wgsl::Preprocessor & preprocessor,
475+ const char * shader_src,
476+ const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
477+ std::vector<std::string> defines;
478+ std::string variant = " flash_attn_vec_reduce" ;
479+
480+ defines.push_back (std::string (" HEAD_DIM_V=" ) + std::to_string (context.key .head_dim_v ));
481+ variant += std::string (" _hsv" ) + std::to_string (context.key .head_dim_v );
482+
483+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (context.max_wg_size ));
484+ variant += std::string (" _wg" ) + std::to_string (context.max_wg_size );
485+
486+ ggml_webgpu_processed_shader result;
487+ result.wgsl = preprocessor.preprocess (shader_src, defines);
488+ result.variant = variant;
489+ return result;
490+ }
491+
492+ struct ggml_webgpu_flash_attn_blk_pipeline_key {
493+ uint32_t q_tile;
494+ uint32_t kv_tile;
495+
496+ bool operator ==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
497+ return q_tile == other.q_tile && kv_tile == other.kv_tile ;
498+ }
499+ };
500+
501+ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
502+ size_t operator ()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
503+ size_t seed = 0 ;
504+ ggml_webgpu_hash_combine (seed, key.q_tile );
505+ ggml_webgpu_hash_combine (seed, key.kv_tile );
506+ return seed;
507+ }
508+ };
509+
510+ struct ggml_webgpu_flash_attn_blk_shader_lib_context {
511+ ggml_webgpu_flash_attn_blk_pipeline_key key;
512+ uint32_t max_wg_size;
513+ };
514+
515+ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader (
516+ pre_wgsl::Preprocessor & preprocessor,
517+ const char * shader_src,
518+ const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
519+ std::vector<std::string> defines;
520+ std::string variant = " flash_attn_vec_blk" ;
521+
522+ defines.push_back (std::string (" Q_TILE=" ) + std::to_string (context.key .q_tile ));
523+ variant += std::string (" _qt" ) + std::to_string (context.key .q_tile );
524+
525+ defines.push_back (std::string (" KV_TILE=" ) + std::to_string (context.key .kv_tile ));
526+ variant += std::string (" _kvt" ) + std::to_string (context.key .kv_tile );
527+
528+ uint32_t wg_size = 1 ;
529+ while ((wg_size << 1 ) <= context.max_wg_size ) {
530+ wg_size <<= 1 ;
531+ }
532+ defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
533+ variant += std::string (" _wg" ) + std::to_string (wg_size);
534+
535+ ggml_webgpu_processed_shader result;
536+ result.wgsl = preprocessor.preprocess (shader_src, defines);
537+ result.variant = variant;
538+ return result;
539+ }
540+
424541// This is exposed because it's necessary in supports_op
425542inline size_t ggml_webgpu_flash_attn_wg_mem_bytes (uint32_t q_tile,
426543 uint32_t kv_tile,
@@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib {
659776 repeat_pipelines; // type
660777 std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
661778 flash_attn_pipelines;
779+ std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
780+ webgpu_pipeline,
781+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
782+ flash_attn_vec_reduce_pipelines;
783+ std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
784+ webgpu_pipeline,
785+ ggml_webgpu_flash_attn_blk_pipeline_key_hash>
786+ flash_attn_blk_pipelines;
662787 std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
663788 webgpu_pipeline,
664789 ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
@@ -1673,32 +1798,16 @@ class ggml_webgpu_shader_lib {
16731798 return repeat_pipelines[key];
16741799 }
16751800
1676- webgpu_pipeline get_flash_attn_pipeline (const ggml_webgpu_shader_lib_context & context) {
1677- const bool has_mask = context.src3 != nullptr ;
1678- const bool has_sinks = context.src4 != nullptr ;
1679-
1680- bool kv_direct = (context.src1 ->type == GGML_TYPE_F16) && (context.src0 ->ne [0 ] % context.sg_mat_k == 0 ) &&
1681- (context.src1 ->ne [1 ] % context.sg_mat_n == 0 );
1682-
1683- ggml_webgpu_flash_attn_pipeline_key key = {
1684- .kv_type = context.src1 ->type ,
1685- .head_dim_qk = (uint32_t ) context.src0 ->ne [0 ],
1686- .head_dim_v = (uint32_t ) context.src2 ->ne [0 ],
1687- .kv_direct = kv_direct,
1688- .has_mask = has_mask,
1689- .has_sinks = has_sinks,
1690- .uses_logit_softcap = (*(float *) &context.dst ->op_params [2 ]) != 0 .0f ,
1691- };
1692-
1693- auto it = flash_attn_pipelines.find (key);
1801+ webgpu_pipeline get_flash_attn_pipeline (const ggml_webgpu_flash_attn_shader_lib_context & context) {
1802+ auto it = flash_attn_pipelines.find (context.key );
16941803 if (it != flash_attn_pipelines.end ()) {
16951804 return it->second ;
16961805 }
16971806
16981807 std::vector<std::string> defines;
16991808 std::string variant = " flash_attn" ;
17001809
1701- switch (key.kv_type ) {
1810+ switch (context. key .kv_type ) {
17021811 case GGML_TYPE_F32:
17031812 defines.push_back (" KV_F32" );
17041813 break ;
@@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib {
17141823 default :
17151824 GGML_ABORT (" Unsupported KV type for flash attention shader" );
17161825 }
1717- variant += std::string (" _" ) + ggml_type_name (key.kv_type );
1826+ variant += std::string (" _" ) + ggml_type_name (context. key .kv_type );
17181827
1719- if (key.has_mask ) {
1828+ if (context. key .has_mask ) {
17201829 defines.push_back (" MASK" );
17211830 variant += " _mask" ;
17221831 }
1723- if (key.has_sinks ) {
1832+ if (context. key .has_sinks ) {
17241833 defines.push_back (" SINKS" );
17251834 variant += " _sinks" ;
17261835 }
1727- if (key.uses_logit_softcap ) {
1836+ if (context. key .uses_logit_softcap ) {
17281837 defines.push_back (" LOGIT_SOFTCAP" );
17291838 variant += " _lgsc" ;
17301839 }
1731- if (key.kv_direct ) {
1840+ if (context. key .kv_direct ) {
17321841 defines.push_back (" KV_DIRECT" );
17331842 variant += " _kvdirect" ;
17341843 }
1844+ if (context.key .has_mask && context.key .use_vec ) {
1845+ defines.push_back (" BLK" );
1846+ variant += " _blk" ;
1847+ }
17351848
1736- defines.push_back (std::string (" HEAD_DIM_QK=" ) + std::to_string (key.head_dim_qk ));
1737- variant += std::string (" _hsqk" ) + std::to_string (key.head_dim_qk );
1849+ defines.push_back (std::string (" HEAD_DIM_QK=" ) + std::to_string (context. key .head_dim_qk ));
1850+ variant += std::string (" _hsqk" ) + std::to_string (context. key .head_dim_qk );
17381851
1739- defines.push_back (std::string (" HEAD_DIM_V=" ) + std::to_string (key.head_dim_v ));
1740- variant += std::string (" _hsv" ) + std::to_string (key.head_dim_v );
1852+ defines.push_back (std::string (" HEAD_DIM_V=" ) + std::to_string (context. key .head_dim_v ));
1853+ variant += std::string (" _hsv" ) + std::to_string (context. key .head_dim_v );
17411854
17421855 defines.push_back (std::string (" SG_MAT_M=" ) + std::to_string (context.sg_mat_m ));
17431856 defines.push_back (std::string (" SG_MAT_N=" ) + std::to_string (context.sg_mat_n ));
17441857 defines.push_back (std::string (" SG_MAT_K=" ) + std::to_string (context.sg_mat_k ));
17451858
1746- uint32_t q_tile = context.sg_mat_m ;
1859+ uint32_t q_tile = context.sg_mat_m ;
17471860 uint32_t kv_tile =
1748- std::min (ggml_webgpu_flash_attn_max_kv_tile ({ key, context.sg_mat_m , context.sg_mat_n , context.sg_mat_k ,
1749- context.wg_mem_limit_bytes , context.max_subgroup_size }),
1861+ std::min (ggml_webgpu_flash_attn_max_kv_tile (context),
17501862 context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
1751- if (key.kv_direct ) {
1863+ if (context.key .use_vec ) {
1864+ q_tile = 1 ;
1865+ kv_tile = std::max (context.sg_mat_n , std::min (32u , ggml_webgpu_flash_attn_max_kv_tile (context)));
1866+ kv_tile = (kv_tile / context.sg_mat_n ) * context.sg_mat_n ;
1867+ const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne (context.key );
1868+ defines.push_back (std::string (" VEC_NE=" ) + std::to_string (vec_ne) + " u" );
1869+ }
1870+ if (context.key .kv_direct ) {
1871+ GGML_ASSERT (kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
17521872 while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0 ) {
17531873 kv_tile -= context.sg_mat_n ;
17541874 }
@@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib {
17571877 defines.push_back (std::string (" Q_TILE=" ) + std::to_string (q_tile));
17581878 defines.push_back (std::string (" KV_TILE=" ) + std::to_string (kv_tile));
17591879
1760- uint32_t wg_size = std::max (context.max_subgroup_size , GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
1880+ uint32_t wg_size = 0 ;
1881+ if (context.key .use_vec ) {
1882+ wg_size = std::max (1u , std::min<uint32_t >(32u , context.max_subgroup_size ));
1883+ } else {
1884+ wg_size = std::max (context.max_subgroup_size , GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
1885+ }
17611886 defines.push_back (std::string (" WG_SIZE=" ) + std::to_string (wg_size));
17621887
1763- auto processed = preprocessor.preprocess (wgsl_flash_attn, defines);
1888+ const char * shader_src = context.key .use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
1889+ webgpu_pipeline pipeline =
1890+ ggml_webgpu_create_pipeline (device, preprocessor.preprocess (shader_src, defines), variant);
17641891 auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
17651892 decisions->q_tile = q_tile;
17661893 decisions->kv_tile = kv_tile;
17671894 decisions->wg_size = wg_size;
1895+ pipeline.context = decisions;
1896+ flash_attn_pipelines[context.key ] = pipeline;
1897+ return flash_attn_pipelines[context.key ];
1898+ }
1899+
1900+ webgpu_pipeline get_flash_attn_blk_pipeline (const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
1901+ auto it = flash_attn_blk_pipelines.find (context.key );
1902+ if (it != flash_attn_blk_pipelines.end ()) {
1903+ return it->second ;
1904+ }
1905+
1906+ ggml_webgpu_processed_shader processed =
1907+ ggml_webgpu_preprocess_flash_attn_blk_shader (preprocessor, wgsl_flash_attn_vec_blk, context);
1908+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed.wgsl , processed.variant );
1909+ flash_attn_blk_pipelines[context.key ] = pipeline;
1910+ return flash_attn_blk_pipelines[context.key ];
1911+ }
1912+
1913+ webgpu_pipeline get_flash_attn_vec_reduce_pipeline (
1914+ const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
1915+ auto it = flash_attn_vec_reduce_pipelines.find (context.key );
1916+ if (it != flash_attn_vec_reduce_pipelines.end ()) {
1917+ return it->second ;
1918+ }
17681919
1769- webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed, variant);
1770- pipeline.context = decisions;
1771- flash_attn_pipelines[key] = pipeline;
1772- return flash_attn_pipelines[key];
1920+ ggml_webgpu_processed_shader processed =
1921+ ggml_webgpu_preprocess_flash_attn_vec_reduce_shader (preprocessor, wgsl_flash_attn_vec_reduce, context);
1922+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline (device, processed.wgsl , processed.variant );
1923+ flash_attn_vec_reduce_pipelines[context.key ] = pipeline;
1924+ return flash_attn_vec_reduce_pipelines[context.key ];
17731925 }
17741926
17751927 webgpu_pipeline get_cpy_pipeline (const ggml_webgpu_shader_lib_context & context) {
0 commit comments