Skip to content

Commit ebd048f

Browse files
wanghqclhez
andauthored
opencl: flash attention improvement (#25069)
* opencl: rework FA kernel for f16 and f32 * opencl: flash-attention prefill prepass kernels - flash_attn_kv_pad_f16 pads the tail KV tile to a BLOCK_N multiple - flash_attn_mask_pad_f16 pads the matching mask tile - flash_attn_blk_f16 classifies each KV tile per query block as fully masked / mixed / fully unmasked, so the main kernel can skip fully-masked tiles and the mask lookup for fully-unmasked ones * opencl: FA kernels for q4_0 and q8_0 * opencl: `set_rows` for f32 to q8_0/q4_0 * opencl: dequant kernels for q4_0 and q8_0 * opencl: add FA tile tuning table with override * opencl: wire host side for FA * opencl: q4_0 MoE tensors are also SOA'ed * opencl: cosmetic fix * opencl: refactor, also clarify some code paths in comments * opencl: fix inifity for `-cl-finite-math-only` --------- Co-authored-by: Li He <lih@qti.qualcomm.com>
1 parent 0ed235e commit ebd048f

11 files changed

Lines changed: 5613 additions & 413 deletions

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
192192
mul_mm_f16_f32_kq_kqv
193193
conv2d
194194
conv2d_f16_f32
195+
flash_attn_pre_f16
195196
flash_attn_f32_f16
197+
flash_attn_f32_q8_0
198+
flash_attn_f32_q4_0
196199
flash_attn_f16
197200
flash_attn_f32
198201
)

ggml/src/ggml-opencl/fa_tune.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#pragma once
2+
3+
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
4+
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
5+
// edit; the FA dispatch and kernel-compile logic stay in the main file.
6+
// This header is a file section — it is #included exactly once, at the point
7+
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
8+
9+
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
10+
struct ggml_opencl_fa_dim {
11+
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
12+
};
13+
14+
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
15+
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
16+
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
17+
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
18+
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
19+
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
20+
{192, 128, 16, 16, 1, 0},
21+
{192, 192, 16, 16, 1, 0},
22+
{256, 256, 16, 16, 16, 0},
23+
};
24+
25+
struct ggml_opencl_fa_dim_table {
26+
const ggml_opencl_fa_dim * data;
27+
size_t count;
28+
29+
const ggml_opencl_fa_dim * begin() const { return data; }
30+
const ggml_opencl_fa_dim * end() const { return data + count; }
31+
};
32+
33+
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
34+
// at backend init without touching the const source table.
35+
static ggml_opencl_fa_dim g_fa_dims_runtime[
36+
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
37+
38+
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
39+
g_fa_dims_adreno_default,
40+
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
41+
};
42+
43+
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
44+
// in the active table at backend init, before the first FA kernel compiles.
45+
// Unmatched (dk,dv) pairs are warned and ignored.
46+
static void ggml_opencl_fa_apply_env_overrides() {
47+
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
48+
if (!e || !e[0]) {
49+
return;
50+
}
51+
52+
std::string s = e;
53+
size_t pos = 0;
54+
while (pos < s.size()) {
55+
size_t comma = s.find(',', pos);
56+
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
57+
int dk, dv, bm, bn, nsplit, thr;
58+
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
59+
bool patched = false;
60+
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
61+
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
62+
if (d.dk == dk && d.dv == dv) {
63+
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
64+
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
65+
dk, dv, bm, bn, nsplit, thr);
66+
patched = true;
67+
break;
68+
}
69+
}
70+
if (!patched) {
71+
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
72+
}
73+
} else {
74+
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
75+
}
76+
if (comma == std::string::npos) break;
77+
pos = comma + 1;
78+
}
79+
}
80+
81+
// Copy the default table into the mutable runtime buffer and apply any
82+
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
83+
// once it has been tuned on hardware.
84+
static void ggml_cl_init_fa_dims_table() {
85+
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
86+
for (size_t i = 0; i < count; ++i) {
87+
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
88+
}
89+
g_opencl_fa_dims = { g_fa_dims_runtime, count };
90+
ggml_opencl_fa_apply_env_overrides();
91+
}

0 commit comments

Comments
 (0)