Skip to content

Commit a1cfb64

Browse files
ggml-webgpu: add vectorized flash attention (ggml-org#20709)
* naive vectorized version * add vectorized flash attention * update vec version * remove unused path and shader * remove unused helper functions * add comments * remove pad path * ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization * change back to vec4 * enable multi split * enable vec path when: - Q->ne[1] < 20 - Q->ne[0] % 32 == 0 - V->ne[0] % 4 == 0 - K->type == f16 * update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select * enable vec path for q4 and q8 * flash-attn vec nwg=1 fast path (skip tmp/reduce staging) * use packed f16 K loads in flash-attn vec split * use packed f16 K loads in flash-attn vec split on host side * tune flash-attn vec f16 VEC_NE by head dim * cleanup * cleanup * keep host side clean * cleanup host side * change back to original host wait/submit behavior * formatting * reverted param-buffer pool r ecfactor * add helper functions * ggml-webgpu: move flash-attn vec pipeline caching back into shader lib * ggml-webgpu: remove duplicate functions * ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation * ggml-webgpu: revert unrelated change * ggml-webgpu: revert deleted comment * disable uniformity check * remove unnecessary change * Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl * Update ggml/src/ggml-webgpu/ggml-webgpu.cpp --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
1 parent 5803c8d commit a1cfb64

5 files changed

Lines changed: 1412 additions & 53 deletions

File tree

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 191 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
9595
uint32_t wg_size = 0;
9696
};
9797

98+
struct ggml_webgpu_processed_shader {
99+
std::string wgsl;
100+
std::string variant;
101+
std::shared_ptr<void> decisions;
102+
};
103+
98104
struct ggml_webgpu_ssm_conv_shader_decisions {
99105
uint32_t block_size;
100106
uint32_t tokens_per_wg;
@@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
384390
bool has_mask;
385391
bool has_sinks;
386392
bool uses_logit_softcap;
393+
bool use_vec;
387394

388395
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
389396
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
390397
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
391-
uses_logit_softcap == other.uses_logit_softcap;
398+
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
392399
}
393400
};
394401

@@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
402409
ggml_webgpu_hash_combine(seed, key.has_mask);
403410
ggml_webgpu_hash_combine(seed, key.has_sinks);
404411
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
412+
ggml_webgpu_hash_combine(seed, key.use_vec);
405413
return seed;
406414
}
407415
};
@@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions {
421429
uint32_t wg_size = 0;
422430
};
423431

432+
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
433+
// Keep conservative defaults unless this is the f16 vec-split shape family.
434+
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
435+
return 1u;
436+
}
437+
438+
// Head-dim specializations used by the tuned vec f16 path.
439+
switch (key.head_dim_qk) {
440+
case 64: return 2u;
441+
case 96: return 4u;
442+
case 128: return 1u;
443+
case 192: return 2u;
444+
case 576: return 2u;
445+
default: return 1u;
446+
}
447+
}
448+
449+
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
450+
uint32_t head_dim_v;
451+
uint32_t wg_size;
452+
};
453+
454+
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
455+
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
456+
size_t seed = 0;
457+
ggml_webgpu_hash_combine(seed, key.head_dim_v);
458+
ggml_webgpu_hash_combine(seed, key.wg_size);
459+
return seed;
460+
}
461+
};
462+
463+
inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
464+
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
465+
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
466+
}
467+
468+
struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
469+
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
470+
uint32_t max_wg_size;
471+
};
472+
473+
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
474+
pre_wgsl::Preprocessor & preprocessor,
475+
const char * shader_src,
476+
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
477+
std::vector<std::string> defines;
478+
std::string variant = "flash_attn_vec_reduce";
479+
480+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
481+
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
482+
483+
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
484+
variant += std::string("_wg") + std::to_string(context.max_wg_size);
485+
486+
ggml_webgpu_processed_shader result;
487+
result.wgsl = preprocessor.preprocess(shader_src, defines);
488+
result.variant = variant;
489+
return result;
490+
}
491+
492+
struct ggml_webgpu_flash_attn_blk_pipeline_key {
493+
uint32_t q_tile;
494+
uint32_t kv_tile;
495+
496+
bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
497+
return q_tile == other.q_tile && kv_tile == other.kv_tile;
498+
}
499+
};
500+
501+
struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
502+
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
503+
size_t seed = 0;
504+
ggml_webgpu_hash_combine(seed, key.q_tile);
505+
ggml_webgpu_hash_combine(seed, key.kv_tile);
506+
return seed;
507+
}
508+
};
509+
510+
struct ggml_webgpu_flash_attn_blk_shader_lib_context {
511+
ggml_webgpu_flash_attn_blk_pipeline_key key;
512+
uint32_t max_wg_size;
513+
};
514+
515+
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
516+
pre_wgsl::Preprocessor & preprocessor,
517+
const char * shader_src,
518+
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
519+
std::vector<std::string> defines;
520+
std::string variant = "flash_attn_vec_blk";
521+
522+
defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
523+
variant += std::string("_qt") + std::to_string(context.key.q_tile);
524+
525+
defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
526+
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
527+
528+
uint32_t wg_size = 1;
529+
while ((wg_size << 1) <= context.max_wg_size) {
530+
wg_size <<= 1;
531+
}
532+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
533+
variant += std::string("_wg") + std::to_string(wg_size);
534+
535+
ggml_webgpu_processed_shader result;
536+
result.wgsl = preprocessor.preprocess(shader_src, defines);
537+
result.variant = variant;
538+
return result;
539+
}
540+
424541
// This is exposed because it's necessary in supports_op
425542
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
426543
uint32_t kv_tile,
@@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib {
659776
repeat_pipelines; // type
660777
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
661778
flash_attn_pipelines;
779+
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
780+
webgpu_pipeline,
781+
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
782+
flash_attn_vec_reduce_pipelines;
783+
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
784+
webgpu_pipeline,
785+
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
786+
flash_attn_blk_pipelines;
662787
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
663788
webgpu_pipeline,
664789
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
@@ -1673,32 +1798,16 @@ class ggml_webgpu_shader_lib {
16731798
return repeat_pipelines[key];
16741799
}
16751800

1676-
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
1677-
const bool has_mask = context.src3 != nullptr;
1678-
const bool has_sinks = context.src4 != nullptr;
1679-
1680-
bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
1681-
(context.src1->ne[1] % context.sg_mat_n == 0);
1682-
1683-
ggml_webgpu_flash_attn_pipeline_key key = {
1684-
.kv_type = context.src1->type,
1685-
.head_dim_qk = (uint32_t) context.src0->ne[0],
1686-
.head_dim_v = (uint32_t) context.src2->ne[0],
1687-
.kv_direct = kv_direct,
1688-
.has_mask = has_mask,
1689-
.has_sinks = has_sinks,
1690-
.uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
1691-
};
1692-
1693-
auto it = flash_attn_pipelines.find(key);
1801+
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
1802+
auto it = flash_attn_pipelines.find(context.key);
16941803
if (it != flash_attn_pipelines.end()) {
16951804
return it->second;
16961805
}
16971806

16981807
std::vector<std::string> defines;
16991808
std::string variant = "flash_attn";
17001809

1701-
switch (key.kv_type) {
1810+
switch (context.key.kv_type) {
17021811
case GGML_TYPE_F32:
17031812
defines.push_back("KV_F32");
17041813
break;
@@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib {
17141823
default:
17151824
GGML_ABORT("Unsupported KV type for flash attention shader");
17161825
}
1717-
variant += std::string("_") + ggml_type_name(key.kv_type);
1826+
variant += std::string("_") + ggml_type_name(context.key.kv_type);
17181827

1719-
if (key.has_mask) {
1828+
if (context.key.has_mask) {
17201829
defines.push_back("MASK");
17211830
variant += "_mask";
17221831
}
1723-
if (key.has_sinks) {
1832+
if (context.key.has_sinks) {
17241833
defines.push_back("SINKS");
17251834
variant += "_sinks";
17261835
}
1727-
if (key.uses_logit_softcap) {
1836+
if (context.key.uses_logit_softcap) {
17281837
defines.push_back("LOGIT_SOFTCAP");
17291838
variant += "_lgsc";
17301839
}
1731-
if (key.kv_direct) {
1840+
if (context.key.kv_direct) {
17321841
defines.push_back("KV_DIRECT");
17331842
variant += "_kvdirect";
17341843
}
1844+
if (context.key.has_mask && context.key.use_vec) {
1845+
defines.push_back("BLK");
1846+
variant += "_blk";
1847+
}
17351848

1736-
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
1737-
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
1849+
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
1850+
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
17381851

1739-
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
1740-
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
1852+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
1853+
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
17411854

17421855
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
17431856
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
17441857
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
17451858

1746-
uint32_t q_tile = context.sg_mat_m;
1859+
uint32_t q_tile = context.sg_mat_m;
17471860
uint32_t kv_tile =
1748-
std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
1749-
context.wg_mem_limit_bytes, context.max_subgroup_size }),
1861+
std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
17501862
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
1751-
if (key.kv_direct) {
1863+
if (context.key.use_vec) {
1864+
q_tile = 1;
1865+
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
1866+
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
1867+
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
1868+
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
1869+
}
1870+
if (context.key.kv_direct) {
1871+
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
17521872
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
17531873
kv_tile -= context.sg_mat_n;
17541874
}
@@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib {
17571877
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
17581878
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
17591879

1760-
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
1880+
uint32_t wg_size = 0;
1881+
if (context.key.use_vec) {
1882+
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
1883+
} else {
1884+
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
1885+
}
17611886
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
17621887

1763-
auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
1888+
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
1889+
webgpu_pipeline pipeline =
1890+
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
17641891
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
17651892
decisions->q_tile = q_tile;
17661893
decisions->kv_tile = kv_tile;
17671894
decisions->wg_size = wg_size;
1895+
pipeline.context = decisions;
1896+
flash_attn_pipelines[context.key] = pipeline;
1897+
return flash_attn_pipelines[context.key];
1898+
}
1899+
1900+
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
1901+
auto it = flash_attn_blk_pipelines.find(context.key);
1902+
if (it != flash_attn_blk_pipelines.end()) {
1903+
return it->second;
1904+
}
1905+
1906+
ggml_webgpu_processed_shader processed =
1907+
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
1908+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
1909+
flash_attn_blk_pipelines[context.key] = pipeline;
1910+
return flash_attn_blk_pipelines[context.key];
1911+
}
1912+
1913+
webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
1914+
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
1915+
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
1916+
if (it != flash_attn_vec_reduce_pipelines.end()) {
1917+
return it->second;
1918+
}
17681919

1769-
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
1770-
pipeline.context = decisions;
1771-
flash_attn_pipelines[key] = pipeline;
1772-
return flash_attn_pipelines[key];
1920+
ggml_webgpu_processed_shader processed =
1921+
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
1922+
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
1923+
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
1924+
return flash_attn_vec_reduce_pipelines[context.key];
17731925
}
17741926

17751927
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {

0 commit comments

Comments
 (0)