From ebd048fc5e4b43ec4e0b4abe0b9bf66e1724dad0 Mon Sep 17 00:00:00 2001 From: Hongqiang Wang <66336067+wanghqc@users.noreply.github.com> Date: Sat, 27 Jun 2026 15:36:06 -0700 Subject: [PATCH] opencl: flash attention improvement (#25069) * opencl: rework FA kernel for f16 and f32 * opencl: flash-attention prefill prepass kernels - flash_attn_kv_pad_f16 pads the tail KV tile to a BLOCK_N multiple - flash_attn_mask_pad_f16 pads the matching mask tile - flash_attn_blk_f16 classifies each KV tile per query block as fully masked / mixed / fully unmasked, so the main kernel can skip fully-masked tiles and the mask lookup for fully-unmasked ones * opencl: FA kernels for q4_0 and q8_0 * opencl: `set_rows` for f32 to q8_0/q4_0 * opencl: dequant kernels for q4_0 and q8_0 * opencl: add FA tile tuning table with override * opencl: wire host side for FA * opencl: q4_0 MoE tensors are also SOA'ed * opencl: cosmetic fix * opencl: refactor, also clarify some code paths in comments * opencl: fix inifity for `-cl-finite-math-only` --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 3 + ggml/src/ggml-opencl/fa_tune.h | 91 + ggml/src/ggml-opencl/ggml-opencl.cpp | 2043 +++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 152 ++ .../src/ggml-opencl/kernels/flash_attn_f16.cl | 115 +- .../src/ggml-opencl/kernels/flash_attn_f32.cl | 111 +- .../ggml-opencl/kernels/flash_attn_f32_f16.cl | 765 +++++- .../kernels/flash_attn_f32_q4_0.cl | 1041 +++++++++ .../kernels/flash_attn_f32_q8_0.cl | 1049 +++++++++ .../ggml-opencl/kernels/flash_attn_pre_f16.cl | 156 ++ ggml/src/ggml-opencl/kernels/set_rows.cl | 500 ++++ 11 files changed, 5613 insertions(+), 413 deletions(-) create mode 100644 ggml/src/ggml-opencl/fa_tune.h create mode 100644 ggml/src/ggml-opencl/kernels/flash_attn_f32_q4_0.cl create mode 100644 ggml/src/ggml-opencl/kernels/flash_attn_f32_q8_0.cl create mode 100644 ggml/src/ggml-opencl/kernels/flash_attn_pre_f16.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 82ce61d72c6d..09efbc566b38 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS mul_mm_f16_f32_kq_kqv conv2d conv2d_f16_f32 + flash_attn_pre_f16 flash_attn_f32_f16 + flash_attn_f32_q8_0 + flash_attn_f32_q4_0 flash_attn_f16 flash_attn_f32 ) diff --git a/ggml/src/ggml-opencl/fa_tune.h b/ggml/src/ggml-opencl/fa_tune.h new file mode 100644 index 000000000000..1e2c6ea7eaf9 --- /dev/null +++ b/ggml/src/ggml-opencl/fa_tune.h @@ -0,0 +1,91 @@ +#pragma once + +// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend. +// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and +// edit; the FA dispatch and kernel-compile logic stay in the main file. +// This header is a file section — it is #included exactly once, at the point +// in ggml-opencl.cpp where the ggml logging macros are already in scope. + +// Per-(dk, dv) FA config; shared by dispatch and supports_op. +struct ggml_opencl_fa_dim { + int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold; +}; + +// Split variant fires when n_kv >= threshold (threshold=0 -> always split). +// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs. +static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = { + { 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64}, + { 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64}, + {112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64}, + {192, 128, 16, 16, 1, 0}, + {192, 192, 16, 16, 1, 0}, + {256, 256, 16, 16, 16, 0}, +}; + +struct ggml_opencl_fa_dim_table { + const ggml_opencl_fa_dim * data; + size_t count; + + const ggml_opencl_fa_dim * begin() const { return data; } + const ggml_opencl_fa_dim * end() const { return data + count; } +}; + +// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here +// at backend init without touching the const source table. +static ggml_opencl_fa_dim g_fa_dims_runtime[ + sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])]; + +static ggml_opencl_fa_dim_table g_opencl_fa_dims = { + g_fa_dims_adreno_default, + sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]), +}; + +// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries +// in the active table at backend init, before the first FA kernel compiles. +// Unmatched (dk,dv) pairs are warned and ignored. +static void ggml_opencl_fa_apply_env_overrides() { + const char * e = std::getenv("GGML_OPENCL_FA_TUNE"); + if (!e || !e[0]) { + return; + } + + std::string s = e; + size_t pos = 0; + while (pos < s.size()) { + size_t comma = s.find(',', pos); + std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos); + int dk, dv, bm, bn, nsplit, thr; + if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) { + bool patched = false; + for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) { + ggml_opencl_fa_dim & d = g_fa_dims_runtime[i]; + if (d.dk == dk && d.dv == dv) { + d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr; + GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n", + dk, dv, bm, bn, nsplit, thr); + patched = true; + break; + } + } + if (!patched) { + GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv); + } + } else { + GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str()); + } + if (comma == std::string::npos) break; + pos = comma + 1; + } +} + +// Copy the default table into the mutable runtime buffer and apply any +// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here +// once it has been tuned on hardware. +static void ggml_cl_init_fa_dims_table() { + const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]); + for (size_t i = 0; i < count; ++i) { + g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i]; + } + g_opencl_fa_dims = { g_fa_dims_runtime, count }; + ggml_opencl_fa_apply_env_overrides(); +} diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 00f20b09b8fd..32581901b2b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -29,6 +29,8 @@ #include #include #include +#include +#include #undef MIN #undef MAX @@ -53,6 +55,9 @@ //------------------------------------------------------------------------------ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor); +static bool ggml_cl_is_q4_0_soa(const ggml_tensor * tensor); +static bool ggml_cl_is_q8_0_soa(const ggml_tensor * tensor); +static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. // Precompute mp (m' in the paper) and L such that division @@ -96,6 +101,7 @@ enum ADRENO_GPU_GEN { A7X, A8X, X1E, + X2E, }; enum ADRENO_CL_COMPILER_TYPE { @@ -236,6 +242,10 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::X1E; } + if (strstr(device_name, "X2")) { + return ADRENO_GPU_GEN::X2E; + } + return ADRENO_GPU_GEN::ADRENO_UNKNOWN; } @@ -368,7 +378,7 @@ struct ggml_backend_opencl_device_context { cl_device_type device_type; std::string device_version; - // Initialized by ggml_cl2_init(). + // Initialized by ggml_cl_init(). ggml_backend_opencl_context * backend_ctx = nullptr; // Initialized by ggml_backend_opencl_device_get_buffer_type() @@ -384,6 +394,55 @@ struct ggml_backend_opencl_device_context { size_t global_mem_size = 0; }; +// Lazily-compiled flash-attention kernels and their per-(dk,dv) tile metadata. +// One map per (Q/KV dtype, decode/prefill, split) combination; the int maps +// hold tile dims (bm/bn), workgroup sizes and the n_kv split thresholds. +struct ggml_opencl_fa_kernels { + // f16 Q / f16 KV + std::map, cl_kernel> f16; + std::map, cl_kernel> f16_q1; + // f32 Q / f32 KV + std::map, cl_kernel> f32; + std::map, cl_kernel> f32_q1; + // f32 Q / f16 KV (mixed) + std::map, cl_kernel> f32_f16; + std::map, cl_kernel> f32_f16_split; // N_SPLIT>1 variant + std::map, cl_kernel> f32_f16_q1; + std::map, cl_kernel> f32_f16_q1_split; // flash-decoding K-split + std::map, int> f32_f16_bm; + std::map, int> f32_f16_bn; + std::map, int> f32_f16_wg_size; + std::map, int> f32_f16_split_wg_size; + std::map, int> f32_f16_split_nkv_threshold; + // f32 Q / native q8_0 KV + std::map, cl_kernel> f32_q8_0_q1; // decode + std::map, cl_kernel> f32_q8_0_q1_split; // flash-decoding pass 1 + std::map, cl_kernel> f32_q8_0; // prefill (baseline) + std::map, cl_kernel> f32_q8_0_split; // N_SPLIT>1 variant + std::map, int> f32_q8_0_split_wg_size; // wg_size = bm*n_split + std::map, int> f32_q8_0_split_nkv_threshold; // use split when n_kv >= this + std::map, int> f32_q8_0_split_bm; // per-split BLOCK_M + // f32 Q / native q4_0 KV + std::map, cl_kernel> f32_q4_0_q1; + std::map, cl_kernel> f32_q4_0_q1_split; + std::map, cl_kernel> f32_q4_0; + std::map, cl_kernel> f32_q4_0_split; + std::map, int> f32_q4_0_split_wg_size; + std::map, int> f32_q4_0_split_nkv_threshold; + std::map, int> f32_q4_0_split_bm; + // shared: flash-decoding merge + prefill prepass (kv-pad, mask-pad, blk class) + std::map, cl_kernel> f32_merge; + std::map, cl_kernel> kv_pad_f16; + std::map, cl_kernel> mask_pad_f16; + std::map, cl_kernel> blk_f16; + // generic prefill tile dims (f16 / f32 paths) + std::map, int> bm; + std::map, int> bn; + // attempted (variant, (dk, dv)) + // all attempted FA kernels appear here, but those not registered failed compilation + std::set>> variant_attempted; +}; + // backend context struct ggml_backend_opencl_context { int ref_count; @@ -397,9 +456,6 @@ struct ggml_backend_opencl_context { // argsort is loaded in supports_op because its availability depends on how // many workgroups are allowed, which requires kernel compilation. bool kernels_loaded_argsort = false; - // flash attn is loaded in supports_op because it contains multiple variants - // and takes time to compile, so we want to only compile it when needed. - bool kernels_loaded_flash_attn = false; // rest of the kernels are currently always loaded in alloc_buffer. bool kernels_loaded = false; @@ -414,13 +470,16 @@ struct ggml_backend_opencl_context { size_t max_workgroup_size; bool fp16_support; bool has_vector_subgroup_broadcast; - bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle + bool has_subgroup_shuffle = false; // cl_khr_subgroup_shuffle or cl_qcom_subgroup_shuffle + bool has_qcom_subgroup_shuffle = false; // specifically cl_qcom_subgroup_shuffle bool disable_fusion; bool adreno_has_large_buffer; bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; + std::string kernel_compile_opts; // cached for lazy-compiled kernels. + int adreno_wave_size; cl_bool non_uniform_workgroups; @@ -546,16 +605,13 @@ struct ggml_backend_opencl_context { cl_kernel kernel_diag_f32; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; - std::map, cl_kernel> kernels_flash_attn_f16; - std::map, cl_kernel> kernels_flash_attn_f16_q1; - std::map, cl_kernel> kernels_flash_attn_f32; - std::map, cl_kernel> kernels_flash_attn_f32_q1; - std::map, cl_kernel> kernels_flash_attn_f32_f16; - std::map, cl_kernel> kernels_flash_attn_f32_f16_q1; - std::map, int> kernels_flash_attn_bm; - std::map, int> kernels_flash_attn_bn; + ggml_opencl_fa_kernels fa; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; + cl_kernel kernel_set_rows_q8_0_i64, kernel_set_rows_q8_0_i32; + cl_kernel kernel_set_rows_q8_0_soa_i64, kernel_set_rows_q8_0_soa_i32; + cl_kernel kernel_set_rows_q4_0_i64, kernel_set_rows_q4_0_i32; + cl_kernel kernel_set_rows_q4_0_soa_i64, kernel_set_rows_q4_0_soa_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_f32_f32_pack, kernel_cpy_i32_i32; @@ -589,6 +645,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; + cl_kernel kernel_dequant_q8_0_f16_view_aos; + cl_kernel kernel_dequant_q8_0_f32_view_aos; + cl_kernel kernel_dequant_q4_0_f16_view_aos; + cl_kernel kernel_dequant_q4_0_f32_view_aos; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_convert_bf16_to_f16, kernel_convert_f16_to_bf16; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -877,7 +937,13 @@ inline std::string read_file(const std::string &path) { return text; } -static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { +// fatal=false returns NULL on compile failure instead of aborting; used for +// optional FA variants that may exhaust the Adreno compiler at large DK. +static cl_program build_program_from_source_ex(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts, bool fatal, const char *tag = nullptr) { + if (tag) { + GGML_LOG_INFO("ggml_opencl: compiling %s\n", tag); + } + cl_program p; char *program_log; size_t program_size; @@ -889,7 +955,10 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); if(err < 0) { GGML_LOG_ERROR("OpenCL error creating program"); - exit(1); + if (fatal) { + exit(1); + } + return nullptr; } err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL); @@ -898,14 +967,22 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co program_log = (char*) malloc(log_size + 1); program_log[log_size] = '\0'; clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); - GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log); + GGML_LOG_ERROR("ggml_opencl: kernel compile error (err=%d):\n\n%s\n", err, program_log); free(program_log); - exit(1); + clReleaseProgram(p); + if (fatal) { + exit(1); + } + return nullptr; } return p; } +static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { + return build_program_from_source_ex(ctx, dev, program_buffer, compile_opts, /*fatal=*/true); +} + static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { // compiler options for general kernels auto opencl_c_std = @@ -932,84 +1009,6 @@ static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { } } -static void load_cl_kernels_flash_attn(ggml_backend_opencl_context *backend_ctx) { - // compiler options for general kernels - auto opencl_c_std = - std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); - std::string compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable -cl-unsafe-math-optimizations" - " -cl-finite-math-only -cl-fast-relaxed-math"; - - // flash_attn - if (!backend_ctx->kernels_loaded_flash_attn) { - cl_int err; - - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_f16 { - #include "flash_attn_f16.cl.h" - }; - const std::string kernel_src_f32 { - #include "flash_attn_f32.cl.h" - }; - const std::string kernel_src_f32_f16 { - #include "flash_attn_f32_f16.cl.h" - }; - #else - const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); - const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); - const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); - #endif - - if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { - const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, - {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, - {192, 192, 16, 16}, {256, 256, 16, 16}, - }; - - for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { - const int dk = fa_dims[i].dk; - const int dv = fa_dims[i].dv; - const int bm = fa_dims[i].bm; - const int bn = fa_dims[i].bn; - std::string OPTS = compile_opts + - " -D DK=" + std::to_string(dk) + - " -D DV=" + std::to_string(dv) + - " -D BLOCK_M=" + std::to_string(bm) + - " -D BLOCK_N=" + std::to_string(bn); - - cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); - cl_kernel k_f16, k_f16_q1; - CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); - CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; - backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; - CL_CHECK(clReleaseProgram(prog_f16)); - - cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); - cl_kernel k_f32, k_f32_q1; - CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); - CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; - backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; - CL_CHECK(clReleaseProgram(prog_f32)); - - cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); - cl_kernel k_f32_f16, k_f32_f16_q1; - CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); - CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; - backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; - CL_CHECK(clReleaseProgram(prog_f32_f16)); - - backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; - backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; - } - backend_ctx->kernels_loaded_flash_attn = true; - } - } -} - static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { if (backend_ctx->kernels_loaded) { return; @@ -1028,6 +1027,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { compile_opts += " -qcom-enable-large-buffer "; } + backend_ctx->kernel_compile_opts = compile_opts; + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); // add @@ -1189,6 +1190,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_dequant_q8_0_f16_view_aos = clCreateKernel(backend_ctx->program_cvt, "kernel_dequant_q8_0_f16_view_aos", &err), err)); + CL_CHECK((backend_ctx->kernel_dequant_q8_0_f32_view_aos = clCreateKernel(backend_ctx->program_cvt, "kernel_dequant_q8_0_f32_view_aos", &err), err)); + CL_CHECK((backend_ctx->kernel_dequant_q4_0_f16_view_aos = clCreateKernel(backend_ctx->program_cvt, "kernel_dequant_q4_0_f16_view_aos", &err), err)); + CL_CHECK((backend_ctx->kernel_dequant_q4_0_f32_view_aos = clCreateKernel(backend_ctx->program_cvt, "kernel_dequant_q4_0_f32_view_aos", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); @@ -2680,6 +2685,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i32", &err), err)); CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i64", &err), err)); CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q8_0_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q8_0_i64", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q8_0_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q8_0_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q8_0_soa_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q8_0_soa_i64", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q8_0_soa_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q8_0_soa_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q4_0_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q4_0_i64", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q4_0_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q4_0_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q4_0_soa_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q4_0_soa_i64", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_q4_0_soa_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_q4_0_soa_i32", &err), err)); GGML_LOG_CONT("."); } @@ -3704,13 +3717,470 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { backend_ctx->kernels_loaded = true; } -// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { -// XXX static bool initialized = false; -// XXX static ggml_backend_opencl_context *backend_ctx = nullptr; - static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev); static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev); +// FA per-(dk,dv) tile tuning table + GGML_OPENCL_FA_TUNE override parsing. +#include "fa_tune.h" + +// FA variant key for the per-(dk,dv,variant) lazy compile cache. +// Kernel built on first dispatch to reduce kernel loading time. +// NB - a warmup run is recommended to get all necessary FA variants compiled +// before actual runs. +enum ggml_opencl_fa_variant { + FA_VARIANT_PRE = 0, // prepass kernels (kv_pad, mask_pad, blk) + FA_VARIANT_F16 = 1, + FA_VARIANT_F32 = 2, + FA_VARIANT_F32_F16 = 3, + FA_VARIANT_Q8_0 = 4, + FA_VARIANT_Q4_0 = 5, + FA_VARIANT_F32_F16_SPLIT = 6, + FA_VARIANT_Q8_0_SPLIT = 7, + FA_VARIANT_Q4_0_SPLIT = 8, +}; + +static std::string ggml_opencl_fa_kernel_src(ggml_opencl_fa_variant v) { +#ifdef GGML_OPENCL_EMBED_KERNELS + switch (v) { + case FA_VARIANT_F16: + return std::string{ + #include "flash_attn_f16.cl.h" + }; + case FA_VARIANT_F32: + return std::string{ + #include "flash_attn_f32.cl.h" + }; + case FA_VARIANT_F32_F16: + case FA_VARIANT_F32_F16_SPLIT: + return std::string{ + #include "flash_attn_f32_f16.cl.h" + }; + case FA_VARIANT_PRE: + return std::string{ + #include "flash_attn_pre_f16.cl.h" + }; + case FA_VARIANT_Q8_0: + case FA_VARIANT_Q8_0_SPLIT: + return std::string{ + #include "flash_attn_f32_q8_0.cl.h" + }; + case FA_VARIANT_Q4_0: + case FA_VARIANT_Q4_0_SPLIT: + return std::string{ + #include "flash_attn_f32_q4_0.cl.h" + }; + } + return {}; +#else + switch (v) { + case FA_VARIANT_F16: return read_file("flash_attn_f16.cl"); + case FA_VARIANT_F32: return read_file("flash_attn_f32.cl"); + case FA_VARIANT_F32_F16: + case FA_VARIANT_F32_F16_SPLIT: return read_file("flash_attn_f32_f16.cl"); + case FA_VARIANT_PRE: return read_file("flash_attn_pre_f16.cl"); + case FA_VARIANT_Q8_0: + case FA_VARIANT_Q8_0_SPLIT: return read_file("flash_attn_f32_q8_0.cl"); + case FA_VARIANT_Q4_0: + case FA_VARIANT_Q4_0_SPLIT: return read_file("flash_attn_f32_q4_0.cl"); + } + return {}; +#endif +} + +static std::string ggml_opencl_fa_compile_opts(ggml_backend_opencl_context * backend_ctx, + const ggml_opencl_fa_dim * cfg, + ggml_opencl_fa_variant variant) { + std::string opts = backend_ctx->kernel_compile_opts + + " -D DK=" + std::to_string(cfg->dk) + + " -D DV=" + std::to_string(cfg->dv) + + " -D BLOCK_M=" + std::to_string(cfg->bm) + + " -D BLOCK_N=" + std::to_string(cfg->bn); + + const bool is_split = variant == FA_VARIANT_F32_F16_SPLIT || + variant == FA_VARIANT_Q8_0_SPLIT || + variant == FA_VARIANT_Q4_0_SPLIT; + if (is_split) { + opts += " -D N_SPLIT=" + std::to_string(cfg->n_split); + if (backend_ctx->has_subgroup_shuffle) { + opts += backend_ctx->has_qcom_subgroup_shuffle + ? " -D cl_qcom_subgroup_shuffle=1" + : " -D cl_khr_subgroup_shuffle=1"; + } + } + return opts; +} + +// Log private memory for an FA kernel. Enable via `GGML_OPENCL_FA_LOG_SPILL=1`. +// On Adreno non-zero private_mem means spilling to global memory due to resource +// constraint and usually causes performance degradation. +// (per-work-item, no cache locality) — a strong signal to pick a config +// with smaller per-thread state (e.g. larger N_SPLIT). +static void ggml_opencl_log_fa_kernel_spill(ggml_backend_opencl_context * backend_ctx, + cl_kernel kernel, const char * name, int dk, int dv) { + static const bool enabled = []{ + const char * e = std::getenv("GGML_OPENCL_FA_LOG_SPILL"); + return e && e[0] && e[0] != '0'; + }(); + + if (!enabled || kernel == nullptr) { + return; + } + + cl_ulong priv_mem = 0; + if (clGetKernelWorkGroupInfo(kernel, backend_ctx->device, CL_KERNEL_PRIVATE_MEM_SIZE, + sizeof(priv_mem), &priv_mem, NULL) == CL_SUCCESS) { + const char * tag = priv_mem > 0 ? "SPILL" : "ok"; + GGML_LOG_INFO("ggml_opencl: [%s] %s DK=%d DV=%d private_mem=%llu bytes\n", + tag, name, dk, dv, (unsigned long long) priv_mem); + } +} + +static void ggml_opencl_ensure_fa_pre_kernels(ggml_backend_opencl_context * backend_ctx, int dk, int dv) { + const std::pair dk_dv = {dk, dv}; + if (backend_ctx->fa.kv_pad_f16.count(dk_dv) > 0) { + return; + } + + const ggml_opencl_fa_dim * cfg = nullptr; + for (const auto & d : g_opencl_fa_dims) { + if (d.dk == dk && d.dv == dv) { + cfg = &d; break; + } + } + + if (cfg == nullptr) { + GGML_ABORT("ggml_opencl: no flash_attn config for DK=%d DV=%d", dk, dv); + } + + GGML_LOG_INFO("ggml_opencl: lazy-compiling flash_attn prepass for DK=%d DV=%d\n", dk, dv); + + cl_int err; + const std::string src = ggml_opencl_fa_kernel_src(FA_VARIANT_PRE); + const std::string opts = ggml_opencl_fa_compile_opts(backend_ctx, cfg, FA_VARIANT_PRE); + + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, src.c_str(), opts); + + cl_kernel k_kv_pad_f16, k_mask_pad_f16, k_blk_f16; + CL_CHECK((k_kv_pad_f16 = clCreateKernel(prog, "flash_attn_kv_pad_f16", &err), err)); + CL_CHECK((k_mask_pad_f16 = clCreateKernel(prog, "flash_attn_mask_pad_f16", &err), err)); + CL_CHECK((k_blk_f16 = clCreateKernel(prog, "flash_attn_blk_f16", &err), err)); + backend_ctx->fa.kv_pad_f16[{dk, dv}] = k_kv_pad_f16; + backend_ctx->fa.mask_pad_f16[{dk, dv}] = k_mask_pad_f16; + backend_ctx->fa.blk_f16[{dk, dv}] = k_blk_f16; + CL_CHECK(clReleaseProgram(prog)); + + backend_ctx->fa.f32_f16_bm[{dk, dv}] = cfg->bm; + backend_ctx->fa.f32_f16_bn[{dk, dv}] = cfg->bn; + backend_ctx->fa.f32_f16_wg_size[{dk, dv}] = cfg->bm; + backend_ctx->fa.bm[{dk, dv}] = cfg->bm; + backend_ctx->fa.bn[{dk, dv}] = cfg->bn; +} + +// Compile one (variant, dk, dv); memoised. false = compiler rejected. +static bool ggml_opencl_ensure_fa_variant(ggml_backend_opencl_context * backend_ctx, int dk, int dv, ggml_opencl_fa_variant variant) { + const std::pair dk_dv = {dk, dv}; + + const ggml_opencl_fa_dim * cfg = nullptr; + for (const auto & d : g_opencl_fa_dims) { + if (d.dk == dk && d.dv == dv) { + cfg = &d; break; + } + } + if (cfg == nullptr) { + return false; + } + + // if a variant has already been compiled + switch (variant) { + case FA_VARIANT_F16: { + if (backend_ctx->fa.f16.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_F32: { + if (backend_ctx->fa.f32.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_F32_F16: { + if (backend_ctx->fa.f32_f16.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_Q8_0: { + if (backend_ctx->fa.f32_q8_0.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_Q4_0: { + if (backend_ctx->fa.f32_q4_0.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_F32_F16_SPLIT: { + if (backend_ctx->fa.f32_f16_split.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_Q8_0_SPLIT: { + if (backend_ctx->fa.f32_q8_0_split.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_Q4_0_SPLIT: { + if (backend_ctx->fa.f32_q4_0_split.count(dk_dv)) { + return true; + } + break; + } + case FA_VARIANT_PRE: { + ggml_opencl_ensure_fa_pre_kernels(backend_ctx, dk, dv); + return true; + } + } + + // not registered but attempted - meaning these kernels failed to compile + const auto attempt_key = std::make_pair(variant, dk_dv); + if (backend_ctx->fa.variant_attempted.count(attempt_key)) { + return false; + } + backend_ctx->fa.variant_attempted.insert(attempt_key); + + const bool is_split = variant == FA_VARIANT_F32_F16_SPLIT || + variant == FA_VARIANT_Q8_0_SPLIT || + variant == FA_VARIANT_Q4_0_SPLIT; + const bool is_quant = variant == FA_VARIANT_Q8_0 || variant == FA_VARIANT_Q8_0_SPLIT || + variant == FA_VARIANT_Q4_0 || variant == FA_VARIANT_Q4_0_SPLIT; + if (is_quant && (dk % 32 != 0 || dv % 32 != 0)) { + return false; + } + if (is_split && cfg->n_split <= 1) { + return false; + } + if ((variant == FA_VARIANT_Q8_0_SPLIT || variant == FA_VARIANT_Q4_0_SPLIT) && + ((dk / 32) % cfg->n_split != 0 || (dv / 4) % cfg->n_split != 0)) { + return false; + } + + const std::string src = ggml_opencl_fa_kernel_src(variant); + + if (src.empty()) { + return false; + } + const std::string opts = ggml_opencl_fa_compile_opts(backend_ctx, cfg, variant); + + const char * tag = nullptr; + switch (variant) { + case FA_VARIANT_F16: tag = "fa f16"; break; + case FA_VARIANT_F32: tag = "fa f32"; break; + case FA_VARIANT_F32_F16: tag = "fa f32_f16"; break; + case FA_VARIANT_Q8_0: tag = "fa q8_0"; break; + case FA_VARIANT_Q4_0: tag = "fa q4_0"; break; + case FA_VARIANT_F32_F16_SPLIT: tag = "fa f32_f16 split"; break; + case FA_VARIANT_Q8_0_SPLIT: tag = "fa q8_0 split"; break; + case FA_VARIANT_Q4_0_SPLIT: tag = "fa q4_0 split"; break; + default: break; + } + cl_program prog = build_program_from_source_ex( + backend_ctx->context, backend_ctx->device, src.c_str(), opts, /*fatal=*/false, tag); + + if (!prog) { + return false; + } + + cl_int err; + switch (variant) { + case FA_VARIANT_F16: { + cl_kernel k, kq1; + CL_CHECK((k = clCreateKernel(prog, "flash_attn_f16", &err), err)); + CL_CHECK((kq1 = clCreateKernel(prog, "flash_attn_f16_q1", &err), err)); + backend_ctx->fa.f16[{dk, dv}] = k; + backend_ctx->fa.f16_q1[{dk, dv}] = kq1; + break; + } + case FA_VARIANT_F32: { + cl_kernel k, kq1; + CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32", &err), err)); + CL_CHECK((kq1 = clCreateKernel(prog, "flash_attn_f32_q1", &err), err)); + backend_ctx->fa.f32[{dk, dv}] = k; + backend_ctx->fa.f32_q1[{dk, dv}] = kq1; + break; + } + case FA_VARIANT_F32_F16: { + cl_kernel k, kq1; + CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_f16", &err), err)); + CL_CHECK((kq1 = clCreateKernel(prog, "flash_attn_f32_f16_q1", &err), err)); + backend_ctx->fa.f32_f16[{dk, dv}] = k; + backend_ctx->fa.f32_f16_q1[{dk, dv}] = kq1; + ggml_opencl_log_fa_kernel_spill(backend_ctx, k, "flash_attn_f32_f16", dk, dv); + ggml_opencl_log_fa_kernel_spill(backend_ctx, kq1, "flash_attn_f32_f16_q1", dk, dv); + cl_kernel k_split = clCreateKernel(prog, "flash_attn_f32_f16_q1_split", &err); + if (err == CL_SUCCESS) { + backend_ctx->fa.f32_f16_q1_split[{dk, dv}] = k_split; + ggml_opencl_log_fa_kernel_spill(backend_ctx, k_split, "flash_attn_f32_f16_q1_split", dk, dv); + } + cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err); + if (err == CL_SUCCESS) { + backend_ctx->fa.f32_merge[{dk, dv}] = k_merge; + } + break; + } + case FA_VARIANT_Q8_0: + case FA_VARIANT_Q4_0: { + const bool is_q8 = variant == FA_VARIANT_Q8_0; + const std::string base = is_q8 ? "flash_attn_f32_q8_0" : "flash_attn_f32_q4_0"; + const std::string name_q1 = base + "_q1"; + const std::string name_q1_split = base + "_q1_split"; + auto & m_q1 = is_q8 ? backend_ctx->fa.f32_q8_0_q1 : backend_ctx->fa.f32_q4_0_q1; + auto & m_prefill = is_q8 ? backend_ctx->fa.f32_q8_0 : backend_ctx->fa.f32_q4_0; + auto & m_q1_split = is_q8 ? backend_ctx->fa.f32_q8_0_q1_split : backend_ctx->fa.f32_q4_0_q1_split; + + cl_kernel k, kq1; + CL_CHECK((kq1 = clCreateKernel(prog, name_q1.c_str(), &err), err)); + CL_CHECK((k = clCreateKernel(prog, base.c_str(), &err), err)); + m_q1[{dk, dv}] = kq1; + m_prefill[{dk, dv}] = k; + ggml_opencl_log_fa_kernel_spill(backend_ctx, kq1, name_q1.c_str(), dk, dv); + ggml_opencl_log_fa_kernel_spill(backend_ctx, k, base.c_str(), dk, dv); + cl_kernel k_split = clCreateKernel(prog, name_q1_split.c_str(), &err); + if (err == CL_SUCCESS) { + m_q1_split[{dk, dv}] = k_split; + ggml_opencl_log_fa_kernel_spill(backend_ctx, k_split, name_q1_split.c_str(), dk, dv); + } + if (!backend_ctx->fa.f32_merge.count({dk, dv})) { + cl_kernel k_merge = clCreateKernel(prog, "flash_attn_f32_merge", &err); + if (err == CL_SUCCESS) { + backend_ctx->fa.f32_merge[{dk, dv}] = k_merge; + } + } + break; + } + case FA_VARIANT_F32_F16_SPLIT: { + cl_kernel k; + CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_f16", &err), err)); + backend_ctx->fa.f32_f16_split[{dk, dv}] = k; + backend_ctx->fa.f32_f16_split_wg_size[{dk, dv}] = cfg->bm * cfg->n_split; + backend_ctx->fa.f32_f16_split_nkv_threshold[{dk, dv}] = cfg->nkv_split_threshold; + break; + } + case FA_VARIANT_Q8_0_SPLIT: + case FA_VARIANT_Q4_0_SPLIT: { + const bool is_q8 = variant == FA_VARIANT_Q8_0_SPLIT; + cl_kernel k; + CL_CHECK((k = clCreateKernel(prog, is_q8 ? "flash_attn_f32_q8_0" : "flash_attn_f32_q4_0", &err), err)); + auto & split = is_q8 ? backend_ctx->fa.f32_q8_0_split : backend_ctx->fa.f32_q4_0_split; + auto & split_wg = is_q8 ? backend_ctx->fa.f32_q8_0_split_wg_size : backend_ctx->fa.f32_q4_0_split_wg_size; + auto & split_bm = is_q8 ? backend_ctx->fa.f32_q8_0_split_bm : backend_ctx->fa.f32_q4_0_split_bm; + auto & split_thresh = is_q8 ? backend_ctx->fa.f32_q8_0_split_nkv_threshold : backend_ctx->fa.f32_q4_0_split_nkv_threshold; + split[{dk, dv}] = k; + split_wg[{dk, dv}] = cfg->bm * cfg->n_split; + split_bm[{dk, dv}] = cfg->bm; + split_thresh[{dk, dv}] = 0; // quant prefill: always split + break; + } + default: + break; + } + CL_CHECK(clReleaseProgram(prog)); + return true; +} + +// Compile a quant FA split kernel with a hand-picked (BLOCK_M, N_SPLIT) that +// overrides the default fa_dims tuning, for the DK values where the default +// N_SPLIT is degenerate for quant prefill: +// DK=256: default N_SPLIT=16 leaves DK/32=8 blocks -> 0 blocks/split. +// Override N_SPLIT=8 (1 block/split), BLOCK_M=16. +// DK=96 : DK/32 = 3 blocks, not divisible by the default N_SPLIT=2 -> +// override N_SPLIT=3. BLOCK_M must be 16, not 32: the N_SPLIT=3 +// QK-partial reduction uses sub_group_shuffle, so all 3 split +// threads of a query must land in one subgroup — WG_SIZE = +// BLOCK_M*N_SPLIT must be <= the 64-lane Adreno subgroup (16*3=48). +static bool ggml_opencl_ensure_fa_quant_split_override( + ggml_backend_opencl_context * backend_ctx, + int dk, int dv, int quant_bm, int quant_n_split, bool is_q8_0 +) { + const std::pair dk_dv = {dk, dv}; + if (is_q8_0 && backend_ctx->fa.f32_q8_0_split.count(dk_dv)) { + return true; + } + if (!is_q8_0 && backend_ctx->fa.f32_q4_0_split.count(dk_dv)) { + return true; + } + + const ggml_opencl_fa_variant variant = is_q8_0 ? FA_VARIANT_Q8_0_SPLIT : FA_VARIANT_Q4_0_SPLIT; + const auto attempt_key = std::make_pair(variant, dk_dv); + if (backend_ctx->fa.variant_attempted.count(attempt_key)) { + return false; + } + + backend_ctx->fa.variant_attempted.insert(attempt_key); + + std::string shuffle_opts; + if (backend_ctx->has_subgroup_shuffle) { + shuffle_opts = backend_ctx->has_qcom_subgroup_shuffle + ? " -D cl_qcom_subgroup_shuffle=1" + : " -D cl_khr_subgroup_shuffle=1"; + } + const ggml_opencl_fa_dim * cfg = nullptr; + for (const auto & d : g_opencl_fa_dims) { + if (d.dk == dk && d.dv == dv) { + cfg = &d; break; + } + } + if (cfg == nullptr) { + return false; + } + + // BLK_PREPASS_BM is the prepass-kernel BLOCK_M, needed so the quant kernel + // indexes the blk[] classification buffer correctly. + std::string opts = backend_ctx->kernel_compile_opts + shuffle_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(quant_bm) + + " -D BLOCK_N=" + std::to_string(cfg->bn) + + " -D N_SPLIT=" + std::to_string(quant_n_split) + + " -D BLK_PREPASS_BM=" + std::to_string(cfg->bm); + + const std::string src = ggml_opencl_fa_kernel_src(variant); + if (src.empty()) { + return false; + } + + const std::string tag = std::string("fa ") + (is_q8_0 ? "q8_0" : "q4_0") + + " split DK=" + std::to_string(dk); + cl_program prog = build_program_from_source_ex( + backend_ctx->context, backend_ctx->device, src.c_str(), opts, /*fatal=*/false, tag.c_str()); + + if (!prog) { + return false; + } + + cl_int err; + cl_kernel k; + if (is_q8_0) { + CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_q8_0", &err), err)); + backend_ctx->fa.f32_q8_0_split[dk_dv] = k; + backend_ctx->fa.f32_q8_0_split_wg_size[dk_dv] = quant_bm * quant_n_split; + backend_ctx->fa.f32_q8_0_split_bm[dk_dv] = quant_bm; + backend_ctx->fa.f32_q8_0_split_nkv_threshold[dk_dv] = 0; + } else { + CL_CHECK((k = clCreateKernel(prog, "flash_attn_f32_q4_0", &err), err)); + backend_ctx->fa.f32_q4_0_split[dk_dv] = k; + backend_ctx->fa.f32_q4_0_split_wg_size[dk_dv] = quant_bm * quant_n_split; + backend_ctx->fa.f32_q4_0_split_bm[dk_dv] = quant_bm; + backend_ctx->fa.f32_q4_0_split_nkv_threshold[dk_dv] = 0; + } + CL_CHECK(clReleaseProgram(prog)); + return true; +} + namespace /* anonymous */ { extern struct ggml_backend_device_i ggml_backend_opencl_device_i; } @@ -3955,6 +4425,8 @@ static void ggml_opencl_print_backend_info(ggml_backend_opencl_device_context * backend_ctx->driver_version.c_str()); GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: subgroup shuffle support: %s\n", + backend_ctx->has_subgroup_shuffle ? "true" : "false"); GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", @@ -4111,6 +4583,8 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->gpu_family = dev_ctx->gpu_family; backend_ctx->adreno_gen = dev_ctx->adreno_gen; if (backend_ctx->gpu_family == GPU_FAMILY::ADRENO) { + ggml_cl_init_fa_dims_table(); + // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; } @@ -4156,6 +4630,11 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; + // subgroup shuffle support (N_SPLIT>1 FA kernel) + backend_ctx->has_qcom_subgroup_shuffle = strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL; + backend_ctx->has_subgroup_shuffle = + strstr(ext_buffer, "cl_khr_subgroup_shuffle") != NULL || + backend_ctx->has_qcom_subgroup_shuffle; cl_uint base_align_in_bits; CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); @@ -5100,6 +5579,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te switch (op->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); default: return false; @@ -5175,9 +5656,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_EXP: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + // Adreno F16 exp/expm1 overflow even post-half->float convert. + return op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_EXPM1: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SOFTPLUS: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; default: @@ -5250,7 +5732,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; - } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + } else if (op->src[0]->type == GGML_TYPE_Q4_0) { + // Non-contig src0 routes through on-device dequant-to-f16. + return op->src[1]->type == GGML_TYPE_F32; + } else if (op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || op->src[0]->type == GGML_TYPE_Q5_1 || op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_IQ4_NL || @@ -5339,43 +5824,55 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_FLASH_ATTN_EXT: - { - load_cl_kernels_flash_attn(backend_ctx); - - const ggml_tensor * q = op->src[0]; - const ggml_tensor * k = op->src[1]; - const ggml_tensor * v = op->src[2]; - - const int dk = q->ne[0]; - const int dv = v->ne[0]; - - const struct { int dk; int dv; } supported_dims[] = { - { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96}, - {112, 112}, {128, 128}, {192, 128}, - {192, 192}, {256, 256}, - }; + case GGML_OP_FLASH_ATTN_EXT: { + const ggml_tensor * q = op->src[0]; + const ggml_tensor * k = op->src[1]; + const ggml_tensor * v = op->src[2]; + + const int dk = q->ne[0]; + const int dv = v->ne[0]; + + const struct { int dk; int dv; } supported_dims[] = { + { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96}, + {112, 112}, {128, 128}, {192, 128}, + {192, 192}, {256, 256}, + }; - bool dims_supported = false; - for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { - if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { - dims_supported = true; - break; - } - } - if (!dims_supported) { - return false; + bool dims_supported = false; + for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { + if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { + dims_supported = true; + break; } + } + if (!dims_supported) { + return false; + } - const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && - v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && - v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; - const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && - v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; + const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && + v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; + const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; + const bool is_f32_q8_0 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_Q8_0 && + v->type == GGML_TYPE_Q8_0 && op->type == GGML_TYPE_F32 && + dk % 32 == 0 && dv % 32 == 0; + const bool is_f32_q4_0 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_Q4_0 && + v->type == GGML_TYPE_Q4_0 && op->type == GGML_TYPE_F32 && + dk % 32 == 0 && dv % 32 == 0; + + // Asymmetric KV: host-dequants both sides to F32, uses f32 kernel. + auto is_kv_type_ok = [](ggml_type t) { + return t == GGML_TYPE_F16 || t == GGML_TYPE_F32 || + t == GGML_TYPE_Q4_0 || t == GGML_TYPE_Q8_0; + }; + const bool is_f32_asym = q->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && + k->type != v->type && + is_kv_type_ok(k->type) && is_kv_type_ok(v->type); - return is_f32_f32 || is_f16_f16 || is_f32_f16; - } + return is_f32_f32 || is_f16_f16 || is_f32_f16 || is_f32_q8_0 || is_f32_q4_0 || is_f32_asym; + } default: return false; } @@ -5737,6 +6234,9 @@ struct ggml_backend_opencl_buffer_context { temp_tensor_extras_q6_K.push_back(e); } temp_tensor_extras_q6_K_in_use.clear(); + + q8_0_soa_tensors.clear(); + q4_0_soa_tensors.clear(); } // Pools for extras. Available extras are in `temp_tensor_extras`. Extras @@ -5767,6 +6267,17 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_q6_K; std::vector temp_tensor_extras_q6_K_in_use; + // q8_0 tensors with AoS->SoA layout conversion installed by set_tensor. + // Two types of tensors get SOA'ed - normal weights and MoE weights. + // In Q8_0's case, we only have normal weights. If we ever have Q8_0 as MoE + // weights, they need to be added to this set in `set_tensors`. + std::unordered_set q8_0_soa_tensors; + + // Same for q4_0. KV-cache q4_0 tensors are allocated but never pass + // through set_tensor, so they stay AoS and aren't in this set. + // In Q4_0's case, in addition to normal weights, we have MoE weights. + std::unordered_set q4_0_soa_tensors; + // The buffer_context is initially created by ggml_backend_buft_alloc_buffer // before any tensor is initialized (at the beginning of alloc_tensor_range). // Hence, there is always a buffer object in this vector. When each tensor is @@ -5848,6 +6359,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // buffers for quantized bits and scales, which are then populated by the // conversion kernel. if (tensor->type == GGML_TYPE_Q4_0) { + // Views can't SoA-ify here — parent owns the layout (see q8_0 guard). + if (tensor->view_src != nullptr || !ggml_is_contiguous(tensor)) { + return; + } // Tensors should have been preallocated, therefore they should // already have ggml_tensor_extra_cl as extra. ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -5937,6 +6452,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, }; extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); tensor->extra = extra; + // MoE tensors are also SOA'ed + ctx->q4_0_soa_tensors.insert(tensor); return; } @@ -5965,6 +6482,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + ctx->q4_0_soa_tensors.insert(tensor); // transpose the weights and scales #ifdef GGML_OPENCL_USE_ADRENO_KERNELS @@ -6516,6 +7034,11 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } if (tensor->type == GGML_TYPE_Q8_0) { + // Views share the parent's buffer; parent owns SoA conversion. + if (tensor->view_src != nullptr || !ggml_is_contiguous(tensor)) { + return; + } + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -6571,6 +7094,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + ctx->q8_0_soa_tensors.insert(tensor); // Transpose the weights and scales #ifdef GGML_OPENCL_USE_ADRENO_KERNELS @@ -7226,7 +7750,18 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, // To properly support this, we need to restore block_q4_0 struct arrays // from the flattened buffers. if (tensor->type == GGML_TYPE_Q4_0) { - ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; + // KV-cache q4_0 stays AoS — direct readback, no SoA restore. + if (!ggml_cl_is_q4_0_soa(tensor)) { + ggml_tensor_extra_cl * extra_aos = (ggml_tensor_extra_cl *) tensor->extra; + CL_CHECK(clEnqueueReadBuffer( + queue, extra_aos->data_device, CL_TRUE, + extra_aos->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); + return; + } + // SoA extra lives on the parent tensor — follow view_src. + const ggml_tensor * extra_src = tensor->view_src != nullptr ? tensor->view_src : tensor; + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)extra_src->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { @@ -7697,7 +8232,18 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, return; } if (tensor->type == GGML_TYPE_Q8_0) { - ggml_tensor_extra_cl_q8_0 * extra = (ggml_tensor_extra_cl_q8_0 *)tensor->extra; + // KV-cache q8_0 stays AoS (see Q4_0 branch). + if (!ggml_cl_is_q8_0_soa(tensor)) { + ggml_tensor_extra_cl * extra_aos = (ggml_tensor_extra_cl *) tensor->extra; + CL_CHECK(clEnqueueReadBuffer( + queue, extra_aos->data_device, CL_TRUE, + extra_aos->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); + return; + } + // SoA extra lives on the parent — follow view_src. + const ggml_tensor * extra_src = tensor->view_src != nullptr ? tensor->view_src : tensor; + ggml_tensor_extra_cl_q8_0 * extra = (ggml_tensor_extra_cl_q8_0 *)extra_src->extra; cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -8821,6 +9367,34 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +// check if a Q8_0 tensor has been SOA'ed in set_tensor +// we store SOA'ed tensors in a map in set_tensor, check against that map +static bool ggml_cl_is_q8_0_soa(const ggml_tensor * tensor) { + if (tensor == nullptr || tensor->type != GGML_TYPE_Q8_0 || tensor->buffer == nullptr) { + return false; + } + auto * ctx = (ggml_backend_opencl_buffer_context *) tensor->buffer->context; + if (ctx == nullptr) { + return false; + } + const ggml_tensor * key = tensor->view_src != nullptr ? tensor->view_src : tensor; + return ctx->q8_0_soa_tensors.count(key) > 0; +} + +// check if a Q4_0 tensor has been SOA'ed in set_tensor +// we store SOA'ed tensors in a map in set_tensor, check against that map +static bool ggml_cl_is_q4_0_soa(const ggml_tensor * tensor) { + if (tensor == nullptr || tensor->type != GGML_TYPE_Q4_0 || tensor->buffer == nullptr) { + return false; + } + auto * ctx = (ggml_backend_opencl_buffer_context *) tensor->buffer->context; + if (ctx == nullptr) { + return false; + } + const ggml_tensor * key = tensor->view_src != nullptr ? tensor->view_src : tensor; + return ctx->q4_0_soa_tensors.count(key) > 0; +} + static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -8834,26 +9408,14 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c // ne2 = ne02 // ne3 = ne03 - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; - - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - - const cl_ulong nb10 = src1->nb[0]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); - const int ne0 = dst->ne[0]; + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); const int nblk0 = ne0/ggml_blck_size(dst->type); @@ -8861,31 +9423,49 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + + const bool q8_0_soa = dst->type == GGML_TYPE_Q8_0 && ggml_cl_is_q8_0_soa(dst); + const bool q4_0_soa = dst->type == GGML_TYPE_Q4_0 && ggml_cl_is_q4_0_soa(dst); + const bool is_soa = q8_0_soa || q4_0_soa; cl_kernel kernel; - switch (dst->type) { - case GGML_TYPE_F32: - if (src1->type == GGML_TYPE_I64) { - kernel = backend_ctx->kernel_set_rows_f32_i64; - } else { - kernel = backend_ctx->kernel_set_rows_f32_i32; - } - break; - case GGML_TYPE_F16: - if (src1->type == GGML_TYPE_I64) { - kernel = backend_ctx->kernel_set_rows_f16_i64; - } else { - kernel = backend_ctx->kernel_set_rows_f16_i32; - } - break; - default: - GGML_ABORT("not implemented"); + if (q8_0_soa) { + kernel = (src1->type == GGML_TYPE_I64) + ? backend_ctx->kernel_set_rows_q8_0_soa_i64 + : backend_ctx->kernel_set_rows_q8_0_soa_i32; + } else if (q4_0_soa) { + kernel = (src1->type == GGML_TYPE_I64) + ? backend_ctx->kernel_set_rows_q4_0_soa_i64 + : backend_ctx->kernel_set_rows_q4_0_soa_i32; + } else { + switch (dst->type) { + case GGML_TYPE_F32: + kernel = (src1->type == GGML_TYPE_I64) + ? backend_ctx->kernel_set_rows_f32_i64 + : backend_ctx->kernel_set_rows_f32_i32; + break; + case GGML_TYPE_F16: + kernel = (src1->type == GGML_TYPE_I64) + ? backend_ctx->kernel_set_rows_f16_i64 + : backend_ctx->kernel_set_rows_f16_i32; + break; + case GGML_TYPE_Q8_0: + kernel = (src1->type == GGML_TYPE_I64) + ? backend_ctx->kernel_set_rows_q8_0_i64 + : backend_ctx->kernel_set_rows_q8_0_i32; + break; + case GGML_TYPE_Q4_0: + kernel = (src1->type == GGML_TYPE_I64) + ? backend_ctx->kernel_set_rows_q4_0_i64 + : backend_ctx->kernel_set_rows_q4_0_i32; + break; + default: + GGML_ABORT("not implemented"); + } } fastdiv_vals ne11_ = init_fastdiv_values(ne11); @@ -8895,21 +9475,65 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3)); + + if (is_soa) { + // The q/d subbuffers in q8_0/q4_0 extras are interchangeable here. + // For views (e.g. ggml_set_rows' `out`), follow view_src for the SoA extra. + const ggml_tensor * soa_src = dst->view_src != nullptr ? dst->view_src : dst; + cl_mem q_mem = nullptr; + cl_mem d_mem = nullptr; + if (q8_0_soa) { + ggml_tensor_extra_cl_q8_0 * e = (ggml_tensor_extra_cl_q8_0 *)soa_src->extra; + q_mem = e->q; + d_mem = e->d; + } else { + ggml_tensor_extra_cl_q4_0 * e = (ggml_tensor_extra_cl_q4_0 *)soa_src->extra; + q_mem = e->q; + d_mem = e->d; + } + cl_ulong offset_q = 0; + cl_ulong offset_d = 0; + const int ne1_dst = dst->ne[1]; + const int ne2_dst = dst->ne[2]; + const int ne3_dst = dst->ne[3]; + + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &q_mem)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_q)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &d_mem)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_d)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(fastdiv_vals), &ne11_)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(fastdiv_vals), &ne12_)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &nblk0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne1_dst)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne2_dst)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne3_dst)); + } else { + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3)); + } int nth0 = 64; if (backend_ctx->gpu_family == INTEL) { @@ -11483,14 +12107,370 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); } +struct ggml_cl_flash_attn_temp_buffer { + cl_mem data = nullptr; + + ~ggml_cl_flash_attn_temp_buffer() { + if (data != nullptr) { + CL_CHECK(clReleaseMemObject(data)); + data = nullptr; + } + } +}; + +// Resolve the source buffer + strides for an FA KV tensor: keep the +// caller-supplied AoS buffer if non-NULL, else fall back to tensor->extra. +static void ggml_cl_flash_attn_resolve_src( + const ggml_tensor * tensor, + cl_mem & buf, + cl_ulong & offset, + cl_ulong & nb1, + cl_ulong & nb2, + cl_ulong & nb3) { + if (buf != NULL) { + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra && extra->data_device); + buf = extra->data_device; + offset = extra->offset + tensor->view_offs; + nb1 = tensor->nb[1]; + nb2 = tensor->nb[2]; + nb3 = tensor->nb[3]; +} + +// Read a (possibly strided-view) tensor from device into a tight host buffer. +// dim 0 is always tight; a strided view is gathered row-by-row. +static void ggml_cl_flash_attn_read_tensor_host( + ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * tensor, + cl_mem src_buffer, cl_ulong src_offset, + cl_ulong src_nb1, cl_ulong src_nb2, cl_ulong src_nb3, + size_t row_bytes, void * dst, size_t total_bytes +) { + const bool contiguous_layout = + src_nb1 == row_bytes && + src_nb2 == row_bytes * (cl_ulong) tensor->ne[1] && + src_nb3 == src_nb2 * (cl_ulong) tensor->ne[2]; + + if (contiguous_layout) { + CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, src_buffer, CL_TRUE, + src_offset, total_bytes, dst, 0, NULL, NULL)); + return; + } + + size_t dst_off = 0; + for (int64_t i3 = 0; i3 < tensor->ne[3]; ++i3) { + for (int64_t i2 = 0; i2 < tensor->ne[2]; ++i2) { + for (int64_t i1 = 0; i1 < tensor->ne[1]; ++i1) { + const cl_ulong row_src_off = src_offset + + (cl_ulong) i3 * src_nb3 + + (cl_ulong) i2 * src_nb2 + + (cl_ulong) i1 * src_nb1; + CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, src_buffer, CL_TRUE, + row_src_off, row_bytes, + (uint8_t *) dst + dst_off, 0, NULL, NULL)); + dst_off += row_bytes; + } + } + } + GGML_ASSERT(dst_off == total_bytes); +} + +// Rebuild AoS q8_0/q4_0 bytes from a SoA tensor into a temp buffer. +// Returns false if the tensor is not SoA-quantised (already AoS). +static bool ggml_cl_flash_attn_reconstruct_aos( + ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * tensor, + ggml_cl_flash_attn_temp_buffer & temp, + cl_mem & out_buf, + cl_ulong & out_offset, + cl_ulong & out_nb1, + cl_ulong & out_nb2, + cl_ulong & out_nb3 +) { + if (tensor == nullptr) { + return false; + } + const bool is_q8_0 = tensor->type == GGML_TYPE_Q8_0 && ggml_cl_is_q8_0_soa(tensor); + const bool is_q4_0 = tensor->type == GGML_TYPE_Q4_0 && ggml_cl_is_q4_0_soa(tensor); + if (!is_q8_0 && !is_q4_0) { + return false; + } + + // For views, SoA extra is on view_src (view->extra is pre-SoA). + // Noshuffle layout only applies to 2D weights, as determined by `use_adreno_kernels`, + // where ne2 == 1 and ne3 == 1 -- these are never FA inputs. + // Therefore, we use `restore_block_qk_0` kernels, not `restore_block_qk_0_noshuffle`. + const ggml_tensor * soa_src = tensor->view_src ? tensor->view_src : tensor; + cl_mem extra_q = NULL; + cl_mem extra_d = NULL; + if (is_q8_0) { + auto * e = (ggml_tensor_extra_cl_q8_0 *) soa_src->extra; + GGML_ASSERT(e && e->q && e->d); + extra_q = e->q; + extra_d = e->d; + } else { + auto * e = (ggml_tensor_extra_cl_q4_0 *) soa_src->extra; + GGML_ASSERT(e && e->q && e->d); + extra_q = e->q; + extra_d = e->d; + } + + // Reconstruct the whole parent; view offsets then work naturally. + const size_t parent_nbytes = ggml_nbytes(soa_src); + cl_int err; + temp.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, parent_nbytes, NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = is_q8_0 ? backend_ctx->kernel_restore_block_q8_0 + : backend_ctx->kernel_restore_block_q4_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &temp.data)); + + const size_t n_blocks = (size_t) ggml_nelements(soa_src) / ggml_blck_size(soa_src->type); + size_t global_work_size[] = { n_blocks, 1, 1 }; + size_t local_work_size[] = { 1, 1, 1 }; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + + out_buf = temp.data; + out_offset = tensor->view_offs; + out_nb1 = tensor->nb[1]; + out_nb2 = tensor->nb[2]; + out_nb3 = tensor->nb[3]; + return true; +} + +// GPU dequant of a contiguous q4_0/q8_0 KV tensor to f16/f32. Caller supplies +// src_buf when reconstructing from SoA. Returns false for non-contig layouts +// (the kernel indexes blocks tightly within ne[0]) so the caller can fall back +// to the host path. +static bool ggml_cl_flash_attn_dequant_kv_gpu( + ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * tensor, + ggml_type target_type, + cl_mem in_src_buf, + cl_ulong in_src_offset, + cl_ulong in_src_nb1, + cl_ulong in_src_nb2, + cl_ulong in_src_nb3, + ggml_cl_flash_attn_temp_buffer & temp, + cl_mem & out_buf, + cl_ulong & out_offset, + cl_ulong & out_nb1, + cl_ulong & out_nb2, + cl_ulong & out_nb3 +) { + GGML_ASSERT(tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_0); + GGML_ASSERT(target_type == GGML_TYPE_F16 || target_type == GGML_TYPE_F32); + + const bool is_q8_0 = tensor->type == GGML_TYPE_Q8_0; + + cl_mem src_buf = in_src_buf; + cl_ulong src_offset = in_src_offset; + cl_ulong src_nb1 = in_src_nb1; + cl_ulong src_nb2 = in_src_nb2; + cl_ulong src_nb3 = in_src_nb3; + ggml_cl_flash_attn_resolve_src(tensor, src_buf, src_offset, src_nb1, src_nb2, src_nb3); + + if (tensor->nb[0] != (cl_ulong) ggml_type_size(tensor->type)) { + return false; + } + + const size_t n_blocks = (size_t) ggml_nelements(tensor) / 32; // block size is 32 + const size_t elem_size = ggml_type_size(target_type); + const size_t out_bytes = n_blocks * 32 * elem_size; + const cl_int nblk0_arg = (cl_int) (tensor->ne[0] / 32); + const cl_int ne1_arg = (cl_int) tensor->ne[1]; + const cl_int ne2_arg = (cl_int) tensor->ne[2]; + const cl_int ne3_arg = (cl_int) tensor->ne[3]; + + cl_int err; + temp.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, out_bytes, NULL, &err); + CL_CHECK(err); + + cl_kernel kernel; + if (target_type == GGML_TYPE_F16) { + kernel = is_q8_0 ? backend_ctx->kernel_dequant_q8_0_f16_view_aos + : backend_ctx->kernel_dequant_q4_0_f16_view_aos; + } else { + kernel = is_q8_0 ? backend_ctx->kernel_dequant_q8_0_f32_view_aos + : backend_ctx->kernel_dequant_q4_0_f32_view_aos; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &src_offset)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &src_nb1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &src_nb2)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &src_nb3)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &nblk0_arg)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne1_arg)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne2_arg)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne3_arg)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_mem), &temp.data)); + + size_t global_ws[3] = { (size_t) nblk0_arg, (size_t) ne1_arg, (size_t) ne2_arg * (size_t) ne3_arg }; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, + global_ws, NULL, 0, NULL, NULL)); + + out_buf = temp.data; + out_offset = 0; + out_nb1 = (cl_ulong) tensor->ne[0] * elem_size; + out_nb2 = out_nb1 * (cl_ulong) tensor->ne[1]; + out_nb3 = out_nb2 * (cl_ulong) tensor->ne[2]; + return true; +} + +static bool ggml_cl_flash_attn_prepare_quantized_tensor( + ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * tensor, + ggml_type target_type, + ggml_cl_flash_attn_temp_buffer & temp, + cl_mem & data_device, + cl_ulong & offset, + cl_ulong & nb1, + cl_ulong & nb2, + cl_ulong & nb3 +) { + if (!ggml_is_quantized(tensor->type)) { + return false; + } + + // Caller-supplied AoS buffer wins over tensor->extra when present. + cl_mem src_buffer = data_device; + cl_ulong src_offset = offset; + cl_ulong src_nb1 = nb1; + cl_ulong src_nb2 = nb2; + cl_ulong src_nb3 = nb3; + ggml_cl_flash_attn_resolve_src(tensor, src_buffer, src_offset, src_nb1, src_nb2, src_nb3); + + const int64_t n = ggml_nelements(tensor); + const size_t row_bytes = (size_t) (tensor->ne[0] / ggml_blck_size(tensor->type)) * ggml_type_size(tensor->type); + // tight-packed byte count (ggml_nbytes includes stride gaps). + const size_t total_bytes = (size_t) (n / ggml_blck_size(tensor->type)) * ggml_type_size(tensor->type); + std::vector host_quant(total_bytes); + + sync_with_other_backends(backend_ctx); + ggml_cl_flash_attn_read_tensor_host(backend_ctx, tensor, src_buffer, src_offset, + src_nb1, src_nb2, src_nb3, + row_bytes, host_quant.data(), total_bytes); + + std::vector host_f32(n); + ggml_get_type_traits(tensor->type)->to_float(host_quant.data(), host_f32.data(), n); + + const size_t bytes_per_elem = ggml_type_size(target_type); + const size_t buffer_size = (size_t) n * bytes_per_elem; + + std::vector host_linear(buffer_size); + if (target_type == GGML_TYPE_F32) { + memcpy(host_linear.data(), host_f32.data(), buffer_size); + } else { + GGML_ASSERT(target_type == GGML_TYPE_F16); + ggml_fp32_to_fp16_row(host_f32.data(), (ggml_fp16_t *) host_linear.data(), n); + } + + cl_int err; + temp.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, buffer_size, NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer(backend_ctx->queue, temp.data, CL_TRUE, 0, buffer_size, host_linear.data(), 0, NULL, NULL)); + + data_device = temp.data; + offset = 0; + nb1 = (cl_ulong) (tensor->ne[0] * bytes_per_elem); + nb2 = (cl_ulong) (tensor->ne[1] * nb1); + nb3 = (cl_ulong) (tensor->ne[2] * nb2); + + static bool warned = false; + if (!warned) { + GGML_LOG_WARN("ggml_opencl: OpenCL flash attention dequantizes GPU-resident quantized KV cache into temporary linear buffers; performance may be poor\n"); + warned = true; + } + + return true; +} + +// Host-side F16 -> F32 for the asymmetric-KV F32 fallback path. +static bool ggml_cl_flash_attn_convert_f16_to_f32( + ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * tensor, + ggml_cl_flash_attn_temp_buffer & temp, + cl_mem & data_device, + cl_ulong & offset, + cl_ulong & nb1, + cl_ulong & nb2, + cl_ulong & nb3 +) { + if (tensor->type != GGML_TYPE_F16) { + return false; + } + + cl_mem src_buffer = data_device; + cl_ulong src_offset = offset; + cl_ulong src_nb1 = nb1; + cl_ulong src_nb2 = nb2; + cl_ulong src_nb3 = nb3; + ggml_cl_flash_attn_resolve_src(tensor, src_buffer, src_offset, src_nb1, src_nb2, src_nb3); + + const int64_t n = ggml_nelements(tensor); + const size_t row_bytes = (size_t) tensor->ne[0] * sizeof(ggml_fp16_t); + const size_t total_bytes = (size_t) n * sizeof(ggml_fp16_t); + std::vector host_f16(total_bytes); + + sync_with_other_backends(backend_ctx); + ggml_cl_flash_attn_read_tensor_host(backend_ctx, tensor, src_buffer, src_offset, + src_nb1, src_nb2, src_nb3, + row_bytes, host_f16.data(), total_bytes); + + std::vector host_f32(n); + ggml_fp16_to_fp32_row((const ggml_fp16_t *) host_f16.data(), host_f32.data(), n); + + const size_t f32_bytes = (size_t) n * sizeof(float); + cl_int err; + temp.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, f32_bytes, NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer(backend_ctx->queue, temp.data, CL_TRUE, 0, + f32_bytes, host_f32.data(), 0, NULL, NULL)); + + data_device = temp.data; + offset = 0; + nb1 = (cl_ulong) (tensor->ne[0] * sizeof(float)); + nb2 = (cl_ulong) (tensor->ne[1] * nb1); + nb3 = (cl_ulong) (tensor->ne[2] * nb2); + + static bool warned = false; + if (!warned) { + GGML_LOG_WARN("ggml_opencl: OpenCL flash attention asymmetric KV converts an F16 cache to F32 host-side; performance may be poor\n"); + warned = true; + } + + return true; +} + +// Flash-Decoding (K-split) dispatch thresholds. FD fires for non-causal +// attention with n_kv >= FD_MIN_N_KV and d_head <= FD_MAX_DK; the KV range is +// split into ~n_kv/FD_KV_PER_SPLIT partials, clamped to [FD_MIN_SPLITS, +// FD_MAX_SPLITS]. Multi-query FD is restricted to small heads +// (d_head <= FD_MAX_DK_MULTI) and capped at FD_MAX_N_Q_MULTI queries. +static constexpr int FD_MIN_N_KV = 2048; +static constexpr int FD_KV_PER_SPLIT = 2048; +static constexpr int FD_MIN_SPLITS = 2; +static constexpr int FD_MAX_SPLITS = 16; +static constexpr int FD_MAX_DK = 128; +static constexpr int FD_MAX_DK_MULTI = 64; +static constexpr int FD_MAX_N_Q_MULTI = 8; + static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { const ggml_tensor * v = dst->src[2]; const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; + GGML_ASSERT(q->extra); GGML_ASSERT(k->extra); GGML_ASSERT(v->extra); GGML_ASSERT(dst->extra); + if (mask) { GGML_ASSERT(mask->extra); } @@ -11508,87 +12488,463 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co const int n_head_kv = k->ne[2]; const int n_batch = q->ne[3]; + // Per-variant lazy compile for this (dk, dv). + ggml_opencl_ensure_fa_pre_kernels(backend_ctx, d_head_q, d_head_v); + cl_kernel kernel = NULL; const bool is_f16 = q->type == GGML_TYPE_F16; - const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; - const std::pair dk_dv = {d_head_q, d_head_v}; - - if (n_q == 1) { - if (is_mixed) { - kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); - } else if (is_f16) { - kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); + const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16; + const bool is_q8_0 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_Q8_0 && v->type == GGML_TYPE_Q8_0; + const bool is_q4_0 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_Q4_0 && v->type == GGML_TYPE_Q4_0; + + if (is_f16) { + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_F16); + } else if (is_mixed) { + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_F32_F16); + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_F32_F16_SPLIT); + } else if (is_q8_0) { + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_Q8_0); + if (d_head_q == 96 && d_head_v == 96) { + ggml_opencl_ensure_fa_quant_split_override(backend_ctx, 96, 96, /*quant_bm=*/16, /*quant_n_split=*/3, /*is_q8_0=*/true); + } else if (d_head_q == 256 && d_head_v == 256) { + ggml_opencl_ensure_fa_quant_split_override(backend_ctx, 256, 256, /*quant_bm=*/16, /*quant_n_split=*/8, /*is_q8_0=*/true); } else { - kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); - } - } else { - if (is_mixed) { - kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); - } else if (is_f16) { - kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_Q8_0_SPLIT); + } + } else if (is_q4_0) { + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_Q4_0); + if (d_head_q == 96 && d_head_v == 96) { + ggml_opencl_ensure_fa_quant_split_override(backend_ctx, 96, 96, /*quant_bm=*/16, /*quant_n_split=*/3, /*is_q8_0=*/false); + } else if (d_head_q == 256 && d_head_v == 256) { + ggml_opencl_ensure_fa_quant_split_override(backend_ctx, 256, 256, /*quant_bm=*/16, /*quant_n_split=*/8, /*is_q8_0=*/false); } else { - kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_Q4_0_SPLIT); } + } else { + ggml_opencl_ensure_fa_variant(backend_ctx, d_head_q, d_head_v, FA_VARIANT_F32); } - GGML_ASSERT(kernel != NULL); + + const std::pair dk_dv = {d_head_q, d_head_v}; + const bool use_native_q8_0_q1 = is_q8_0 && n_q == 1 && + backend_ctx->fa.f32_q8_0_q1.count(dk_dv) > 0; + // Native q8_0 prefill — reads q8_0 directly, wg_size = cfg->bm. + const bool use_native_q8_0 = is_q8_0 && n_q > 1 && + backend_ctx->fa.f32_q8_0.count(dk_dv) > 0; + const bool use_native_q4_0_q1 = is_q4_0 && n_q == 1 && + backend_ctx->fa.f32_q4_0_q1.count(dk_dv) > 0; + const bool use_native_q4_0 = is_q4_0 && n_q > 1 && + backend_ctx->fa.f32_q4_0.count(dk_dv) > 0; + const int block_m = n_q > 1 + ? (is_mixed ? backend_ctx->fa.f32_f16_bm.at(dk_dv) : backend_ctx->fa.bm.at(dk_dv)) + : 0; + const int block_n = is_mixed + ? backend_ctx->fa.f32_f16_bn.at(dk_dv) + : backend_ctx->fa.bn.at(dk_dv); + // Pick split variant only when n_kv crosses the per-(dk,dv) threshold. + const bool use_split_kernel = (n_q > 1 && is_mixed && + backend_ctx->fa.f32_f16_split.count(dk_dv) > 0 && + n_kv >= backend_ctx->fa.f32_f16_split_nkv_threshold.at(dk_dv)); + const bool use_split_q8_0 = (use_native_q8_0 && + backend_ctx->fa.f32_q8_0_split.count(dk_dv) > 0 && + n_kv >= backend_ctx->fa.f32_q8_0_split_nkv_threshold.at(dk_dv)); + const bool use_split_q4_0 = (use_native_q4_0 && + backend_ctx->fa.f32_q4_0_split.count(dk_dv) > 0 && + n_kv >= backend_ctx->fa.f32_q4_0_split_nkv_threshold.at(dk_dv)); + const int wg_size_fa = (n_q > 1 && is_mixed) + ? (use_split_kernel + ? backend_ctx->fa.f32_f16_split_wg_size.at(dk_dv) + : backend_ctx->fa.f32_f16_wg_size.at(dk_dv)) + : block_m; ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; - ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; - ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL; + // SoA q8_0/q4_0 K/V: data_device aliases the `q` subbuffer; reconstruct + // AoS into a temp buffer below. AoS tensors use extra_k/v->data_device. + const bool k_soa = ggml_cl_is_q8_0_soa(k) || ggml_cl_is_q4_0_soa(k); + const bool v_soa = ggml_cl_is_q8_0_soa(v) || ggml_cl_is_q4_0_soa(v); + ggml_tensor_extra_cl * extra_k = k_soa ? nullptr : (ggml_tensor_extra_cl *)k->extra; + ggml_tensor_extra_cl * extra_v = v_soa ? nullptr : (ggml_tensor_extra_cl *)v->extra; + cl_ulong offset_q = extra_q->offset + q->view_offs; - cl_ulong offset_k = extra_k->offset + k->view_offs; - cl_ulong offset_v = extra_v->offset + v->view_offs; + cl_ulong offset_k = k_soa ? 0 : extra_k->offset + k->view_offs; + cl_ulong offset_v = v_soa ? 0 : extra_v->offset + v->view_offs; cl_ulong offset_o = extra_o->offset + dst->view_offs; cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL; cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0; - const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; - const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; - const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; - const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; + const cl_ulong q_nb1 = q->nb[1]; + const cl_ulong q_nb2 = q->nb[2]; + const cl_ulong q_nb3 = q->nb[3]; + + cl_ulong k_nb1 = k->nb[1]; + cl_ulong k_nb2 = k->nb[2]; + cl_ulong k_nb3 = k->nb[3]; + + cl_ulong v_nb1 = v->nb[1]; + cl_ulong v_nb2 = v->nb[2]; + cl_ulong v_nb3 = v->nb[3]; + + const cl_ulong o_nb1 = dst->nb[1]; + const cl_ulong o_nb2 = dst->nb[2]; + const cl_ulong o_nb3 = dst->nb[3]; + const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; const int mask_ne2 = mask ? mask->ne[2] : 0; const int mask_ne3 = mask ? mask->ne[3] : 0; - float scale, max_bias, logit_softcap; + float scale; + float max_bias; + float logit_softcap; + const float * params = (const float *)dst->op_params; scale = params[0]; max_bias = params[1]; logit_softcap = params[2]; + if (n_q == 1) { + if (use_native_q8_0_q1) { + kernel = backend_ctx->fa.f32_q8_0_q1.at(dk_dv); + } else if (use_native_q4_0_q1) { + kernel = backend_ctx->fa.f32_q4_0_q1.at(dk_dv); + } else if (is_mixed) { + kernel = backend_ctx->fa.f32_f16_q1.at(dk_dv); + } else if (is_f16) { + kernel = backend_ctx->fa.f16_q1.at(dk_dv); + } else { + kernel = backend_ctx->fa.f32_q1.at(dk_dv); + } + } else { + if (use_native_q8_0) { + kernel = use_split_q8_0 + ? backend_ctx->fa.f32_q8_0_split.at(dk_dv) + : backend_ctx->fa.f32_q8_0.at(dk_dv); + } else if (use_native_q4_0) { + kernel = use_split_q4_0 + ? backend_ctx->fa.f32_q4_0_split.at(dk_dv) + : backend_ctx->fa.f32_q4_0.at(dk_dv); + } else if (is_mixed) { + kernel = use_split_kernel + ? backend_ctx->fa.f32_f16_split.at(dk_dv) + : backend_ctx->fa.f32_f16.at(dk_dv); + } else if (is_f16) { + kernel = backend_ctx->fa.f16.at(dk_dv); + } else { + kernel = backend_ctx->fa.f32.at(dk_dv); + } + } + GGML_ASSERT(kernel != NULL); + + ggml_cl_flash_attn_temp_buffer temp_k; + ggml_cl_flash_attn_temp_buffer temp_v; + ggml_cl_flash_attn_temp_buffer temp_k_pad; + ggml_cl_flash_attn_temp_buffer temp_v_pad; + ggml_cl_flash_attn_temp_buffer temp_mask_pad; + ggml_cl_flash_attn_temp_buffer temp_blk; + const ggml_type kv_target_type = is_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + cl_mem k_data_device = k_soa ? NULL : extra_k->data_device; + cl_mem v_data_device = v_soa ? NULL : extra_v->data_device; + + // SoA q8_0/q4_0 -> reconstruct AoS for downstream kernels that expect + // tight records (no-op when k/v is already AoS). + ggml_cl_flash_attn_temp_buffer temp_k_aos; + ggml_cl_flash_attn_temp_buffer temp_v_aos; + ggml_cl_flash_attn_reconstruct_aos(backend_ctx, k, temp_k_aos, + k_data_device, offset_k, k_nb1, k_nb2, k_nb3); + ggml_cl_flash_attn_reconstruct_aos(backend_ctx, v, temp_v_aos, + v_data_device, offset_v, v_nb1, v_nb2, v_nb3); + + // currently FA kernels support KV cache with f16, f32, q4_0 and q8_0. + // there two cases that these kernels cannot cover, + // 1. KV cache types are q4_0 or q8_0, but the FA kernels fail to compile + // 2. KV cache types not currently supported by an FA kernel, e.g., q4_1 + // these two cases are supported here by dequantizing to f32/f16 and this + // causes performance degradation. + // For q4_0 or q8_0 cases that fail kernel compilation, dequant happens in GPU; + // for types that do not have FA kernels, dequant happens on host. + if (!use_native_q8_0_q1 && !use_native_q8_0 && + !use_native_q4_0_q1 && !use_native_q4_0) { + // for q4_0, q8_0 FA kernels that fail to compile + bool k_done = false; + bool v_done = false; + if (k->type == GGML_TYPE_Q8_0 || k->type == GGML_TYPE_Q4_0) { + k_done = ggml_cl_flash_attn_dequant_kv_gpu( + backend_ctx, k, kv_target_type, k_data_device, offset_k, k_nb1, k_nb2, k_nb3, + temp_k, k_data_device, offset_k, k_nb1, k_nb2, k_nb3); + } + if (v->type == GGML_TYPE_Q8_0 || v->type == GGML_TYPE_Q4_0) { + v_done = ggml_cl_flash_attn_dequant_kv_gpu( + backend_ctx, v, kv_target_type, v_data_device, offset_v, v_nb1, v_nb2, v_nb3, + temp_v, v_data_device, offset_v, v_nb1, v_nb2, v_nb3); + } + if (!k_done) { + ggml_cl_flash_attn_prepare_quantized_tensor( + backend_ctx, k, kv_target_type, temp_k, k_data_device, offset_k, k_nb1, k_nb2, k_nb3); + } + if (!v_done) { + ggml_cl_flash_attn_prepare_quantized_tensor( + backend_ctx, v, kv_target_type, temp_v, v_data_device, offset_v, v_nb1, v_nb2, v_nb3); + } + // Asymmetric KV on the F32 fallback path: convert the F16 side to F32 + // too. (Symmetric F16 / mixed paths handle F16 directly.) + if (kv_target_type == GGML_TYPE_F32 && !is_mixed && !is_f16) { + ggml_cl_flash_attn_convert_f16_to_f32(backend_ctx, k, temp_k, k_data_device, offset_k, k_nb1, k_nb2, k_nb3); + ggml_cl_flash_attn_convert_f16_to_f32(backend_ctx, v, temp_v, v_data_device, offset_v, v_nb1, v_nb2, v_nb3); + } + } + + cl_mem k_pad_buffer = NULL; + cl_mem v_pad_buffer = NULL; + cl_mem mask_pad_buffer = NULL; + cl_mem blk_buffer = NULL; + cl_ulong mask_pad_nb1 = 0; + cl_ulong mask_pad_nb2 = 0; + cl_ulong mask_pad_nb3 = 0; + + // Flash-Decoding K-split decision. Resolved here, before the prefill + // prepass, because KV-pad and blk prepass are pure overhead when FD fires. const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); + const int fd_max_n_q = (d_head_q <= FD_MAX_DK_MULTI) ? FD_MAX_N_Q_MULTI : 1; + cl_kernel fd_k_split = NULL; + if (n_q >= 1 && n_q <= fd_max_n_q && n_kv >= FD_MIN_N_KV && !is_causal && + d_head_q <= FD_MAX_DK && + backend_ctx->fa.f32_merge.count(dk_dv) > 0) { + if (is_mixed && backend_ctx->fa.f32_f16_q1_split.count(dk_dv) > 0) { + fd_k_split = backend_ctx->fa.f32_f16_q1_split.at(dk_dv); + } else if (is_q8_0 && backend_ctx->fa.f32_q8_0_q1_split.count(dk_dv) > 0) { + fd_k_split = backend_ctx->fa.f32_q8_0_q1_split.at(dk_dv); + } else if (is_q4_0 && backend_ctx->fa.f32_q4_0_q1_split.count(dk_dv) > 0) { + fd_k_split = backend_ctx->fa.f32_q4_0_q1_split.at(dk_dv); + } + } + const bool use_fd = (fd_k_split != NULL); + + const int n_q_blocks = n_q > 1 ? (n_q + block_m - 1) / block_m : 0; + const int n_kv_blocks = n_kv > 0 ? (n_kv + block_n - 1) / block_n : 0; + // KV pad + blk prepass are pure overhead when FD will fire — skip them. + const bool use_mixed_prepass = is_mixed && n_q > 1 && !use_fd; + const bool use_kv_pad = use_mixed_prepass && (n_kv % block_n != 0); + // blk prepass: per-KV-tile mask class (0=masked, 1=mixed, 2=unmasked). + // Consumed identically by f32_f16, q8_0 and q4_0 prefill kernels. + const bool use_quant_prepass = (use_native_q8_0 || use_native_q4_0) && !use_fd; + const bool use_blk_mask = (use_mixed_prepass || use_quant_prepass) && mask_buffer != NULL; + + if (use_kv_pad) { + cl_int err; + + const size_t k_pad_size = (size_t) k_nb1 * (size_t) block_n * (size_t) n_head_kv * (size_t) n_batch; + temp_k_pad.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, k_pad_size, NULL, &err); + CL_CHECK(err); + k_pad_buffer = temp_k_pad.data; + + const size_t v_pad_size = (size_t) v_nb1 * (size_t) block_n * (size_t) n_head_kv * (size_t) n_batch; + temp_v_pad.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, v_pad_size, NULL, &err); + CL_CHECK(err); + v_pad_buffer = temp_v_pad.data; + + cl_kernel kernel_kv_pad = backend_ctx->fa.kv_pad_f16.at(dk_dv); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 0, sizeof(cl_mem), &k_data_device)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 1, sizeof(cl_ulong), &offset_k)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 2, sizeof(cl_mem), &v_data_device)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 3, sizeof(cl_ulong), &offset_v)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 4, sizeof(cl_mem), &k_pad_buffer)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 5, sizeof(cl_mem), &v_pad_buffer)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 6, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 7, sizeof(int), &n_head_kv)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 8, sizeof(int), &n_batch)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 9, sizeof(cl_ulong), &k_nb1)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 10, sizeof(cl_ulong), &k_nb2)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 11, sizeof(cl_ulong), &k_nb3)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 12, sizeof(cl_ulong), &v_nb1)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 13, sizeof(cl_ulong), &v_nb2)); + CL_CHECK(clSetKernelArg(kernel_kv_pad, 14, sizeof(cl_ulong), &v_nb3)); + + size_t global_work_size[] = { (size_t) block_n, (size_t) n_head_kv, (size_t) n_batch }; + backend_ctx->enqueue_ndrange_kernel(kernel_kv_pad, 3, global_work_size, NULL, dst); + + if (mask_buffer != NULL) { + mask_pad_nb1 = (cl_ulong) block_n * (cl_ulong) sizeof(ggml_fp16_t); + mask_pad_nb2 = (cl_ulong) n_q * mask_pad_nb1; + mask_pad_nb3 = (cl_ulong) mask_ne2 * mask_pad_nb2; + + const size_t mask_pad_size = (size_t) mask_ne3 * (size_t) mask_pad_nb3; + temp_mask_pad.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, mask_pad_size, NULL, &err); + CL_CHECK(err); + mask_pad_buffer = temp_mask_pad.data; + + cl_kernel kernel_mask_pad = backend_ctx->fa.mask_pad_f16.at(dk_dv); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 0, sizeof(cl_mem), &mask_buffer)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 1, sizeof(cl_ulong), &offset_mask)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 2, sizeof(cl_mem), &mask_pad_buffer)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 3, sizeof(int), &n_q)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 4, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 5, sizeof(cl_ulong), &mask_nb1)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 6, sizeof(cl_ulong), &mask_nb2)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 7, sizeof(cl_ulong), &mask_nb3)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 8, sizeof(int), &mask_ne2)); + CL_CHECK(clSetKernelArg(kernel_mask_pad, 9, sizeof(int), &mask_ne3)); + + size_t global_work_size_mask[] = { (size_t) block_n, (size_t) n_q, (size_t) (mask_ne2 * mask_ne3) }; + backend_ctx->enqueue_ndrange_kernel(kernel_mask_pad, 3, global_work_size_mask, NULL, dst); + } + } + + if (use_blk_mask) { + cl_int err; + const size_t blk_size = (size_t) n_kv_blocks * (size_t) n_q_blocks * (size_t) mask_ne2 * (size_t) mask_ne3; + temp_blk.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, blk_size, NULL, &err); + if (err != CL_SUCCESS) { + // Flush before retry — reclaim deferred driver deallocations. + CL_CHECK(clFinish(backend_ctx->queue)); + temp_blk.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, blk_size, NULL, &err); + } + CL_CHECK(err); + blk_buffer = temp_blk.data; + + cl_kernel kernel_blk = backend_ctx->fa.blk_f16.at(dk_dv); + CL_CHECK(clSetKernelArg(kernel_blk, 0, sizeof(cl_mem), &mask_buffer)); + CL_CHECK(clSetKernelArg(kernel_blk, 1, sizeof(cl_ulong), &offset_mask)); + CL_CHECK(clSetKernelArg(kernel_blk, 2, sizeof(cl_mem), &blk_buffer)); + CL_CHECK(clSetKernelArg(kernel_blk, 3, sizeof(int), &n_q)); + CL_CHECK(clSetKernelArg(kernel_blk, 4, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(kernel_blk, 5, sizeof(cl_ulong), &mask_nb1)); + CL_CHECK(clSetKernelArg(kernel_blk, 6, sizeof(cl_ulong), &mask_nb2)); + CL_CHECK(clSetKernelArg(kernel_blk, 7, sizeof(cl_ulong), &mask_nb3)); + CL_CHECK(clSetKernelArg(kernel_blk, 8, sizeof(int), &mask_ne2)); + CL_CHECK(clSetKernelArg(kernel_blk, 9, sizeof(int), &mask_ne3)); + + size_t global_work_size_blk[] = { (size_t) n_kv_blocks, (size_t) n_q_blocks, (size_t) (mask_ne2 * mask_ne3) }; + backend_ctx->enqueue_ndrange_kernel(kernel_blk, 3, global_work_size_blk, NULL, dst); + } const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); + if (use_fd) { + int n_splits = (n_kv + FD_KV_PER_SPLIT - 1) / FD_KV_PER_SPLIT; + if (n_splits < FD_MIN_SPLITS) { + n_splits = FD_MIN_SPLITS; + } + if (n_splits > FD_MAX_SPLITS) { + n_splits = FD_MAX_SPLITS; + } + const int kv_per_split = (n_kv + n_splits - 1) / n_splits; + + const int fa_partial_floats = 2 + d_head_v; + const size_t partial_size_bytes = + (size_t) n_batch * n_head * n_q * n_splits * fa_partial_floats * sizeof(float); + + ggml_cl_flash_attn_temp_buffer temp_partial; + cl_int err; + temp_partial.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, + partial_size_bytes, NULL, &err); + if (err != CL_SUCCESS) { + CL_CHECK(clFinish(backend_ctx->queue)); + temp_partial.data = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, + partial_size_bytes, NULL, &err); + } + CL_CHECK(err); + + cl_kernel k_split = fd_k_split; + int argi = 0; + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_q)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &k_data_device)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_k)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &v_data_device)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_v)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_q)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb1)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb2)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &q_nb3)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb1)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb2)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &k_nb3)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb1)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb2)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &v_nb3)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head_log2_val)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(float), &logit_softcap)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_head_kv)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &mask_buffer)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &offset_mask)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb1)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb2)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_ulong), &mask_nb3)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &mask_ne2)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &mask_ne3)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(cl_mem), &temp_partial.data)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &n_splits)); + CL_CHECK(clSetKernelArg(k_split, argi++, sizeof(int), &kv_per_split)); + + const size_t fd_wg = 64; // matches Q1_WG_SIZE in the kernel + size_t fd_lws[3] = { fd_wg, 1, 1 }; + // gid(2) packs q_idx * n_splits + split_idx. + size_t fd_gws[3] = { fd_wg, (size_t)(n_head * n_batch), (size_t)(n_splits * n_q) }; + backend_ctx->enqueue_ndrange_kernel(k_split, 3, fd_gws, fd_lws, dst); + + cl_kernel k_merge = backend_ctx->fa.f32_merge.at(dk_dv); + argi = 0; + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &temp_partial.data)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &extra_o->data_device)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &offset_o)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_head)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_splits)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb1)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb2)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &o_nb3)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_mem), &sinks_buffer)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(cl_ulong), &offset_sinks)); + CL_CHECK(clSetKernelArg(k_merge, argi++, sizeof(int), &n_q)); + + const size_t merge_wg = (size_t) (d_head_v / 4); // one lane per float4 + size_t merge_lws[3] = { merge_wg, 1, 1 }; + size_t merge_gws[3] = { merge_wg, (size_t)(n_head * n_batch), (size_t) n_q }; + backend_ctx->enqueue_ndrange_kernel(k_merge, 3, merge_gws, merge_lws, dst); + return; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &k_data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &v_data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias)); CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0)); CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1)); @@ -11604,15 +12960,45 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer)); CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks)); + if (n_q > 1 && is_mixed) { + CL_CHECK(clSetKernelArg(kernel, 40, sizeof(cl_mem), &k_pad_buffer)); + CL_CHECK(clSetKernelArg(kernel, 41, sizeof(cl_mem), &v_pad_buffer)); + CL_CHECK(clSetKernelArg(kernel, 42, sizeof(cl_mem), &mask_pad_buffer)); + CL_CHECK(clSetKernelArg(kernel, 43, sizeof(cl_mem), &blk_buffer)); + CL_CHECK(clSetKernelArg(kernel, 44, sizeof(int), &n_kv_blocks)); + CL_CHECK(clSetKernelArg(kernel, 45, sizeof(cl_ulong), &mask_pad_nb1)); + CL_CHECK(clSetKernelArg(kernel, 46, sizeof(cl_ulong), &mask_pad_nb2)); + CL_CHECK(clSetKernelArg(kernel, 47, sizeof(cl_ulong), &mask_pad_nb3)); + } else if (use_native_q8_0 || use_native_q4_0) { + // arg 40 = blk classification buffer (NULL disables prepass opt). + CL_CHECK(clSetKernelArg(kernel, 40, sizeof(cl_mem), &blk_buffer)); + } if (n_q == 1) { const size_t wg_size = 64; size_t local_work_size[] = { wg_size, 1 }; size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + } else if (use_native_q8_0 || use_native_q4_0) { + // Native quant prefill. The split variant may override BLOCK_M + // (e.g. DK=96 quant uses BM=16). + const bool use_split = use_native_q8_0 ? use_split_q8_0 : use_split_q4_0; + int bm; + size_t wg_size; + if (use_split) { + bm = use_native_q8_0 ? backend_ctx->fa.f32_q8_0_split_bm.at(dk_dv) + : backend_ctx->fa.f32_q4_0_split_bm.at(dk_dv); + wg_size = use_native_q8_0 ? backend_ctx->fa.f32_q8_0_split_wg_size.at(dk_dv) + : backend_ctx->fa.f32_q4_0_split_wg_size.at(dk_dv); + } else { + bm = backend_ctx->fa.bm.at(dk_dv); + wg_size = (size_t) bm; + } + size_t local_work_size[] = { wg_size, 1 }; + size_t global_work_size[] = { (size_t)((n_q + bm - 1) / bm) * wg_size, (size_t)(n_head * n_batch) }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } else { - const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); - const size_t wg_size = block_m; + const size_t wg_size = (size_t) wg_size_fa; size_t local_work_size[] = { wg_size, 1 }; size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); @@ -13004,7 +14390,9 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + // SoA extra lives on view_src (view->extra is pre-SoA). + const ggml_tensor * soa0_src = src0->view_src != nullptr ? src0->view_src : src0; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)soa0_src->extra; cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; @@ -13756,6 +15144,122 @@ static void ggml_cl_mul_mat_q5_K_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +// Dequant a possibly-strided q4_0/q8_0 tensor to tight-packed f16. Returns a +// temp cl_mem the caller must release. SoA inputs are reconstructed into a +// temp AoS buffer reported via *extra_reconstruct (also caller-released). +// this is for quantized K cache without FA. +static cl_mem ggml_cl_mul_mat_dequant_quant_to_f16( + ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * tensor, + cl_mem * extra_reconstruct /* out, may be NULL */ +) { + GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0); + + if (extra_reconstruct) { + *extra_reconstruct = NULL; + } + + cl_mem src_buf; + cl_ulong src_offset; + cl_ulong src_nb1; + cl_ulong src_nb2; + cl_ulong src_nb3; + + const bool is_soa = tensor->type == GGML_TYPE_Q4_0 + ? ggml_cl_is_q4_0_soa(tensor) + : ggml_cl_is_q8_0_soa(tensor); + + if (is_soa) { + // Reconstruct full parent AoS; view's own nb[] then index it correctly. + const ggml_tensor * parent = tensor->view_src ? tensor->view_src : tensor; + const ggml_tensor * soa_src = parent; + const size_t block_bytes = (size_t) ggml_type_size(tensor->type); + const size_t blck_size = (size_t) ggml_blck_size(tensor->type); + const size_t parent_row_blocks = (size_t) parent->ne[0] / blck_size; + const size_t parent_row_bytes = parent_row_blocks * block_bytes; + const size_t parent_nbytes = (size_t) ggml_nelements(parent) / blck_size * block_bytes; + + cl_int err; + cl_mem aos = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, parent_nbytes, NULL, &err); + CL_CHECK(err); + + cl_kernel kernel; + if (tensor->type == GGML_TYPE_Q8_0) { + auto * extra = (ggml_tensor_extra_cl_q8_0 *) soa_src->extra; + kernel = backend_ctx->kernel_restore_block_q8_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &aos)); + } else { + auto * extra = (ggml_tensor_extra_cl_q4_0 *) soa_src->extra; + kernel = backend_ctx->kernel_restore_block_q4_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &aos)); + } + + const size_t n_blocks = parent_nbytes / block_bytes; + size_t gws_rec[] = { n_blocks, 1, 1 }; + size_t lws_rec[] = { 1, 1, 1 }; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, gws_rec, lws_rec, 0, NULL, NULL)); + + (void) parent_row_blocks; + (void) parent_row_bytes; + src_buf = aos; + src_offset = tensor->view_offs; + src_nb1 = tensor->nb[1]; + src_nb2 = tensor->nb[2]; + src_nb3 = tensor->nb[3]; + + if (extra_reconstruct) { + *extra_reconstruct = aos; + } else { + // OpenCL retains the memobj while queued kernels reference it. + CL_CHECK(clReleaseMemObject(aos)); + } + } else { + auto * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra && extra->data_device); + src_buf = extra->data_device; + src_offset = extra->offset + tensor->view_offs; + src_nb1 = tensor->nb[1]; + src_nb2 = tensor->nb[2]; + src_nb3 = tensor->nb[3]; + } + + const cl_int nblk0 = (cl_int) (tensor->ne[0] / ggml_blck_size(tensor->type)); + const cl_int ne1_ = (cl_int) tensor->ne[1]; + const cl_int ne2_ = (cl_int) tensor->ne[2]; + const cl_int ne3_ = (cl_int) tensor->ne[3]; + + const size_t out_bytes = (size_t) ggml_nelements(tensor) * sizeof(ggml_fp16_t); + + cl_int err; + cl_mem out = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, out_bytes, NULL, &err); + CL_CHECK(err); + + cl_kernel dq_kernel = tensor->type == GGML_TYPE_Q8_0 + ? backend_ctx->kernel_dequant_q8_0_f16_view_aos + : backend_ctx->kernel_dequant_q4_0_f16_view_aos; + + CL_CHECK(clSetKernelArg(dq_kernel, 0, sizeof(cl_mem), &src_buf)); + CL_CHECK(clSetKernelArg(dq_kernel, 1, sizeof(cl_ulong), &src_offset)); + CL_CHECK(clSetKernelArg(dq_kernel, 2, sizeof(cl_ulong), &src_nb1)); + CL_CHECK(clSetKernelArg(dq_kernel, 3, sizeof(cl_ulong), &src_nb2)); + CL_CHECK(clSetKernelArg(dq_kernel, 4, sizeof(cl_ulong), &src_nb3)); + CL_CHECK(clSetKernelArg(dq_kernel, 5, sizeof(cl_int), &nblk0)); + CL_CHECK(clSetKernelArg(dq_kernel, 6, sizeof(cl_int), &ne1_)); + CL_CHECK(clSetKernelArg(dq_kernel, 7, sizeof(cl_int), &ne2_)); + CL_CHECK(clSetKernelArg(dq_kernel, 8, sizeof(cl_int), &ne3_)); + CL_CHECK(clSetKernelArg(dq_kernel, 9, sizeof(cl_mem), &out)); + + size_t gws[3] = { (size_t) nblk0, (size_t) ne1_, (size_t) (ne2_ * ne3_) }; + size_t lws[3] = { 1, 1, 1 }; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, dq_kernel, 3, NULL, gws, lws, 0, NULL, NULL)); + + return out; +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -13770,6 +15274,31 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + // Non-contig quant src0: on-device dequant to f16 then native f16 MUL_MAT. + if ((src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q8_0) && !ggml_is_contiguous(src0)) { + cl_mem f16_buf = ggml_cl_mul_mat_dequant_quant_to_f16(backend_ctx, src0, nullptr); + + ggml_tensor fake_src0 = *src0; + ggml_tensor_extra_cl fake_extra = {}; + fake_extra.data_device = f16_buf; + fake_extra.offset = 0; + fake_src0.type = GGML_TYPE_F16; + fake_src0.extra = &fake_extra; + fake_src0.view_src = nullptr; + fake_src0.view_offs = 0; + fake_src0.nb[0] = sizeof(ggml_fp16_t); + fake_src0.nb[1] = fake_src0.nb[0] * src0->ne[0]; + fake_src0.nb[2] = fake_src0.nb[1] * src0->ne[1]; + fake_src0.nb[3] = fake_src0.nb[2] * src0->ne[2]; + + ggml_cl_mul_mat(backend, &fake_src0, src1, dst); + + // Safe to release now: OpenCL retains the memobj while queued + // kernels that reference it are still in flight. + CL_CHECK(clReleaseMemObject(f16_buf)); + return; + } + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -13779,16 +15308,19 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_ulong offsetd = extrad->offset + dst->view_offs; #ifdef GGML_OPENCL_SOA_Q - ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; - ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; - ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; - ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; - ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; - ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; - ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; - ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; - ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; + // view->extra stays pre-SoA; cast to the SoA struct would SIGSEGV. + // Follow view_src to reach the real SoA extra. + const ggml_tensor * soa0_src = src0->view_src != nullptr ? src0->view_src : src0; + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)soa0_src->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)soa0_src->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)soa0_src->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)soa0_src->extra; + ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)soa0_src->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)soa0_src->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)soa0_src->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)soa0_src->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)soa0_src->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)soa0_src->extra; #endif GGML_TENSOR_LOCALS(int, ne0, src0, ne); @@ -15543,15 +17075,18 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, GGML_UNUSED(offset0); #ifdef GGML_OPENCL_SOA_Q - ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; - ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; - ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; - ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; - ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; - ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; - ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; - ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + // SoA extra lives on view_src (view->extra is pre-SoA). + const ggml_tensor * soa0_src = src0->view_src != nullptr ? src0->view_src : src0; + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)soa0_src->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)soa0_src->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)soa0_src->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)soa0_src->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)soa0_src->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)soa0_src->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)soa0_src->extra; + ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)soa0_src->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)soa0_src->extra; + #endif // TODO: general MoE for the following types diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 226b127ab3be..82a1305592ed 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0( } } +// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path). +kernel void kernel_dequant_q8_0_f32_view_aos( + global char * src, + ulong src_offset, + ulong src_nb1, + ulong src_nb2, + ulong src_nb3, + int nblk0, + int ne1, + int ne2, + int ne3, + global float * dst +) { + int blk_i0 = get_global_id(0); + int i1 = get_global_id(1); + int batch = get_global_id(2); + + if (blk_i0 >= nblk0) return; + if (i1 >= ne1) return; + + int i2 = batch % ne2; + int i3 = batch / ne2; + if (i3 >= ne3) return; + + global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0); + float d = vload_half(0, (global half *)block); + global char * qs = block + 2; + + ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0; + global float * out = dst + (dst_row_base + blk_i0) * QK8_0; + + for (int i = 0; i < QK8_0; ++i) { + out[i] = d * (float)qs[i]; + } +} + +// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped. +kernel void kernel_dequant_q8_0_f16_view_aos( + global char * src, + ulong src_offset, + ulong src_nb1, + ulong src_nb2, + ulong src_nb3, + int nblk0, + int ne1, + int ne2, + int ne3, + global half * dst +) { + int blk_i0 = get_global_id(0); + int i1 = get_global_id(1); + int batch = get_global_id(2); + + if (blk_i0 >= nblk0) return; + if (i1 >= ne1) return; + + int i2 = batch % ne2; + int i3 = batch / ne2; + if (i3 >= ne3) return; + + global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0); + float d = vload_half(0, (global half *)block); + global char * qs = block + 2; + + ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0; + global half * out = dst + (dst_row_base + blk_i0) * QK8_0; + + for (int i = 0; i < QK8_0; ++i) { + out[i] = (half)(d * (float)qs[i]); + } +} + +// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant). +kernel void kernel_dequant_q4_0_f32_view_aos( + global char * src, + ulong src_offset, + ulong src_nb1, + ulong src_nb2, + ulong src_nb3, + int nblk0, + int ne1, + int ne2, + int ne3, + global float * dst +) { + int blk_i0 = get_global_id(0); + int i1 = get_global_id(1); + int batch = get_global_id(2); + + if (blk_i0 >= nblk0) return; + if (i1 >= ne1) return; + + int i2 = batch % ne2; + int i3 = batch / ne2; + if (i3 >= ne3) return; + + global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2); + float d = vload_half(0, (global half *)block); + global uchar * qs = (global uchar *)(block + 2); + + ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0; + global float * out = dst + (dst_row_base + blk_i0) * QK4_0; + + for (int i = 0; i < QK4_0/2; ++i) { + uchar byte = qs[i]; + int q0 = (int)(byte & 0x0F) - 8; + int q1 = (int)(byte >> 4) - 8; + out[i] = d * (float)q0; + out[i + QK4_0/2] = d * (float)q1; + } +} + +// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant). +kernel void kernel_dequant_q4_0_f16_view_aos( + global char * src, + ulong src_offset, + ulong src_nb1, + ulong src_nb2, + ulong src_nb3, + int nblk0, + int ne1, + int ne2, + int ne3, + global half * dst +) { + int blk_i0 = get_global_id(0); + int i1 = get_global_id(1); + int batch = get_global_id(2); + + if (blk_i0 >= nblk0) return; + if (i1 >= ne1) return; + + int i2 = batch % ne2; + int i3 = batch / ne2; + if (i3 >= ne3) return; + + global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2); + float d = vload_half(0, (global half *)block); + global uchar * qs = (global uchar *)(block + 2); + + ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0; + global half * out = dst + (dst_row_base + blk_i0) * QK4_0; + + for (int i = 0; i < QK4_0/2; ++i) { + uchar byte = qs[i]; + int q0 = (int)(byte & 0x0F) - 8; + int q1 = (int)(byte >> 4) - 8; + out[i] = (half)(d * (float)q0); + out[i + QK4_0/2] = (half)(d * (float)q1); + } +} + kernel void kernel_restore_block_q8_0_trans( global uchar * src_q, global half * src_d, diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl index 8f43c4f27d58..ec941b5f1022 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl @@ -4,14 +4,26 @@ #define ACC_TYPE4 float4 #define DATA_TYPE half #define DATA_TYPE4 half4 -#define CONVERT_ACC4(x) convert_float4(x) -#define CONVERT_DATA4(x) convert_half4(x) +#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3)) +#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3)) #define DK_VEC (DK/4) #define DV_VEC (DV/4) #define WG_SIZE (BLOCK_M) #define Q1_WG_SIZE 64 +// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs, +// infinite operand can cause undefined behavior and miscompilation for exp. +// Therefore, a large negative value is used instead. +#define FA_M_INIT (-3.0e38f) + +// Drop full unroll at DK>=192 — Adreno compiler host-memory budget. +#if DK >= 192 +#define FA_UNROLL +#else +#define FA_UNROLL _Pragma("unroll") +#endif + inline float get_alibi_slope( const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1 ) { @@ -81,18 +93,18 @@ __kernel void flash_attn_f16( if (my_query_row < n_q) { const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); - #pragma unroll + FA_UNROLL for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } } ACC_TYPE4 o_acc[DV_VEC]; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = (ACC_TYPE4)(0.0f); } - ACC_TYPE m_i = -INFINITY; + ACC_TYPE m_i = FA_M_INIT; ACC_TYPE l_i = 0.0f; float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); @@ -125,49 +137,72 @@ __kernel void flash_attn_f16( continue; } - for (int j = 0; j < BLOCK_N; j += 2) { + for (int j = 0; j < BLOCK_N; j += 4) { const int k_row0 = k_start + j; const int k_row1 = k_start + j + 1; + const int k_row2 = k_start + j + 2; + const int k_row3 = k_start + j + 3; ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); - #pragma unroll + ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f); + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { - dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); - dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + const ACC_TYPE4 qk = q_priv[k]; + dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2); + dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3); } - ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; - ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; + ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale; + ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale; if (is_causal) { - if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY; - if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY; + const int causal_limit = n_kv - n_q + my_query_row; + if (k_row0 > causal_limit) s0 = FA_M_INIT; + if (k_row1 > causal_limit) s1 = FA_M_INIT; + if (k_row2 > causal_limit) s2 = FA_M_INIT; + if (k_row3 > causal_limit) s3 = FA_M_INIT; } - - if (k_row0 >= n_kv) score0 = -INFINITY; - if (k_row1 >= n_kv) score1 = -INFINITY; + if (k_row0 >= n_kv) s0 = FA_M_INIT; + if (k_row1 >= n_kv) s1 = FA_M_INIT; + if (k_row2 >= n_kv) s2 = FA_M_INIT; + if (k_row3 >= n_kv) s3 = FA_M_INIT; if (mask_base != NULL) { const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); - if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; - if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2]; + if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3]; } if (logit_softcap > 0.0f) { - score0 = logit_softcap * tanh(score0 / logit_softcap); - score1 = logit_softcap * tanh(score1 / logit_softcap); + s0 = logit_softcap * tanh(s0 / logit_softcap); + s1 = logit_softcap * tanh(s1 / logit_softcap); + s2 = logit_softcap * tanh(s2 / logit_softcap); + s3 = logit_softcap * tanh(s3 / logit_softcap); } - const ACC_TYPE m_new = max(m_i, max(score0, score1)); - const ACC_TYPE p0 = exp(score0 - m_new); - const ACC_TYPE p1 = exp(score1 - m_new); - const ACC_TYPE scale_prev = exp(m_i - m_new); + const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3))); + const ACC_TYPE scale_prev = native_exp(m_i - m_new); + const ACC_TYPE p0 = native_exp(s0 - m_new); + const ACC_TYPE p1 = native_exp(s1 - m_new); + const ACC_TYPE p2 = native_exp(s2 - m_new); + const ACC_TYPE p3 = native_exp(s3 - m_new); - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { - o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]); + o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]), + mad(p2, CONVERT_ACC4(l_v[j+2][i]), + mad(p1, CONVERT_ACC4(l_v[j+1][i]), + mad(p0, CONVERT_ACC4(l_v[j][i]), + o_acc[i] * scale_prev)))); } - l_i = l_i * scale_prev + p0 + p1; + l_i = l_i * scale_prev + p0 + p1 + p2 + p3; m_i = m_new; } } @@ -179,7 +214,7 @@ __kernel void flash_attn_f16( const ACC_TYPE m_final = max(m_i, m_sink); const ACC_TYPE scale_o = exp(m_i - m_final); - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_acc[i] *= scale_o; } @@ -191,12 +226,12 @@ __kernel void flash_attn_f16( global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_i; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv); } } else { - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_row[i] = (DATA_TYPE4)(0.0f); } @@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1( ACC_TYPE4 q_priv[DK_VEC]; const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); - #pragma unroll + FA_UNROLL for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } @@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1( sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); } - ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY; + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); - #pragma unroll + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1( __local ACC_TYPE local_m[Q1_WG_SIZE]; local_m[tid] = m_i; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); barrier(CLK_LOCAL_MEM_FENCE); @@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1( const ACC_TYPE m_final = local_m[0]; ACC_TYPE4 o_acc[DV_VEC]; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); ACC_TYPE l_i = 0.0f; @@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1( const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); - #pragma unroll + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1( } const ACC_TYPE p = exp(score - m_final); l_i += p; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; i++) { o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); } @@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1( __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; local_l[tid] = l_i; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_l[tid] += local_l[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1( for (int i = 0; i < DV_VEC; i++) { local_o_comp[tid] = o_acc[i]; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1( } } } else if (tid == 0) { - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f); } } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index a6d747903751..2547731c3779 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -13,6 +13,18 @@ #define WG_SIZE (BLOCK_M) #define Q1_WG_SIZE 64 +// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs, +// infinite operand can cause undefined behavior and miscompilation for exp. +// Therefore, a large negative value is used instead. +#define FA_M_INIT (-3.0e38f) + +// Drop full unroll at DK>=192 — Adreno compiler host-memory budget. +#if DK >= 192 +#define FA_UNROLL +#else +#define FA_UNROLL _Pragma("unroll") +#endif + inline float get_alibi_slope( const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1 ) { @@ -82,18 +94,18 @@ __kernel void flash_attn_f32( if (my_query_row < n_q) { const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); - #pragma unroll + FA_UNROLL for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } } ACC_TYPE4 o_acc[DV_VEC]; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = (ACC_TYPE4)(0.0f); } - ACC_TYPE m_i = -INFINITY; + ACC_TYPE m_i = FA_M_INIT; ACC_TYPE l_i = 0.0f; float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); @@ -126,49 +138,72 @@ __kernel void flash_attn_f32( continue; } - for (int j = 0; j < BLOCK_N; j += 2) { + for (int j = 0; j < BLOCK_N; j += 4) { const int k_row0 = k_start + j; const int k_row1 = k_start + j + 1; + const int k_row2 = k_start + j + 2; + const int k_row3 = k_start + j + 3; ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); - #pragma unroll + ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f); + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { - dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); - dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + const ACC_TYPE4 qk = q_priv[k]; + dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2); + dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3); } - ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; - ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; + ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale; + ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale; if (is_causal) { - if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY; - if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY; + const int causal_limit = n_kv - n_q + my_query_row; + if (k_row0 > causal_limit) s0 = FA_M_INIT; + if (k_row1 > causal_limit) s1 = FA_M_INIT; + if (k_row2 > causal_limit) s2 = FA_M_INIT; + if (k_row3 > causal_limit) s3 = FA_M_INIT; } - - if (k_row0 >= n_kv) score0 = -INFINITY; - if (k_row1 >= n_kv) score1 = -INFINITY; + if (k_row0 >= n_kv) s0 = FA_M_INIT; + if (k_row1 >= n_kv) s1 = FA_M_INIT; + if (k_row2 >= n_kv) s2 = FA_M_INIT; + if (k_row3 >= n_kv) s3 = FA_M_INIT; if (mask_base != NULL) { const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); - if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; - if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2]; + if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3]; } if (logit_softcap > 0.0f) { - score0 = logit_softcap * tanh(score0 / logit_softcap); - score1 = logit_softcap * tanh(score1 / logit_softcap); + s0 = logit_softcap * tanh(s0 / logit_softcap); + s1 = logit_softcap * tanh(s1 / logit_softcap); + s2 = logit_softcap * tanh(s2 / logit_softcap); + s3 = logit_softcap * tanh(s3 / logit_softcap); } - const ACC_TYPE m_new = max(m_i, max(score0, score1)); - const ACC_TYPE p0 = exp(score0 - m_new); - const ACC_TYPE p1 = exp(score1 - m_new); - const ACC_TYPE scale_prev = exp(m_i - m_new); + const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3))); + const ACC_TYPE scale_prev = native_exp(m_i - m_new); + const ACC_TYPE p0 = native_exp(s0 - m_new); + const ACC_TYPE p1 = native_exp(s1 - m_new); + const ACC_TYPE p2 = native_exp(s2 - m_new); + const ACC_TYPE p3 = native_exp(s3 - m_new); - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { - o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]); + o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]), + mad(p2, CONVERT_ACC4(l_v[j+2][i]), + mad(p1, CONVERT_ACC4(l_v[j+1][i]), + mad(p0, CONVERT_ACC4(l_v[j][i]), + o_acc[i] * scale_prev)))); } - l_i = l_i * scale_prev + p0 + p1; + l_i = l_i * scale_prev + p0 + p1 + p2 + p3; m_i = m_new; } } @@ -180,7 +215,7 @@ __kernel void flash_attn_f32( const ACC_TYPE m_final = max(m_i, m_sink); const ACC_TYPE scale_o = exp(m_i - m_final); - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_acc[i] *= scale_o; } @@ -192,12 +227,12 @@ __kernel void flash_attn_f32( global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_i; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv); } } else { - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_row[i] = (DATA_TYPE4)(0.0f); } @@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1( ACC_TYPE4 q_priv[DK_VEC]; const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); - #pragma unroll + FA_UNROLL for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } @@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1( sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); } - ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY; + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); - #pragma unroll + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1( __local ACC_TYPE local_m[Q1_WG_SIZE]; local_m[tid] = m_i; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); barrier(CLK_LOCAL_MEM_FENCE); @@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1( const ACC_TYPE m_final = local_m[0]; ACC_TYPE4 o_acc[DV_VEC]; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); ACC_TYPE l_i = 0.0f; @@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1( const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); - #pragma unroll + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1( } const ACC_TYPE p = exp(score - m_final); l_i += p; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; i++) { o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); } @@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1( __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; local_l[tid] = l_i; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_l[tid] += local_l[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1( for (int i = 0; i < DV_VEC; i++) { local_o_comp[tid] = o_acc[i]; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1( } } } else if (tid == 0) { - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f); } } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl index ec7361b9e370..a7f1de325c81 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl @@ -1,5 +1,13 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#ifdef cl_khr_subgroup_shuffle +#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable +#define HAS_SUBGROUP_SHUFFLE 1 +#elif defined(cl_qcom_subgroup_shuffle) +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#define HAS_SUBGROUP_SHUFFLE 1 +#endif + #define ACC_TYPE float #define ACC_TYPE4 float4 #define Q_DATA_TYPE4 float4 @@ -12,9 +20,34 @@ #define DK_VEC (DK/4) #define DV_VEC (DV/4) -#define WG_SIZE (BLOCK_M) #define Q1_WG_SIZE 64 +// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs, +// infinite operand can cause undefined behavior and miscompilation for exp. +// Therefore, a large negative value is used instead. +#define FA_M_INIT (-3.0e38f) + +// Drop full unroll at DK>=192 — Adreno compiler host-memory budget. +#if DK >= 192 +#define FA_UNROLL +#else +#define FA_UNROLL _Pragma("unroll") +#endif + +// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use. +#ifndef N_SPLIT +#define N_SPLIT 1 +#endif + +#define SPLIT_DK_VEC (DK_VEC / N_SPLIT) +#define SPLIT_DV_VEC (DV_VEC / N_SPLIT) + +#if N_SPLIT > 1 +#define WG_SIZE (BLOCK_M * N_SPLIT) +#else +#define WG_SIZE (BLOCK_M) +#endif + inline float get_alibi_slope( const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1 ) { @@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16( const int mask_ne2, const int mask_ne3, const global void* sinks_void, - const ulong sinks_offset + const ulong sinks_offset, + const global void * k_pad_void, + const global void * v_pad_void, + const global void * mask_pad_void, + const global char * blk, + const int n_kv_blocks, + const ulong mask_pad_nb1, + const ulong mask_pad_nb2, + const ulong mask_pad_nb3 ) { const int tid = get_local_id(0); const int block_q_idx = get_group_id(0); const int head_batch_idx = get_global_id(1); - const int my_query_row = block_q_idx * BLOCK_M + tid; +#if N_SPLIT > 1 + const int q_lane = tid / N_SPLIT; + const int split_idx = tid % N_SPLIT; +#else + const int q_lane = tid; + const int split_idx = 0; +#endif + + const int my_query_row = block_q_idx * BLOCK_M + q_lane; + const int query_valid = my_query_row < n_q; const int batch_idx = head_batch_idx / n_head; const int head_idx = head_batch_idx % n_head; const int gqa_ratio = n_head / n_head_kv; const int head_kv_idx = head_idx / gqa_ratio; + const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0; + const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0; const global char* q_base = (const global char*)q_void + q_offset; const global char* k_base = (const global char*)k_void + k_offset; @@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16( const global char* mask_base = NULL; if (mask_void != NULL) { - const int mask_head_idx = head_idx % mask_ne2; - const int mask_batch_idx = batch_idx % mask_ne3; mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; } + const global char* mask_pad_base = NULL; + if (mask_pad_void != NULL) { + mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2; + } + const global char* blk_base = NULL; + if (blk != NULL) { + const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M; + blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks; + } - ACC_TYPE4 q_priv[DK_VEC]; - if (my_query_row < n_q) { + ACC_TYPE4 q_priv[SPLIT_DK_VEC]; + const int dk_off = split_idx * SPLIT_DK_VEC; + if (query_valid) { const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); - #pragma unroll - for (int i = 0; i < DK_VEC; ++i) { - q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + FA_UNROLL + for (int i = 0; i < SPLIT_DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]); + } + } else { + FA_UNROLL + for (int i = 0; i < SPLIT_DK_VEC; ++i) { + q_priv[i] = (ACC_TYPE4)(0.0f); } } - ACC_TYPE4 o_acc[DV_VEC]; - #pragma unroll - for (int i = 0; i < DV_VEC; ++i) { + ACC_TYPE4 o_acc[SPLIT_DV_VEC]; + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { o_acc[i] = (ACC_TYPE4)(0.0f); } - ACC_TYPE m_i = -INFINITY; + + ACC_TYPE m_i = FA_M_INIT; ACC_TYPE l_i = 0.0f; float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); @@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16( __local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; __local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; +#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE) + __local ACC_TYPE local_partial[BLOCK_N][WG_SIZE]; + __local ACC_TYPE local_p[BLOCK_M][BLOCK_N]; + __local ACC_TYPE local_softmax_scale[BLOCK_M]; + __local ACC_TYPE local_l_inv[BLOCK_M]; +#endif + for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) { + char blk_cur = 1; + if (blk_base != NULL) { + blk_cur = blk_base[k_start / BLOCK_N]; + if (blk_cur == 0) continue; + } + + const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv; + const int k_tile_start = use_kv_pad ? 0 : k_start; + const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2; + const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3; + const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2; + const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3; + const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base; + const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base; + for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) { const int row = i / DK_VEC; const int col = i % DK_VEC; - const int k_row_idx = k_start + row; - if (k_row_idx < n_kv) { - const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; - l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col]; + const int k_row_idx = k_tile_start + row; + if (use_kv_pad || k_row_idx < n_kv) { + const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1; + l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col]; + } else { + l_k[row][col] = (KV_DATA_TYPE4)(0.0h); } } for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) { const int row = i / DV_VEC; const int col = i % DV_VEC; - const int v_row_idx = k_start + row; - if (v_row_idx < n_kv) { - const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; - l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col]; + const int v_row_idx = k_tile_start + row; + if (use_kv_pad || v_row_idx < n_kv) { + const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1; + l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col]; + } else { + l_v[row][col] = (KV_DATA_TYPE4)(0.0h); } } barrier(CLK_LOCAL_MEM_FENCE); - if (my_query_row >= n_q) { - continue; +#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE) + { + const int dv_off = split_idx * SPLIT_DV_VEC; + for (int j = 0; j < BLOCK_N; j += 2) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + + ACC_TYPE partial0 = 0.0f; + ACC_TYPE partial1 = 0.0f; + FA_UNROLL + for (int k = 0; k < SPLIT_DK_VEC; k++) { + const ACC_TYPE4 qk = q_priv[k]; + ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]); + ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]); + partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3; + partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3; + } + + FA_UNROLL + for (int step = 1; step < N_SPLIT; step <<= 1) { + partial0 += sub_group_shuffle_xor(partial0, step); + partial1 += sub_group_shuffle_xor(partial1, step); + } + + ACC_TYPE score0 = partial0 * scale; + ACC_TYPE score1 = partial1 * scale; + + if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; } + if (is_causal) { + if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT; + if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT; + } + if (k_row0 >= n_kv) score0 = FA_M_INIT; + if (k_row1 >= n_kv) score1 = FA_M_INIT; + + if (query_valid && mask_base != NULL && blk_cur != 2) { + if (use_kv_pad && mask_pad_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = + (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1); + score0 += slope * (ACC_TYPE)mask_ptr[j]; + score1 += slope * (ACC_TYPE)mask_ptr[j + 1]; + } else { + const global MASK_DATA_TYPE* mask_ptr = + (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + } + } + + if (logit_softcap > 0.0f) { + score0 = logit_softcap * tanh(score0 / logit_softcap); + score1 = logit_softcap * tanh(score1 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(score0, score1)); + // Whole tile masked (m_new == FA_M_INIT): force the exp() args + // far negative so the tile contributes 0, not exp(0)=1. + const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new; + const ACC_TYPE sp = native_exp(m_i - m_exp); + const ACC_TYPE p0 = native_exp(score0 - m_exp); + const ACC_TYPE p1 = native_exp(score1 - m_exp); + + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_acc[i] = o_acc[i] * sp + + p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i]) + + p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]); + } + l_i = l_i * sp + p0 + p1; + m_i = m_new; + } } - - for (int j = 0; j < BLOCK_N; j += 2) { - const int k_row0 = k_start + j; - const int k_row1 = k_start + j + 1; - - ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); - ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); - #pragma unroll - for (int k = 0; k < DK_VEC; k++) { - dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0); - dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1); +#elif N_SPLIT > 1 + // N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction. + // Phase 1 — partial dots for all BLOCK_N tokens. + for (int j = 0; j < BLOCK_N; ++j) { + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + FA_UNROLL + for (int k = 0; k < SPLIT_DK_VEC; k++) { + dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc); } - ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; - ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; - - if (is_causal) { - if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY; - if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY; + local_partial[j][tid] = + dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3; + } + barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible + + // Phase 2 — split_idx==0 reduces partial sums and computes block softmax. + if (split_idx == 0) { + if (query_valid) { + ACC_TYPE m_new = m_i; + for (int j = 0; j < BLOCK_N; ++j) { + const int k_row = k_start + j; + ACC_TYPE score = 0.0f; + FA_UNROLL + for (int s = 0; s < N_SPLIT; s++) { + score += local_partial[j][q_lane * N_SPLIT + s]; + } + score *= scale; + + if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT; + if (k_row >= n_kv) score = FA_M_INIT; + + if (mask_base != NULL && blk_cur != 2) { + if (use_kv_pad && mask_pad_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = + (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1); + score += slope * (ACC_TYPE)mask_ptr[j]; + } else { + const global MASK_DATA_TYPE* mask_ptr = + (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row]; + } + } + + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + + m_new = max(m_new, score); + local_p[q_lane][j] = score; + } + + const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new; + const ACC_TYPE sp = native_exp(m_i - m_exp); + ACC_TYPE l_new = l_i * sp; + for (int j = 0; j < BLOCK_N; ++j) { + const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp); + local_p[q_lane][j] = p; + l_new += p; + } + local_softmax_scale[q_lane] = sp; + l_i = l_new; + m_i = m_new; + } else { + local_softmax_scale[q_lane] = 1.0f; + for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f; } + } + barrier(CLK_LOCAL_MEM_FENCE); - if (k_row0 >= n_kv) score0 = -INFINITY; - if (k_row1 >= n_kv) score1 = -INFINITY; - - if (mask_base != NULL) { - const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); - if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; - if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + // Phase 3 — V accumulate using broadcast probabilities. + { + const ACC_TYPE sp_block = local_softmax_scale[q_lane]; + const int dv_off = split_idx * SPLIT_DV_VEC; + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_acc[i] *= sp_block; } - - if (logit_softcap > 0.0f) { - score0 = logit_softcap * tanh(score0 / logit_softcap); - score1 = logit_softcap * tanh(score1 / logit_softcap); + for (int j = 0; j < BLOCK_N; ++j) { + const ACC_TYPE p = local_p[q_lane][j]; + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]); + } } + } +#else + // N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0. + if (query_valid) { + for (int j = 0; j < BLOCK_N; j += 4) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + const int k_row2 = k_start + j + 2; + const int k_row3 = k_start + j + 3; + + ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f); + FA_UNROLL + for (int k = 0; k < DK_VEC; k++) { + const ACC_TYPE4 qk = q_priv[k]; + dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1); + dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2); + dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3); + } + ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; + ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale; + ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale; + + if (is_causal) { + const int causal_limit = n_kv - n_q + my_query_row; + if (k_row0 > causal_limit) s0 = FA_M_INIT; + if (k_row1 > causal_limit) s1 = FA_M_INIT; + if (k_row2 > causal_limit) s2 = FA_M_INIT; + if (k_row3 > causal_limit) s3 = FA_M_INIT; + } + if (k_row0 >= n_kv) s0 = FA_M_INIT; + if (k_row1 >= n_kv) s1 = FA_M_INIT; + if (k_row2 >= n_kv) s2 = FA_M_INIT; + if (k_row3 >= n_kv) s3 = FA_M_INIT; + + if (mask_base != NULL && blk_cur != 2) { + if (use_kv_pad && mask_pad_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1); + s0 += slope * (ACC_TYPE)mask_ptr[j]; + s1 += slope * (ACC_TYPE)mask_ptr[j + 1]; + s2 += slope * (ACC_TYPE)mask_ptr[j + 2]; + s3 += slope * (ACC_TYPE)mask_ptr[j + 3]; + } else { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2]; + if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3]; + } + } + + if (logit_softcap > 0.0f) { + s0 = logit_softcap * tanh(s0 / logit_softcap); + s1 = logit_softcap * tanh(s1 / logit_softcap); + s2 = logit_softcap * tanh(s2 / logit_softcap); + s3 = logit_softcap * tanh(s3 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3))); + // Whole tile masked (m_new == FA_M_INIT): force the exp() args + // far negative so the tile contributes 0, not exp(0)=1. + const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new; + const ACC_TYPE scale_prev = native_exp(m_i - m_exp); + const ACC_TYPE p0 = native_exp(s0 - m_exp); + const ACC_TYPE p1 = native_exp(s1 - m_exp); + const ACC_TYPE p2 = native_exp(s2 - m_exp); + const ACC_TYPE p3 = native_exp(s3 - m_exp); + + FA_UNROLL + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]), + mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]), + mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]), + mad(p0, CONVERT_KV_ACC4(l_v[j][i]), + o_acc[i] * scale_prev)))); + } + l_i = l_i * scale_prev + p0 + p1 + p2 + p3; + m_i = m_new; + } + } +#endif + // End of tile: every thread must finish reading l_k/l_v before the + // next iteration's load overwrites them (WAR hazard on local memory). + barrier(CLK_LOCAL_MEM_FENCE); + } - const ACC_TYPE m_new = max(m_i, max(score0, score1)); - const ACC_TYPE p0 = exp(score0 - m_new); - const ACC_TYPE p1 = exp(score1 - m_new); - const ACC_TYPE scale_prev = exp(m_i - m_new); - - #pragma unroll - for (int i = 0; i < DV_VEC; ++i) { - o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]); + // Write output. +#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE) + if (query_valid) { + ACC_TYPE sinks_sp = 1.0f; + if (sinks_void != NULL) { + const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + sinks_sp = exp(m_i - m_final); + l_i = l_i * sinks_sp + exp(m_sink - m_final); + m_i = m_final; + } + const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f; + const int dv_off = split_idx * SPLIT_DV_VEC; + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); + if (l_inv > 0.0f) { + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv); } - l_i = l_i * scale_prev + p0 + p1; - m_i = m_new; + } else { + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f); + } + } + } +#elif N_SPLIT > 1 + if (split_idx == 0) { + ACC_TYPE sinks_sp = 1.0f; + if (query_valid && sinks_void != NULL) { + const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + sinks_sp = exp(m_i - m_final); + l_i = l_i * sinks_sp + exp(m_sink - m_final); + m_i = m_final; } + local_softmax_scale[q_lane] = sinks_sp; + local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f; } + barrier(CLK_LOCAL_MEM_FENCE); - if (my_query_row < n_q) { + if (query_valid) { + const ACC_TYPE sinks_sp = local_softmax_scale[q_lane]; + const ACC_TYPE l_inv = local_l_inv[q_lane]; + const int dv_off = split_idx * SPLIT_DV_VEC; + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); + if (l_inv > 0.0f) { + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv); + } + } else { + FA_UNROLL + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f); + } + } + } +#else + if (query_valid) { if (sinks_void != NULL) { const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); const ACC_TYPE m_sink = sinks_ptr[head_idx]; const ACC_TYPE m_final = max(m_i, m_sink); const ACC_TYPE scale_o = exp(m_i - m_final); - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_acc[i] *= scale_o; } @@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16( global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_i; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv); } } else { - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) { o_row[i] = (O_DATA_TYPE4)(0.0f); } } } +#endif } __kernel void flash_attn_f32_f16_q1( @@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1( mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; } - ACC_TYPE4 q_priv[DK_VEC]; + // Q is uniform across WG threads (n_q=1). Share via local memory to + // avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that + // spills to DDR on Adreno. + __local ACC_TYPE4 q_shared[DK_VEC]; const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); - #pragma unroll - for (int i = 0; i < DK_VEC; ++i) { - q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) { + q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]); } + barrier(CLK_LOCAL_MEM_FENCE); float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); @@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1( sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); } - ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY; + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); - #pragma unroll + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { - dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1( __local ACC_TYPE local_m[Q1_WG_SIZE]; local_m[tid] = m_i; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); barrier(CLK_LOCAL_MEM_FENCE); @@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1( const ACC_TYPE m_final = local_m[0]; ACC_TYPE4 o_acc[DV_VEC]; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); ACC_TYPE l_i = 0.0f; @@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1( const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); - #pragma unroll + FA_UNROLL for (int k = 0; k < DK_VEC; k++) { - dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1( } const ACC_TYPE p = exp(score - m_final); l_i += p; - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; i++) { o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]); } @@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1( __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; local_l[tid] = l_i; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_l[tid] += local_l[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1( for (int i = 0; i < DV_VEC; i++) { local_o_comp[tid] = o_acc[i]; barrier(CLK_LOCAL_MEM_FENCE); - #pragma unroll + FA_UNROLL for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1( } } } else if (tid == 0) { - #pragma unroll + FA_UNROLL for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f); } } + +// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx. +// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm. +#define FA_PARTIAL_FLOATS (2 + DV) + +__kernel void flash_attn_f32_f16_q1_split( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + const float scale, + const int n_q, + const int n_kv, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void * mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + global float * partial_void, + const int n_splits, + const int kv_per_split +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + const int split_q_idx = get_global_id(2); + const int split_idx = split_q_idx % n_splits; + const int q_idx = split_q_idx / n_splits; + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const int kv_start = split_idx * kv_per_split; + const int kv_end = min(kv_start + kv_per_split, n_kv); + + const ulong record_stride = (ulong) FA_PARTIAL_FLOATS; + const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) + * n_splits + split_idx); + global float * rec = partial_void + record_idx * record_stride; + global float4 * rec_o = (global float4 *) (rec + 2); + + if (kv_start >= kv_end) { + // Empty split: leave sentinel partial for merge. + if (tid == 0) { + rec[0] = FA_M_INIT; + rec[1] = 0.0f; + } + return; + } + + const global char * q_base = (const global char *) q_void + q_offset; + const global char * k_base = (const global char *) k_void + k_offset; + const global char * v_base = (const global char *) v_void + v_offset; + + const global char * mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char *) mask_void + mask_offset + + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 + + (ulong) q_idx * mask_nb1; + } + + // Share Q via local memory (n_q=1 per split -> uniform across WG). + __local ACC_TYPE4 q_shared[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1; + const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset); + for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) { + q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + barrier(CLK_LOCAL_MEM_FENCE); + + const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + + // Pass 1a — split-local max. + ACC_TYPE m_i = FA_M_INIT; + for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll + for (int k = 0; k < DK_VEC; ++k) { + dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base); + score += slope * (ACC_TYPE) mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + m_i = max(m_i, score); + } + + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_c = local_m[0]; + + // Pass 1b — softmax-weighted V accumulate. + ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + ACC_TYPE l_i = 0.0f; + + for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset); + const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll + for (int k = 0; k < DK_VEC; ++k) { + dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base); + score += slope * (ACC_TYPE) mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + const ACC_TYPE p = exp(score - m_c); + l_i += p; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]); + } + } + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE l_c = local_l[0]; + + if (tid == 0) { + rec[0] = (float) m_c; + rec[1] = (float) l_c; + } + for (int i = 0; i < DV_VEC; ++i) { + local_o[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o[tid] += local_o[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + rec_o[i] = local_o[0]; + } + } +} + +// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0. +__kernel void flash_attn_f32_merge( + const global float * partial_void, + global void * o_void, + const ulong o_offset, + const int n_head, + const int n_splits, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const global void * sinks_void, + const ulong sinks_offset, + const int n_q +) { + const int lane = get_local_id(0); // 0..DV_VEC-1 + const int head_batch_idx = get_global_id(1); + const int q_idx = get_global_id(2); + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const ulong record_stride = (ulong) FA_PARTIAL_FLOATS; + const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits; + const global float * rec0 = partial_void + record_idx_0 * record_stride; + + __local ACC_TYPE m_final_shared; + __local ACC_TYPE l_final_shared; + if (lane == 0) { + ACC_TYPE m = FA_M_INIT; + for (int c = 0; c < n_splits; ++c) { + const ACC_TYPE m_c = rec0[c * record_stride + 0]; + m = max(m, m_c); + } + ACC_TYPE m_sink = 0.0f; + bool has_sink = false; + if (sinks_void != NULL) { + const global ACC_TYPE * sinks_ptr = + (const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset); + m_sink = sinks_ptr[head_idx]; + has_sink = true; + m = max(m, m_sink); + } + ACC_TYPE l = 0.0f; + for (int c = 0; c < n_splits; ++c) { + const ACC_TYPE m_c = rec0[c * record_stride + 0]; + const ACC_TYPE l_c = rec0[c * record_stride + 1]; + if (m_c > FA_M_INIT) { + l += l_c * exp(m_c - m); + } + } + if (has_sink) { + l += exp(m_sink - m); + } + m_final_shared = m; + l_final_shared = l; + } + barrier(CLK_LOCAL_MEM_FENCE); + const ACC_TYPE m_final = m_final_shared; + const ACC_TYPE l_final = l_final_shared; + const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f; + + ACC_TYPE4 o = (ACC_TYPE4)(0.0f); + for (int c = 0; c < n_splits; ++c) { + const global float * rec_c = rec0 + c * record_stride; + const ACC_TYPE m_c = rec_c[0]; + if (m_c <= FA_M_INIT) continue; + const global float4 * rec_oc = (const global float4 *) (rec_c + 2); + const ACC_TYPE scale_c = exp(m_c - m_final); + o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o); + } + o = o * l_inv; + + const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1; + global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset); + o_row[lane] = CONVERT_O_DATA4(o); +} diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_q4_0.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_q4_0.cl new file mode 100644 index 000000000000..36167ba543b5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_q4_0.cl @@ -0,0 +1,1041 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#ifdef cl_khr_integer_dot_product +#pragma OPENCL EXTENSION cl_khr_integer_dot_product : enable +#define FA_HAVE_INT_DOT 1 +#endif + +#ifdef cl_khr_subgroup_shuffle +#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable +#define HAS_SUBGROUP_SHUFFLE 1 +#elif defined(cl_qcom_subgroup_shuffle) +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#define HAS_SUBGROUP_SHUFFLE 1 +#endif + +// Flash attention: Q=f32, K=q4_0, V=q4_0. +// Block = half d + uchar qs[16]; qs[j] low/high nibble -> elem j / j+16. +// Dequant: val[i] = d * (nibble_i - 8). dp4a path runs on raw 0..15 nibbles +// and applies the -8*sum(q) correction once per block (needs Q q_sum). + +#define ACC_TYPE float +#define ACC_TYPE4 float4 +#define Q_DATA_TYPE4 float4 +#define O_DATA_TYPE4 float4 +#define MASK_DATA_TYPE half +#define CONVERT_Q_ACC4(x) (x) +#define CONVERT_O_DATA4(x) (x) + +#define DK_VEC (DK/4) +#define DV_VEC (DV/4) +#define Q1_WG_SIZE 64 + +// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs, +// infinite operand can cause undefined behavior and miscompilation for exp. +// Therefore, a large negative value is used instead. +#define FA_M_INIT (-3.0e38f) + +#define QK4_0 32 +#define Q4_0_BLOCK_SIZE 18 + +#define DK_Q4_BLOCKS (DK / QK4_0) +#define DV_Q4_BLOCKS (DV / QK4_0) + +inline float dot_q4_0_f32(const global char * block_ptr, ACC_TYPE4 * q_slice) { + float d = vload_half(0, (const global half *)block_ptr); + const global uchar * qs = (const global uchar *)(block_ptr + 2); + + float sum = 0.0f; + // Low nibbles -> elems 0..15. + #pragma unroll + for (int g = 0; g < 4; ++g) { + float4 nv = (float4)((float)(int)(qs[g*4 + 0] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 1] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 2] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 3] & 0x0F) - 8.0f); + sum += dot(q_slice[g], nv); + } + // High nibbles -> elems 16..31. + #pragma unroll + for (int g = 0; g < 4; ++g) { + float4 nv = (float4)((float)(int)(qs[g*4 + 0] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 1] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 2] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 3] >> 4) - 8.0f); + sum += dot(q_slice[4 + g], nv); + } + return sum * d; +} + +#ifdef FA_HAVE_INT_DOT +inline uint pack_i8x4(char a, char b, char c, char d) { + return ((uint)(uchar)a) | + ((uint)(uchar)b) << 8 | + ((uint)(uchar)c) << 16 | + ((uint)(uchar)d) << 24; +} + +// Returns (qd, q_sum); q_sum feeds the -8*sum(q) bias correction. +typedef struct { + float qd; + int q_sum; +} q4_q_block_info; + +inline q4_q_block_info quant_q_block_int8_packed_q4(const ACC_TYPE4 * q_block, + uint * out_packed) { + float amax = 0.0f; + #pragma unroll + for (int i = 0; i < 8; ++i) { + float4 av = fabs(q_block[i]); + amax = fmax(amax, fmax(fmax(av.s0, av.s1), fmax(av.s2, av.s3))); + } + float qd = amax / 127.0f; + float qid = (amax > 0.0f) ? 127.0f / amax : 0.0f; + + int q_sum = 0; + #pragma unroll + for (int i = 0; i < 8; ++i) { + float4 v = q_block[i] * qid; + char a = (char)((int)round(v.s0)); + char b = (char)((int)round(v.s1)); + char c = (char)((int)round(v.s2)); + char d = (char)((int)round(v.s3)); + out_packed[i] = pack_i8x4(a, b, c, d); + q_sum += (int)a + (int)b + (int)c + (int)d; + } + q4_q_block_info info = { qd, q_sum }; + return info; +} + +// k_packed[0..3] = low nibbles (Q elems 0..15), k_packed[4..7] = high (16..31). +inline void pack_q4_0_nibbles(const global uchar * qs, uint * k_packed) { + #pragma unroll + for (int g = 0; g < 4; ++g) { + uchar b0 = qs[g*4 + 0]; + uchar b1 = qs[g*4 + 1]; + uchar b2 = qs[g*4 + 2]; + uchar b3 = qs[g*4 + 3]; + k_packed[g] = + ((uint)(b0 & 0x0F)) | + ((uint)(b1 & 0x0F)) << 8 | + ((uint)(b2 & 0x0F)) << 16 | + ((uint)(b3 & 0x0F)) << 24; + k_packed[4 + g] = + ((uint)(b0 >> 4)) | + ((uint)(b1 >> 4)) << 8 | + ((uint)(b2 >> 4)) << 16 | + ((uint)(b3 >> 4)) << 24; + } +} + +inline float dot_q4_0_int(const global char * k_block_ptr, + const uint * q_packed, + float q_d, + int q_sum) { + float kd = vload_half(0, (const global half *)k_block_ptr); + const global uchar * k_qs = (const global uchar *)(k_block_ptr + 2); + + uint k_packed[8]; + pack_q4_0_nibbles(k_qs, k_packed); + + int sum = 0; + #pragma unroll + for (int i = 0; i < 8; ++i) { + sum = dot_acc_sat_4x8packed_ss_int(q_packed[i], k_packed[i], sum); + } + // Correct raw-nibble sum: (nibble - 8) bias -> subtract 8 * q_sum. + return (float)(sum - 8 * q_sum) * q_d * kd; +} +#endif // FA_HAVE_INT_DOT + +inline void dequant_q4_0_f32(const global char * block_ptr, ACC_TYPE4 * out) { + float d = vload_half(0, (const global half *)block_ptr); + const global uchar * qs = (const global uchar *)(block_ptr + 2); + + #pragma unroll + for (int g = 0; g < 4; ++g) { + out[g] = d * (float4)((float)(int)(qs[g*4 + 0] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 1] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 2] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 3] & 0x0F) - 8.0f); + } + #pragma unroll + for (int g = 0; g < 4; ++g) { + out[4 + g] = d * (float4)((float)(int)(qs[g*4 + 0] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 1] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 2] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 3] >> 4) - 8.0f); + } +} + +// max_bias<=0 returns 1.0 so score += 1.0 * mask[k] stays a no-op multiplier. +inline float get_alibi_slope(float max_bias, int head_idx, int n_head_log2, float m0, float m1) { + if (max_bias <= 0.0f) return 1.0f; + float base = (head_idx < n_head_log2) ? m0 : m1; + int exph = (head_idx < n_head_log2) ? (head_idx + 1) : (2*(head_idx - n_head_log2) + 1); + return pow(base, (float)exph); +} + +// q1 decode: one query row per WG, threads sweep KV positions. +__kernel void flash_attn_f32_q4_0_q1( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; + const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + +#ifdef FA_HAVE_INT_DOT + // Quantise Q once per thread: 8 uints + qd + q_sum per block. + uint q_packed[DK_Q4_BLOCKS * 8]; + float q_d_scale[DK_Q4_BLOCKS]; + int q_sum_arr[DK_Q4_BLOCKS]; + #pragma unroll + for (int b = 0; b < DK_Q4_BLOCKS; ++b) { + q4_q_block_info info = quant_q_block_int8_packed_q4(&q_priv[b * 8], &q_packed[b * 8]); + q_d_scale[b] = info.qd; + q_sum_arr[b] = info.q_sum; + } +#endif + + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + + const global ACC_TYPE* sinks_ptr = NULL; + if (sinks_void != NULL) { + sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + } + + // One-pass online softmax (FA-2): single sweep over kv positions, + // updating per-thread (m_i, l_i, o_acc) per K. Eliminates the second + // K read of the original two-pass implementation. + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT; + ACC_TYPE l_i = 0.0f; + ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const global char* k_row = k_base + batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global char* v_row = v_base + batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + + ACC_TYPE score = 0.0f; + #pragma unroll + for (int b = 0; b < DK_Q4_BLOCKS; b++) { +#ifdef FA_HAVE_INT_DOT + score += dot_q4_0_int(k_row + b * Q4_0_BLOCK_SIZE, + &q_packed[b * 8], q_d_scale[b], q_sum_arr[b]); +#else + score += dot_q4_0_f32(k_row + b * Q4_0_BLOCK_SIZE, &q_priv[b * 8]); +#endif + } + score *= scale; + + if (mask_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + + // Online softmax step. + const ACC_TYPE m_new = max(m_i, score); + const ACC_TYPE alpha = exp(m_i - m_new); + const ACC_TYPE p = exp(score - m_new); + + l_i = alpha * l_i + p; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha; + + #pragma unroll + for (int b = 0; b < DV_Q4_BLOCKS; b++) { + ACC_TYPE4 v_dequant[8]; + dequant_q4_0_f32(v_row + b * Q4_0_BLOCK_SIZE, v_dequant); + #pragma unroll + for (int i = 0; i < 8; i++) { + o_acc[b * 8 + i] = mad(p, v_dequant[i], o_acc[b * 8 + i]); + } + } + + m_i = m_new; + } + + // Cross-thread reduce: max(m_i) -> m_final, rescale per-thread l_i and + // o_acc by alpha = exp(m_i_thread - m_final) before sum-reduce. + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_final = local_m[0]; + + const ACC_TYPE alpha_final = exp(m_i - m_final); + l_i *= alpha_final; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha_final; + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; + global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); + ACC_TYPE l_final = local_l[0]; + + if (sinks_ptr != NULL) { + l_final += exp(sinks_ptr[head_idx] - m_final); + } + + if (l_final > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_final; + for (int i = 0; i < DV_VEC; i++) { + local_o_comp[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv); + } + } + } else if (tid == 0) { + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f); + } +} + +// Flash-decoding split pass for q4_0 KV. Merge kernel is type-agnostic and +// shared with the f16/q8_0 FA kernels. +#define FA_PARTIAL_FLOATS (2 + DV) + +__kernel void flash_attn_f32_q4_0_q1_split( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + const float scale, + const int n_q, + const int n_kv, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void * mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + global float * partial_void, + const int n_splits, + const int kv_per_split +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + const int split_q_idx = get_global_id(2); + const int split_idx = split_q_idx % n_splits; + const int q_idx = split_q_idx / n_splits; + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const int kv_start = split_idx * kv_per_split; + const int kv_end = min(kv_start + kv_per_split, n_kv); + + const ulong record_stride = (ulong) FA_PARTIAL_FLOATS; + const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) + * n_splits + split_idx); + global float * rec = partial_void + record_idx * record_stride; + global float4 * rec_o = (global float4 *) (rec + 2); + + if (kv_start >= kv_end) { + if (tid == 0) { + rec[0] = FA_M_INIT; + rec[1] = 0.0f; + } + return; + } + + const global char * q_base = (const global char *) q_void + q_offset; + const global char * k_base = (const global char *) k_void + k_offset; + const global char * v_base = (const global char *) v_void + v_offset; + + const global char * mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char *) mask_void + mask_offset + + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 + + (ulong) q_idx * mask_nb1; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1; + const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset); + #pragma unroll + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + +#ifdef FA_HAVE_INT_DOT + uint q_packed[DK_Q4_BLOCKS * 8]; + float q_d_scale[DK_Q4_BLOCKS]; + int q_sum_arr[DK_Q4_BLOCKS]; + #pragma unroll + for (int b = 0; b < DK_Q4_BLOCKS; ++b) { + q4_q_block_info info = quant_q_block_int8_packed_q4(&q_priv[b * 8], &q_packed[b * 8]); + q_d_scale[b] = info.qd; + q_sum_arr[b] = info.q_sum; + } +#endif + + const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + + // One-pass online softmax (FA-2): single sweep over the split's K range. + ACC_TYPE m_i = FA_M_INIT; + ACC_TYPE l_i = 0.0f; + ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + + for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) { + const global char * k_row = k_base + batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global char * v_row = v_base + batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + ACC_TYPE score = 0.0f; + #pragma unroll + for (int b = 0; b < DK_Q4_BLOCKS; ++b) { +#ifdef FA_HAVE_INT_DOT + score += dot_q4_0_int(k_row + b * Q4_0_BLOCK_SIZE, + &q_packed[b * 8], q_d_scale[b], q_sum_arr[b]); +#else + score += dot_q4_0_f32(k_row + b * Q4_0_BLOCK_SIZE, &q_priv[b * 8]); +#endif + } + score *= scale; + if (mask_base != NULL) { + const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base); + score += slope * (ACC_TYPE) mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + + // Online softmax step. + const ACC_TYPE m_new = max(m_i, score); + const ACC_TYPE alpha = exp(m_i - m_new); + const ACC_TYPE p = exp(score - m_new); + + l_i = alpha * l_i + p; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha; + + #pragma unroll + for (int b = 0; b < DV_Q4_BLOCKS; ++b) { + ACC_TYPE4 v_dequant[8]; + dequant_q4_0_f32(v_row + b * Q4_0_BLOCK_SIZE, v_dequant); + #pragma unroll + for (int i = 0; i < 8; ++i) { + o_acc[b * 8 + i] = mad(p, v_dequant[i], o_acc[b * 8 + i]); + } + } + + m_i = m_new; + } + + // Cross-thread reduce: max(m_i) -> m_c, rescale per-thread l_i and o_acc + // by alpha = exp(m_i_thread - m_c) before sum-reduce. + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_c = local_m[0]; + + const ACC_TYPE alpha_final = exp(m_i - m_c); + l_i *= alpha_final; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha_final; + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE l_c = local_l[0]; + + if (tid == 0) { + rec[0] = (float) m_c; + rec[1] = (float) l_c; + } + for (int i = 0; i < DV_VEC; ++i) { + local_o[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o[tid] += local_o[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + rec_o[i] = local_o[0]; + } + } +} + +// Prefill: q4_0 K/V, n_q > 1. BLOCK_M × BLOCK_N tiling. +// K in local as packed nibbles + per-block scale; V dequant -> half in local. +// Requires DK % QK4_0 == 0 and DV % QK4_0 == 0. +#define KV_DATA_TYPE4 half4 +#define CONVERT_KV_ACC4(x) convert_float4(x) + +#define DK_Q4_BLOCKS_PREFILL (DK / QK4_0) +#define DV_Q4_BLOCKS_PREFILL (DV / QK4_0) + +// N_SPLIT>1 splits DK/DV across N_SPLIT threads per query row; needs +// sub_group_shuffle_xor and DK_Q4_BLOCKS_PREFILL % N_SPLIT == 0. +#ifndef N_SPLIT +#define N_SPLIT 1 +#endif + +#if N_SPLIT > 1 +#define SPLIT_DK_VEC (DK_VEC / N_SPLIT) +#define SPLIT_DV_VEC (DV_VEC / N_SPLIT) +#define SPLIT_DK_Q4_BLOCKS (DK_Q4_BLOCKS_PREFILL / N_SPLIT) +#define WG_SIZE (BLOCK_M * N_SPLIT) +#else +#define SPLIT_DK_VEC DK_VEC +#define SPLIT_DV_VEC DV_VEC +#define SPLIT_DK_Q4_BLOCKS DK_Q4_BLOCKS_PREFILL +#define WG_SIZE BLOCK_M +#endif + +__kernel void flash_attn_f32_q4_0( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset, + // blk: per-(qblock,kvblock) class from flash_attn_blk_f16 + // (0=masked, 1=mixed, 2=unmasked). NULL disables the prepass opt. + const global void * blk_void +) { + const int tid = get_local_id(0); + const int block_q_idx = get_group_id(0); + const int head_batch_idx = get_global_id(1); + +#if N_SPLIT > 1 + const int q_lane = tid / N_SPLIT; + const int split_idx = tid % N_SPLIT; +#else + const int q_lane = tid; + const int split_idx = 0; +#endif + const int my_query_row = block_q_idx * BLOCK_M + q_lane; + const int query_valid = my_query_row < n_q; + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0; + const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0; + + const global char * q_base = (const global char *) q_void + q_offset; + const global char * k_base = (const global char *) k_void + k_offset; + const global char * v_base = (const global char *) v_void + v_offset; + global char * o_base = (global char *) o_void + o_offset; + + const global char * mask_base = NULL; + if (mask_void != NULL) { + mask_base = (const global char *) mask_void + mask_offset + + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + // BLK_PREPASS_BM may differ from this kernel's BLOCK_M; scale q-block idx. + #ifndef BLK_PREPASS_BM + #define BLK_PREPASS_BM BLOCK_M + #endif + const global char * blk_base = NULL; + int n_kv_blocks = 0; + if (blk_void != NULL) { + n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N; + const int n_q_blocks_prepass = (n_q + BLK_PREPASS_BM - 1) / BLK_PREPASS_BM; + const int prepass_q_block = (block_q_idx * BLOCK_M) / BLK_PREPASS_BM; + blk_base = (const global char *) blk_void + + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks_prepass + prepass_q_block) * n_kv_blocks; + } + + const int dk_off_vec = split_idx * SPLIT_DK_VEC; + ACC_TYPE4 q_priv[SPLIT_DK_VEC]; + if (query_valid) { + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; + const global float4 * q_ptr = (const global float4 *) (q_base + q_row_offset); + #pragma unroll + for (int i = 0; i < SPLIT_DK_VEC; ++i) { + q_priv[i] = q_ptr[dk_off_vec + i]; + } + } else { + #pragma unroll + for (int i = 0; i < SPLIT_DK_VEC; ++i) q_priv[i] = (ACC_TYPE4)(0.0f); + } + +#ifdef FA_HAVE_INT_DOT + uint q_packed_pf[SPLIT_DK_Q4_BLOCKS * 8]; + float q_d_pf[SPLIT_DK_Q4_BLOCKS]; + int q_sum_pf[SPLIT_DK_Q4_BLOCKS]; + #pragma unroll + for (int b = 0; b < SPLIT_DK_Q4_BLOCKS; ++b) { + q4_q_block_info info = quant_q_block_int8_packed_q4(&q_priv[b * 8], &q_packed_pf[b * 8]); + q_d_pf[b] = info.qd; + q_sum_pf[b] = info.q_sum; + } +#endif + + const int dv_off_vec = split_idx * SPLIT_DV_VEC; + ACC_TYPE4 o_acc[SPLIT_DV_VEC]; + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + + ACC_TYPE m_i = FA_M_INIT; + ACC_TYPE l_i = 0.0f; + + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + +#ifdef FA_HAVE_INT_DOT + __local uint l_k_packed[BLOCK_N][DK_Q4_BLOCKS_PREFILL * 8]; + __local float l_k_scale [BLOCK_N][DK_Q4_BLOCKS_PREFILL]; +#else + __local half4 l_k[BLOCK_N][DK_VEC]; +#endif + + __local half4 l_v[BLOCK_N][DV_VEC]; + + for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) { + // Skip fully-masked KV tiles (uniform branch across WG). + char blk_cur = 1; + if (blk_base != NULL) { + blk_cur = blk_base[k_start / BLOCK_N]; + if (blk_cur == 0) continue; + } + + { +#ifdef FA_HAVE_INT_DOT + const int k_blocks_per_row = DK_Q4_BLOCKS_PREFILL; + const int n_blocks_total = BLOCK_N * k_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / k_blocks_per_row; + const int blk = i % k_blocks_per_row; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_off = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + const global char * blk_ptr = k_base + k_row_off + blk * Q4_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global uchar * qs = (const global uchar *)(blk_ptr + 2); + l_k_scale[row][blk] = df; + uint k_packed[8]; + pack_q4_0_nibbles(qs, k_packed); + #pragma unroll + for (int j = 0; j < 8; ++j) { + l_k_packed[row][blk * 8 + j] = k_packed[j]; + } + } else { + l_k_scale[row][blk] = 0.0f; + #pragma unroll + for (int j = 0; j < 8; ++j) l_k_packed[row][blk * 8 + j] = 0u; + } + } +#else + // Fallback: dequant q4_0 -> half in local memory. + const int k_blocks_per_row = DK_Q4_BLOCKS_PREFILL; + const int n_blocks_total = BLOCK_N * k_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / k_blocks_per_row; + const int blk = i % k_blocks_per_row; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_off = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + const global char * blk_ptr = k_base + k_row_off + blk * Q4_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global uchar * qs = (const global uchar *)(blk_ptr + 2); + #pragma unroll + for (int g = 0; g < 4; ++g) { + float4 vlo = df * (float4)((float)(int)(qs[g*4 + 0] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 1] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 2] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 3] & 0x0F) - 8.0f); + float4 vhi = df * (float4)((float)(int)(qs[g*4 + 0] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 1] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 2] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 3] >> 4) - 8.0f); + l_k[row][blk * 8 + g ] = (half4)((half)vlo.s0, (half)vlo.s1, (half)vlo.s2, (half)vlo.s3); + l_k[row][blk * 8 + 4 + g] = (half4)((half)vhi.s0, (half)vhi.s1, (half)vhi.s2, (half)vhi.s3); + } + } else { + #pragma unroll + for (int j = 0; j < 8; ++j) l_k[row][blk * 8 + j] = (half4)(0.0h); + } + } +#endif + } + // V tile load — dequant V -> half in local memory. + { + const int v_blocks_per_row = DV_Q4_BLOCKS_PREFILL; + const int n_blocks_total = BLOCK_N * v_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / v_blocks_per_row; + const int blk = i % v_blocks_per_row; + const int v_row_idx = k_start + row; + if (v_row_idx < n_kv) { + const ulong v_row_off = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; + const global char * blk_ptr = v_base + v_row_off + blk * Q4_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global uchar * qs = (const global uchar *)(blk_ptr + 2); + #pragma unroll + for (int g = 0; g < 4; ++g) { + float4 vlo = df * (float4)((float)(int)(qs[g*4 + 0] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 1] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 2] & 0x0F) - 8.0f, + (float)(int)(qs[g*4 + 3] & 0x0F) - 8.0f); + float4 vhi = df * (float4)((float)(int)(qs[g*4 + 0] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 1] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 2] >> 4) - 8.0f, + (float)(int)(qs[g*4 + 3] >> 4) - 8.0f); + l_v[row][blk * 8 + g ] = (half4)((half)vlo.s0, (half)vlo.s1, (half)vlo.s2, (half)vlo.s3); + l_v[row][blk * 8 + 4 + g] = (half4)((half)vhi.s0, (half)vhi.s1, (half)vhi.s2, (half)vhi.s3); + } + } else { + #pragma unroll + for (int j = 0; j < 8; ++j) l_v[row][blk * 8 + j] = (half4)(0.0h); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + // QK dot + online softmax. N_SPLIT>1 reduces per-thread partials via shuffle_xor. +#if N_SPLIT > 1 + { +#else + if (query_valid) { +#endif + const int k_blk_base = split_idx * SPLIT_DK_Q4_BLOCKS; + for (int j = 0; j < BLOCK_N; j += 4) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + const int k_row2 = k_start + j + 2; + const int k_row3 = k_start + j + 3; + + ACC_TYPE s0, s1, s2, s3; +#ifdef FA_HAVE_INT_DOT + s0 = 0.0f; s1 = 0.0f; s2 = 0.0f; s3 = 0.0f; + #pragma unroll + for (int b_local = 0; b_local < SPLIT_DK_Q4_BLOCKS; ++b_local) { + const int b = k_blk_base + b_local; + int sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0; + #pragma unroll + for (int g = 0; g < 8; ++g) { + const uint qp = q_packed_pf[b_local * 8 + g]; + sum0 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j ][b * 8 + g], sum0); + sum1 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j+1][b * 8 + g], sum1); + sum2 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j+2][b * 8 + g], sum2); + sum3 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j+3][b * 8 + g], sum3); + } + const float qd = q_d_pf[b_local]; + const int q_sum = q_sum_pf[b_local]; + s0 += (float)(sum0 - 8 * q_sum) * qd * l_k_scale[j ][b]; + s1 += (float)(sum1 - 8 * q_sum) * qd * l_k_scale[j+1][b]; + s2 += (float)(sum2 - 8 * q_sum) * qd * l_k_scale[j+2][b]; + s3 += (float)(sum3 - 8 * q_sum) * qd * l_k_scale[j+3][b]; + } +#else + ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f); + #pragma unroll + for (int k = 0; k < SPLIT_DK_VEC; ++k) { + const ACC_TYPE4 qk = q_priv[k]; + const int k_abs = dk_off_vec + k; + dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j ][k_abs]), dot_acc0); + dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k_abs]), dot_acc1); + dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k_abs]), dot_acc2); + dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k_abs]), dot_acc3); + } + s0 = dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3; + s1 = dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3; + s2 = dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3; + s3 = dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3; +#endif + +#if N_SPLIT > 1 + // Power-of-2 N_SPLIT: shuffle_xor butterfly. N_SPLIT=3 (DK=96): + // explicit 3-lane shuffle. + #if (N_SPLIT & (N_SPLIT - 1)) == 0 + #pragma unroll + for (int step = 1; step < N_SPLIT; step <<= 1) { + s0 += sub_group_shuffle_xor(s0, step); + s1 += sub_group_shuffle_xor(s1, step); + s2 += sub_group_shuffle_xor(s2, step); + s3 += sub_group_shuffle_xor(s3, step); + } + #else + const uint tri_base = (get_sub_group_local_id() / N_SPLIT) * N_SPLIT; + s0 = sub_group_shuffle(s0, tri_base + 0) + sub_group_shuffle(s0, tri_base + 1) + sub_group_shuffle(s0, tri_base + 2); + s1 = sub_group_shuffle(s1, tri_base + 0) + sub_group_shuffle(s1, tri_base + 1) + sub_group_shuffle(s1, tri_base + 2); + s2 = sub_group_shuffle(s2, tri_base + 0) + sub_group_shuffle(s2, tri_base + 1) + sub_group_shuffle(s2, tri_base + 2); + s3 = sub_group_shuffle(s3, tri_base + 0) + sub_group_shuffle(s3, tri_base + 1) + sub_group_shuffle(s3, tri_base + 2); + #endif + if (!query_valid) { s0 = FA_M_INIT; s1 = FA_M_INIT; s2 = FA_M_INIT; s3 = FA_M_INIT; } +#endif + s0 *= scale; s1 *= scale; s2 *= scale; s3 *= scale; + + if (is_causal) { + const int causal_limit = n_kv - n_q + my_query_row; + if (k_row0 > causal_limit) s0 = FA_M_INIT; + if (k_row1 > causal_limit) s1 = FA_M_INIT; + if (k_row2 > causal_limit) s2 = FA_M_INIT; + if (k_row3 > causal_limit) s3 = FA_M_INIT; + } + if (k_row0 >= n_kv) s0 = FA_M_INIT; + if (k_row1 >= n_kv) s1 = FA_M_INIT; + if (k_row2 >= n_kv) s2 = FA_M_INIT; + if (k_row3 >= n_kv) s3 = FA_M_INIT; + + if (query_valid && mask_base != NULL && blk_cur != 2) { + const global MASK_DATA_TYPE * mask_ptr = + (const global MASK_DATA_TYPE *) (mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) s0 += slope * (ACC_TYPE) mask_ptr[k_row0]; + if (k_row1 < n_kv) s1 += slope * (ACC_TYPE) mask_ptr[k_row1]; + if (k_row2 < n_kv) s2 += slope * (ACC_TYPE) mask_ptr[k_row2]; + if (k_row3 < n_kv) s3 += slope * (ACC_TYPE) mask_ptr[k_row3]; + } + if (logit_softcap > 0.0f) { + s0 = logit_softcap * tanh(s0 / logit_softcap); + s1 = logit_softcap * tanh(s1 / logit_softcap); + s2 = logit_softcap * tanh(s2 / logit_softcap); + s3 = logit_softcap * tanh(s3 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3))); + // Whole tile masked (m_new == FA_M_INIT): force the exp() args + // far negative so the tile contributes 0, not exp(0)=1. + const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new; + const ACC_TYPE scale_prev = native_exp(m_i - m_exp); + const ACC_TYPE p0 = native_exp(s0 - m_exp); + const ACC_TYPE p1 = native_exp(s1 - m_exp); + const ACC_TYPE p2 = native_exp(s2 - m_exp); + const ACC_TYPE p3 = native_exp(s3 - m_exp); + + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + const int i_abs = dv_off_vec + i; + o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i_abs]), + mad(p2, CONVERT_KV_ACC4(l_v[j+2][i_abs]), + mad(p1, CONVERT_KV_ACC4(l_v[j+1][i_abs]), + mad(p0, CONVERT_KV_ACC4(l_v[j ][i_abs]), + o_acc[i] * scale_prev)))); + } + l_i = l_i * scale_prev + p0 + p1 + p2 + p3; + m_i = m_new; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Write output. + if (query_valid) { + if (sinks_void != NULL) { + const global ACC_TYPE * sinks_ptr = + (const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + const ACC_TYPE scale_o = exp(m_i - m_final); + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_acc[i] *= scale_o; + l_i = l_i * scale_o + exp(m_sink - m_final); + m_i = m_final; + } + const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f; + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global float4 * o_row = (global float4 *) (o_base + o_row_offset); + if (l_inv > 0.0f) { + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_row[dv_off_vec + i] = o_acc[i] * l_inv; + } else { + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_row[dv_off_vec + i] = (float4)(0.0f); + } + } +} + +// FD Pass 2: merge split partials. Identical across q4_0/q8_0/f16; each FA +// source owns a copy since kernels compile per-source-program. +__kernel void flash_attn_f32_merge( + const global float * partial_void, + global void * o_void, + const ulong o_offset, + const int n_head, + const int n_splits, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const global void * sinks_void, + const ulong sinks_offset, + const int n_q +) { + const int lane = get_local_id(0); + const int head_batch_idx = get_global_id(1); + const int q_idx = get_global_id(2); + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const ulong record_stride = (ulong) FA_PARTIAL_FLOATS; + const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits; + const global float * rec0 = partial_void + record_idx_0 * record_stride; + + __local ACC_TYPE m_final_shared; + __local ACC_TYPE l_final_shared; + if (lane == 0) { + ACC_TYPE m = FA_M_INIT; + for (int c = 0; c < n_splits; ++c) { + const ACC_TYPE m_c = rec0[c * record_stride + 0]; + m = max(m, m_c); + } + ACC_TYPE m_sink = 0.0f; + bool has_sink = false; + if (sinks_void != NULL) { + const global ACC_TYPE * sinks_ptr = + (const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset); + m_sink = sinks_ptr[head_idx]; + has_sink = true; + m = max(m, m_sink); + } + ACC_TYPE l = 0.0f; + for (int c = 0; c < n_splits; ++c) { + const ACC_TYPE m_c = rec0[c * record_stride + 0]; + const ACC_TYPE l_c = rec0[c * record_stride + 1]; + if (m_c > FA_M_INIT) { + l += l_c * exp(m_c - m); + } + } + if (has_sink) { + l += exp(m_sink - m); + } + m_final_shared = m; + l_final_shared = l; + } + barrier(CLK_LOCAL_MEM_FENCE); + const ACC_TYPE m_final = m_final_shared; + const ACC_TYPE l_final = l_final_shared; + const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f; + + ACC_TYPE4 o = (ACC_TYPE4)(0.0f); + for (int c = 0; c < n_splits; ++c) { + const global float * rec_c = rec0 + c * record_stride; + const ACC_TYPE m_c = rec_c[0]; + if (m_c <= FA_M_INIT) continue; + const global float4 * rec_oc = (const global float4 *) (rec_c + 2); + const ACC_TYPE scale_c = exp(m_c - m_final); + o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o); + } + o = o * l_inv; + + const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1; + global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset); + o_row[lane] = CONVERT_O_DATA4(o); +} diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_q8_0.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_q8_0.cl new file mode 100644 index 000000000000..a25823f0074e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_q8_0.cl @@ -0,0 +1,1049 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#ifdef cl_khr_integer_dot_product +#pragma OPENCL EXTENSION cl_khr_integer_dot_product : enable +#define FA_HAVE_INT_DOT 1 +#endif + +#ifdef cl_khr_subgroup_shuffle +#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable +#define HAS_SUBGROUP_SHUFFLE 1 +#elif defined(cl_qcom_subgroup_shuffle) +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#define HAS_SUBGROUP_SHUFFLE 1 +#endif + +// Flash attention: Q=f32, K=q8_0, V=q8_0. + +#define ACC_TYPE float +#define ACC_TYPE4 float4 +#define Q_DATA_TYPE4 float4 +#define O_DATA_TYPE4 float4 +#define MASK_DATA_TYPE half +#define CONVERT_Q_ACC4(x) (x) +#define CONVERT_O_DATA4(x) (x) + +#define DK_VEC (DK/4) +#define DV_VEC (DV/4) +#define Q1_WG_SIZE 64 + +// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs, +// infinite operand can cause undefined behavior and miscompilation for exp. +// Therefore, a large negative value is used instead. +#define FA_M_INIT (-3.0e38f) + +// q8_0 block: 2B scale (half) + 32B int8 quants. +#define QK8_0 32 +#define Q8_0_BLOCK_SIZE 34 + +#define DK_Q8_BLOCKS (DK / QK8_0) +#define DV_Q8_BLOCKS (DV / QK8_0) + +inline float dot_q8_0_f32(const global char * block_ptr, ACC_TYPE4 * q_slice) { + float d = vload_half(0, (const global half *)block_ptr); + const global char * qs = block_ptr + 2; + + float sum = 0.0f; + #pragma unroll + for (int i = 0; i < 8; i++) { + float4 qv = (float4)((float)qs[i*4], (float)qs[i*4+1], (float)qs[i*4+2], (float)qs[i*4+3]); + sum += dot(q_slice[i], qv); + } + return sum * d; +} + +#ifdef FA_HAVE_INT_DOT +inline uint pack_i8x4(char a, char b, char c, char d) { + return ((uint)(uchar)a) | + ((uint)(uchar)b) << 8 | + ((uint)(uchar)c) << 16 | + ((uint)(uchar)d) << 24; +} + +inline float quant_q_block_int8_packed(const ACC_TYPE4 * q_block, + uint * out_packed) { + float amax = 0.0f; + #pragma unroll + for (int i = 0; i < 8; ++i) { + float4 av = fabs(q_block[i]); + amax = fmax(amax, fmax(fmax(av.s0, av.s1), fmax(av.s2, av.s3))); + } + float qd = amax / 127.0f; + float qid = (amax > 0.0f) ? 127.0f / amax : 0.0f; + + #pragma unroll + for (int i = 0; i < 8; ++i) { + float4 v = q_block[i] * qid; + char a = (char)((int)round(v.s0)); + char b = (char)((int)round(v.s1)); + char c = (char)((int)round(v.s2)); + char d = (char)((int)round(v.s3)); + out_packed[i] = pack_i8x4(a, b, c, d); + } + return qd; +} + +inline float dot_q8_0_int(const global char * k_block_ptr, + const uint * q_packed, + float q_d) { + float kd = vload_half(0, (const global half *)k_block_ptr); + const global uchar * k_qs = (const global uchar *)(k_block_ptr + 2); + + // k_qs is 2-byte aligned; pack chars per iteration rather than cast to uint*. + int sum = 0; + #pragma unroll + for (int i = 0; i < 8; ++i) { + uint k_packed = + (uint)k_qs[i*4 + 0] | + ((uint)k_qs[i*4 + 1]) << 8 | + ((uint)k_qs[i*4 + 2]) << 16 | + ((uint)k_qs[i*4 + 3]) << 24; + sum = dot_acc_sat_4x8packed_ss_int(q_packed[i], k_packed, sum); + } + return (float)sum * q_d * kd; +} +#endif // FA_HAVE_INT_DOT + +inline void dequant_q8_0_f32(const global char * block_ptr, ACC_TYPE4 * out) { + float d = vload_half(0, (const global half *)block_ptr); + const global char * qs = block_ptr + 2; + + #pragma unroll + for (int i = 0; i < 8; i++) { + out[i] = d * (float4)((float)qs[i*4], (float)qs[i*4+1], (float)qs[i*4+2], (float)qs[i*4+3]); + } +} + +// max_bias<=0 returns 1.0 so score += 1.0 * mask[k] stays a no-op multiplier. +inline float get_alibi_slope(float max_bias, int head_idx, int n_head_log2, float m0, float m1) { + if (max_bias <= 0.0f) return 1.0f; + float base = (head_idx < n_head_log2) ? m0 : m1; + int exph = (head_idx < n_head_log2) ? (head_idx + 1) : (2*(head_idx - n_head_log2) + 1); + return pow(base, (float)exph); +} + +// q1 decode: one query row per WG, threads sweep KV positions. +__kernel void flash_attn_f32_q8_0_q1( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; + const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + +#ifdef FA_HAVE_INT_DOT + // Quantise Q once per thread; q_priv stays as fp for the V accumulate. + uint q_packed[DK_Q8_BLOCKS * 8]; + float q_d_scale[DK_Q8_BLOCKS]; + #pragma unroll + for (int b = 0; b < DK_Q8_BLOCKS; ++b) { + q_d_scale[b] = quant_q_block_int8_packed(&q_priv[b * 8], &q_packed[b * 8]); + } +#endif + + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + + const global ACC_TYPE* sinks_ptr = NULL; + if (sinks_void != NULL) { + sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset); + } + + // One-pass online softmax: per-thread maintains running (m_i, l_i, o_acc), + // updating each as new K positions are processed. Eliminates the second + // K read of the original two-pass implementation. After the loop, threads + // are merged via the standard FA-2 cross-thread reduction (rescale each + // thread's l_i and o_acc by alpha=exp(m_i_thread - m_final), then sum). + ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT; + ACC_TYPE l_i = 0.0f; + ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const global char* k_row = k_base + batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global char* v_row = v_base + batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + + ACC_TYPE score = 0.0f; + #pragma unroll + for (int b = 0; b < DK_Q8_BLOCKS; b++) { +#ifdef FA_HAVE_INT_DOT + score += dot_q8_0_int(k_row + b * Q8_0_BLOCK_SIZE, + &q_packed[b * 8], q_d_scale[b]); +#else + score += dot_q8_0_f32(k_row + b * Q8_0_BLOCK_SIZE, &q_priv[b * 8]); +#endif + } + score *= scale; + + if (mask_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + + // Online softmax step. + const ACC_TYPE m_new = max(m_i, score); + const ACC_TYPE alpha = exp(m_i - m_new); + const ACC_TYPE p = exp(score - m_new); + + l_i = alpha * l_i + p; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha; + + #pragma unroll + for (int b = 0; b < DV_Q8_BLOCKS; b++) { + ACC_TYPE4 v_dequant[8]; + dequant_q8_0_f32(v_row + b * Q8_0_BLOCK_SIZE, v_dequant); + #pragma unroll + for (int i = 0; i < 8; i++) { + o_acc[b * 8 + i] = mad(p, v_dequant[i], o_acc[b * 8 + i]); + } + } + + m_i = m_new; + } + + // Cross-thread reduce: max(m_i) -> m_final, then rescale per-thread l_i + // and o_acc by alpha = exp(m_i_thread - m_final) before sum-reduce. + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_final = local_m[0]; + + const ACC_TYPE alpha_final = exp(m_i - m_final); + l_i *= alpha_final; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha_final; + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; + global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); + ACC_TYPE l_final = local_l[0]; + + if (sinks_ptr != NULL) { + l_final += exp(sinks_ptr[head_idx] - m_final); + } + + if (l_final > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_final; + for (int i = 0; i < DV_VEC; i++) { + local_o_comp[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv); + } + } + } else if (tid == 0) { + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f); + } +} + +// Flash-decoding split pass for q8_0 KV. Partial record: [m, l, O[DV]]. +// Merge kernel from flash_attn_f32_f16.cl is type-agnostic and reused. +#define FA_PARTIAL_FLOATS (2 + DV) + +__kernel void flash_attn_f32_q8_0_q1_split( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + const float scale, + const int n_q, + const int n_kv, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void * mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + global float * partial_void, + const int n_splits, + const int kv_per_split +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + const int split_q_idx = get_global_id(2); + const int split_idx = split_q_idx % n_splits; + const int q_idx = split_q_idx / n_splits; + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const int kv_start = split_idx * kv_per_split; + const int kv_end = min(kv_start + kv_per_split, n_kv); + + const ulong record_stride = (ulong) FA_PARTIAL_FLOATS; + const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) + * n_splits + split_idx); + global float * rec = partial_void + record_idx * record_stride; + global float4 * rec_o = (global float4 *) (rec + 2); + + if (kv_start >= kv_end) { + // Empty split: leave sentinel partial for merge. + if (tid == 0) { + rec[0] = FA_M_INIT; + rec[1] = 0.0f; + } + return; + } + + const global char * q_base = (const global char *) q_void + q_offset; + const global char * k_base = (const global char *) k_void + k_offset; + const global char * v_base = (const global char *) v_void + v_offset; + + const global char * mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char *) mask_void + mask_offset + + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 + + (ulong) q_idx * mask_nb1; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1; + const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset); + #pragma unroll + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + +#ifdef FA_HAVE_INT_DOT + uint q_packed[DK_Q8_BLOCKS * 8]; + float q_d_scale[DK_Q8_BLOCKS]; + #pragma unroll + for (int b = 0; b < DK_Q8_BLOCKS; ++b) { + q_d_scale[b] = quant_q_block_int8_packed(&q_priv[b * 8], &q_packed[b * 8]); + } +#endif + + const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + + // One-pass online softmax (FA-2): single sweep over the split's K range, + // updating per-thread (m_i, l_i, o_acc) per position. Eliminates the + // second K read of the original two-pass implementation. + ACC_TYPE m_i = FA_M_INIT; + ACC_TYPE l_i = 0.0f; + ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + + for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) { + const global char * k_row = k_base + batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global char * v_row = v_base + batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + ACC_TYPE score = 0.0f; + #pragma unroll + for (int b = 0; b < DK_Q8_BLOCKS; ++b) { +#ifdef FA_HAVE_INT_DOT + score += dot_q8_0_int(k_row + b * Q8_0_BLOCK_SIZE, &q_packed[b * 8], q_d_scale[b]); +#else + score += dot_q8_0_f32(k_row + b * Q8_0_BLOCK_SIZE, &q_priv[b * 8]); +#endif + } + score *= scale; + if (mask_base != NULL) { + const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base); + score += slope * (ACC_TYPE) mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + + // Online softmax step. + const ACC_TYPE m_new = max(m_i, score); + const ACC_TYPE alpha = exp(m_i - m_new); + const ACC_TYPE p = exp(score - m_new); + + l_i = alpha * l_i + p; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha; + + #pragma unroll + for (int b = 0; b < DV_Q8_BLOCKS; ++b) { + ACC_TYPE4 v_dequant[8]; + dequant_q8_0_f32(v_row + b * Q8_0_BLOCK_SIZE, v_dequant); + #pragma unroll + for (int i = 0; i < 8; ++i) { + o_acc[b * 8 + i] = mad(p, v_dequant[i], o_acc[b * 8 + i]); + } + } + + m_i = m_new; + } + + // Cross-thread reduce: max(m_i) -> m_c, then rescale per-thread l_i and + // o_acc by alpha = exp(m_i_thread - m_c) before sum-reduce. + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_c = local_m[0]; + + const ACC_TYPE alpha_final = exp(m_i - m_c); + l_i *= alpha_final; + #pragma unroll + for (int i = 0; i < DV_VEC; ++i) o_acc[i] *= alpha_final; + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE l_c = local_l[0]; + + if (tid == 0) { + rec[0] = (float) m_c; + rec[1] = (float) l_c; + } + for (int i = 0; i < DV_VEC; ++i) { + local_o[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o[tid] += local_o[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + rec_o[i] = local_o[0]; + } + } +} + +// Prefill: q8_0 K/V, n_q > 1. BLOCK_M × BLOCK_N tiling. +// K path keeps packed int8 in local for dp4a QK dot; V path dequant -> half in local. +// Requires DK % QK8_0 == 0 and DV % QK8_0 == 0 (gated in supports_op). +#define KV_DATA_TYPE4 half4 +#define CONVERT_KV_ACC4(x) convert_float4(x) + +#define DK_Q8_BLOCKS_PREFILL (DK / QK8_0) +#define DV_Q8_BLOCKS_PREFILL (DV / QK8_0) + +// N_SPLIT>1 splits DK/DV across N_SPLIT threads per query row; needs +// sub_group_shuffle_xor and DK_Q8_BLOCKS_PREFILL % N_SPLIT == 0. +#ifndef N_SPLIT +#define N_SPLIT 1 +#endif + +#if N_SPLIT > 1 +#define SPLIT_DK_VEC (DK_VEC / N_SPLIT) +#define SPLIT_DV_VEC (DV_VEC / N_SPLIT) +#define SPLIT_DK_Q8_BLOCKS (DK_Q8_BLOCKS_PREFILL / N_SPLIT) +#define WG_SIZE (BLOCK_M * N_SPLIT) +#else +#define SPLIT_DK_VEC DK_VEC +#define SPLIT_DV_VEC DV_VEC +#define SPLIT_DK_Q8_BLOCKS DK_Q8_BLOCKS_PREFILL +#define WG_SIZE BLOCK_M +#endif + +// FA_V_STRATEGY: 0 = dequant V to half in local (default); 2 = keep packed +// int8 in local, dequant in the accumulate loop (smaller local, slightly slower). +#ifndef FA_V_STRATEGY +#define FA_V_STRATEGY 0 +#endif + +__kernel void flash_attn_f32_q8_0( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3, + const global void* sinks_void, + const ulong sinks_offset, + // blk: per-(qblock,kvblock) class from flash_attn_blk_f16 + // (0=masked, 1=mixed, 2=unmasked). NULL disables the prepass opt. + const global void * blk_void +) { + const int tid = get_local_id(0); + const int block_q_idx = get_group_id(0); + const int head_batch_idx = get_global_id(1); + +#if N_SPLIT > 1 + const int q_lane = tid / N_SPLIT; + const int split_idx = tid % N_SPLIT; +#else + const int q_lane = tid; + const int split_idx = 0; +#endif + const int my_query_row = block_q_idx * BLOCK_M + q_lane; + const int query_valid = my_query_row < n_q; + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0; + const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0; + + const global char * q_base = (const global char *) q_void + q_offset; + const global char * k_base = (const global char *) k_void + k_offset; + const global char * v_base = (const global char *) v_void + v_offset; + global char * o_base = (global char *) o_void + o_offset; + + const global char * mask_base = NULL; + if (mask_void != NULL) { + mask_base = (const global char *) mask_void + mask_offset + + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + // BLK_PREPASS_BM may differ from this kernel's BLOCK_M; scale q-block idx. + #ifndef BLK_PREPASS_BM + #define BLK_PREPASS_BM BLOCK_M + #endif + const global char * blk_base = NULL; + int n_kv_blocks = 0; + if (blk_void != NULL) { + n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N; + const int n_q_blocks_prepass = (n_q + BLK_PREPASS_BM - 1) / BLK_PREPASS_BM; + const int prepass_q_block = (block_q_idx * BLOCK_M) / BLK_PREPASS_BM; + blk_base = (const global char *) blk_void + + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks_prepass + prepass_q_block) * n_kv_blocks; + } + + const int dk_off_vec = split_idx * SPLIT_DK_VEC; + ACC_TYPE4 q_priv[SPLIT_DK_VEC]; + if (query_valid) { + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; + const global float4 * q_ptr = (const global float4 *) (q_base + q_row_offset); + #pragma unroll + for (int i = 0; i < SPLIT_DK_VEC; ++i) { + q_priv[i] = q_ptr[dk_off_vec + i]; + } + } else { + #pragma unroll + for (int i = 0; i < SPLIT_DK_VEC; ++i) q_priv[i] = (ACC_TYPE4)(0.0f); + } + +#ifdef FA_HAVE_INT_DOT + uint q_packed_pf[SPLIT_DK_Q8_BLOCKS * 8]; + float q_d_pf[SPLIT_DK_Q8_BLOCKS]; + #pragma unroll + for (int b = 0; b < SPLIT_DK_Q8_BLOCKS; ++b) { + q_d_pf[b] = quant_q_block_int8_packed(&q_priv[b * 8], &q_packed_pf[b * 8]); + } +#endif + + const int dv_off_vec = split_idx * SPLIT_DV_VEC; + ACC_TYPE4 o_acc[SPLIT_DV_VEC]; + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + + ACC_TYPE m_i = FA_M_INIT; + ACC_TYPE l_i = 0.0f; + + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); + +#ifdef FA_HAVE_INT_DOT + __local uint l_k_packed[BLOCK_N][DK_Q8_BLOCKS_PREFILL * 8]; + __local float l_k_scale [BLOCK_N][DK_Q8_BLOCKS_PREFILL]; +#else + __local half4 l_k[BLOCK_N][DK_VEC]; +#endif + +#if FA_V_STRATEGY == 2 + __local uint l_v_packed[BLOCK_N][DV_Q8_BLOCKS_PREFILL * 8]; + __local float l_v_scale [BLOCK_N][DV_Q8_BLOCKS_PREFILL]; +#else + __local half4 l_v[BLOCK_N][DV_VEC]; +#endif + + for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) { + // Skip fully-masked KV tiles (uniform branch across WG). + char blk_cur = 1; + if (blk_base != NULL) { + blk_cur = blk_base[k_start / BLOCK_N]; + if (blk_cur == 0) continue; + } + + { +#ifdef FA_HAVE_INT_DOT + const int k_blocks_per_row = DK_Q8_BLOCKS_PREFILL; + const int n_blocks_total = BLOCK_N * k_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / k_blocks_per_row; + const int blk = i % k_blocks_per_row; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_off = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + const global char * blk_ptr = k_base + k_row_off + blk * Q8_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global uchar * qs = (const global uchar *)(blk_ptr + 2); + l_k_scale[row][blk] = df; + #pragma unroll + for (int j = 0; j < 8; ++j) { + uint k_packed = + (uint) qs[j*4 + 0] | + ((uint) qs[j*4 + 1]) << 8 | + ((uint) qs[j*4 + 2]) << 16 | + ((uint) qs[j*4 + 3]) << 24; + l_k_packed[row][blk * 8 + j] = k_packed; + } + } else { + l_k_scale[row][blk] = 0.0f; + #pragma unroll + for (int j = 0; j < 8; ++j) l_k_packed[row][blk * 8 + j] = 0u; + } + } +#else + // Fallback: dequant q8_0 -> half in local memory. + const int k_blocks_per_row = DK / QK8_0; + const int n_blocks_total = BLOCK_N * k_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / k_blocks_per_row; + const int blk = i % k_blocks_per_row; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_off = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + const global char * blk_ptr = k_base + k_row_off + blk * Q8_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global char * qs = blk_ptr + 2; + #pragma unroll + for (int j = 0; j < 8; ++j) { + const float4 v = df * (float4)((float) qs[j*4 + 0], + (float) qs[j*4 + 1], + (float) qs[j*4 + 2], + (float) qs[j*4 + 3]); + l_k[row][blk * 8 + j] = (half4)((half) v.s0, (half) v.s1, (half) v.s2, (half) v.s3); + } + } else { + #pragma unroll + for (int j = 0; j < 8; ++j) l_k[row][blk * 8 + j] = (half4)(0.0h); + } + } +#endif + } + // V tile load — strategy-dependent. +#if FA_V_STRATEGY == 2 + { + // Int8 packed V in local memory + per-block scale. Accumulate + // step unpacks inline. + const int v_blocks_per_row = DV_Q8_BLOCKS_PREFILL; + const int n_blocks_total = BLOCK_N * v_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / v_blocks_per_row; + const int blk = i % v_blocks_per_row; + const int v_row_idx = k_start + row; + if (v_row_idx < n_kv) { + const ulong v_row_off = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; + const global char * blk_ptr = v_base + v_row_off + blk * Q8_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global uchar * qs = (const global uchar *)(blk_ptr + 2); + l_v_scale[row][blk] = df; + #pragma unroll + for (int j = 0; j < 8; ++j) { + uint v_packed = + (uint) qs[j*4 + 0] | + ((uint) qs[j*4 + 1]) << 8 | + ((uint) qs[j*4 + 2]) << 16 | + ((uint) qs[j*4 + 3]) << 24; + l_v_packed[row][blk * 8 + j] = v_packed; + } + } else { + l_v_scale[row][blk] = 0.0f; + #pragma unroll + for (int j = 0; j < 8; ++j) l_v_packed[row][blk * 8 + j] = 0u; + } + } + } +#else + { + // Default: dequant V -> half in local memory. + const int v_blocks_per_row = DV / QK8_0; + const int n_blocks_total = BLOCK_N * v_blocks_per_row; + for (int i = tid; i < n_blocks_total; i += WG_SIZE) { + const int row = i / v_blocks_per_row; + const int blk = i % v_blocks_per_row; + const int v_row_idx = k_start + row; + if (v_row_idx < n_kv) { + const ulong v_row_off = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; + const global char * blk_ptr = v_base + v_row_off + blk * Q8_0_BLOCK_SIZE; + const float df = (float) vload_half(0, (const global half *) blk_ptr); + const global char * qs = blk_ptr + 2; + #pragma unroll + for (int j = 0; j < 8; ++j) { + const float4 v = df * (float4)((float) qs[j*4 + 0], + (float) qs[j*4 + 1], + (float) qs[j*4 + 2], + (float) qs[j*4 + 3]); + l_v[row][blk * 8 + j] = (half4)((half) v.s0, (half) v.s1, (half) v.s2, (half) v.s3); + } + } else { + #pragma unroll + for (int j = 0; j < 8; ++j) l_v[row][blk * 8 + j] = (half4)(0.0h); + } + } + } +#endif + barrier(CLK_LOCAL_MEM_FENCE); + + // QK dot + online softmax. N_SPLIT>1 reduces per-thread partials via shuffle_xor. +#if N_SPLIT > 1 + { +#else + if (query_valid) { +#endif + const int k_blk_base = split_idx * SPLIT_DK_Q8_BLOCKS; + for (int j = 0; j < BLOCK_N; j += 4) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + const int k_row2 = k_start + j + 2; + const int k_row3 = k_start + j + 3; + + ACC_TYPE s0, s1, s2, s3; +#ifdef FA_HAVE_INT_DOT + // dp4a-accelerated QK dot over owned blocks. + s0 = 0.0f; s1 = 0.0f; s2 = 0.0f; s3 = 0.0f; + #pragma unroll + for (int b_local = 0; b_local < SPLIT_DK_Q8_BLOCKS; ++b_local) { + const int b = k_blk_base + b_local; + int sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0; + #pragma unroll + for (int g = 0; g < 8; ++g) { + const uint qp = q_packed_pf[b_local * 8 + g]; + sum0 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j ][b * 8 + g], sum0); + sum1 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j+1][b * 8 + g], sum1); + sum2 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j+2][b * 8 + g], sum2); + sum3 = dot_acc_sat_4x8packed_ss_int(qp, l_k_packed[j+3][b * 8 + g], sum3); + } + const float qd = q_d_pf[b_local]; + s0 += (float)sum0 * qd * l_k_scale[j ][b]; + s1 += (float)sum1 * qd * l_k_scale[j+1][b]; + s2 += (float)sum2 * qd * l_k_scale[j+2][b]; + s3 += (float)sum3 * qd * l_k_scale[j+3][b]; + } +#else + ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f); + #pragma unroll + for (int k = 0; k < SPLIT_DK_VEC; ++k) { + const ACC_TYPE4 qk = q_priv[k]; + const int k_abs = dk_off_vec + k; + dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j ][k_abs]), dot_acc0); + dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k_abs]), dot_acc1); + dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k_abs]), dot_acc2); + dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k_abs]), dot_acc3); + } + s0 = dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3; + s1 = dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3; + s2 = dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3; + s3 = dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3; +#endif + +#if N_SPLIT > 1 + // Power-of-2 N_SPLIT: shuffle_xor butterfly. N_SPLIT=3 (DK=96): 3-way shuffle. + #if (N_SPLIT & (N_SPLIT - 1)) == 0 + #pragma unroll + for (int step = 1; step < N_SPLIT; step <<= 1) { + s0 += sub_group_shuffle_xor(s0, step); + s1 += sub_group_shuffle_xor(s1, step); + s2 += sub_group_shuffle_xor(s2, step); + s3 += sub_group_shuffle_xor(s3, step); + } + #else + const uint tri_base = (get_sub_group_local_id() / N_SPLIT) * N_SPLIT; + s0 = sub_group_shuffle(s0, tri_base + 0) + sub_group_shuffle(s0, tri_base + 1) + sub_group_shuffle(s0, tri_base + 2); + s1 = sub_group_shuffle(s1, tri_base + 0) + sub_group_shuffle(s1, tri_base + 1) + sub_group_shuffle(s1, tri_base + 2); + s2 = sub_group_shuffle(s2, tri_base + 0) + sub_group_shuffle(s2, tri_base + 1) + sub_group_shuffle(s2, tri_base + 2); + s3 = sub_group_shuffle(s3, tri_base + 0) + sub_group_shuffle(s3, tri_base + 1) + sub_group_shuffle(s3, tri_base + 2); + #endif + if (!query_valid) { s0 = FA_M_INIT; s1 = FA_M_INIT; s2 = FA_M_INIT; s3 = FA_M_INIT; } +#endif + s0 *= scale; s1 *= scale; s2 *= scale; s3 *= scale; + + if (is_causal) { + const int causal_limit = n_kv - n_q + my_query_row; + if (k_row0 > causal_limit) s0 = FA_M_INIT; + if (k_row1 > causal_limit) s1 = FA_M_INIT; + if (k_row2 > causal_limit) s2 = FA_M_INIT; + if (k_row3 > causal_limit) s3 = FA_M_INIT; + } + if (k_row0 >= n_kv) s0 = FA_M_INIT; + if (k_row1 >= n_kv) s1 = FA_M_INIT; + if (k_row2 >= n_kv) s2 = FA_M_INIT; + if (k_row3 >= n_kv) s3 = FA_M_INIT; + + if (query_valid && mask_base != NULL && blk_cur != 2) { + const global MASK_DATA_TYPE * mask_ptr = + (const global MASK_DATA_TYPE *) (mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) s0 += slope * (ACC_TYPE) mask_ptr[k_row0]; + if (k_row1 < n_kv) s1 += slope * (ACC_TYPE) mask_ptr[k_row1]; + if (k_row2 < n_kv) s2 += slope * (ACC_TYPE) mask_ptr[k_row2]; + if (k_row3 < n_kv) s3 += slope * (ACC_TYPE) mask_ptr[k_row3]; + } + if (logit_softcap > 0.0f) { + s0 = logit_softcap * tanh(s0 / logit_softcap); + s1 = logit_softcap * tanh(s1 / logit_softcap); + s2 = logit_softcap * tanh(s2 / logit_softcap); + s3 = logit_softcap * tanh(s3 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3))); + // Whole tile masked (m_new == FA_M_INIT): force the exp() args + // far negative so the tile contributes 0, not exp(0)=1. + const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new; + const ACC_TYPE scale_prev = native_exp(m_i - m_exp); + const ACC_TYPE p0 = native_exp(s0 - m_exp); + const ACC_TYPE p1 = native_exp(s1 - m_exp); + const ACC_TYPE p2 = native_exp(s2 - m_exp); + const ACC_TYPE p3 = native_exp(s3 - m_exp); + +#if FA_V_STRATEGY == 2 + #pragma unroll + for (int b_local = 0; b_local < DV_Q8_BLOCKS_PREFILL / N_SPLIT; ++b_local) { + const int b_abs = split_idx * (DV_Q8_BLOCKS_PREFILL / N_SPLIT) + b_local; + const float d0 = l_v_scale[j ][b_abs]; + const float d1 = l_v_scale[j+1][b_abs]; + const float d2 = l_v_scale[j+2][b_abs]; + const float d3 = l_v_scale[j+3][b_abs]; + #pragma unroll + for (int g = 0; g < 8; ++g) { + const int lane_abs = b_abs * 8 + g; + const int lane_local = b_local * 8 + g; + uint pk0 = l_v_packed[j ][lane_abs]; + uint pk1 = l_v_packed[j+1][lane_abs]; + uint pk2 = l_v_packed[j+2][lane_abs]; + uint pk3 = l_v_packed[j+3][lane_abs]; + float4 v0 = d0 * (float4)((float)(char)(pk0 & 0xff), (float)(char)((pk0>>8)&0xff), (float)(char)((pk0>>16)&0xff), (float)(char)((pk0>>24)&0xff)); + float4 v1 = d1 * (float4)((float)(char)(pk1 & 0xff), (float)(char)((pk1>>8)&0xff), (float)(char)((pk1>>16)&0xff), (float)(char)((pk1>>24)&0xff)); + float4 v2 = d2 * (float4)((float)(char)(pk2 & 0xff), (float)(char)((pk2>>8)&0xff), (float)(char)((pk2>>16)&0xff), (float)(char)((pk2>>24)&0xff)); + float4 v3 = d3 * (float4)((float)(char)(pk3 & 0xff), (float)(char)((pk3>>8)&0xff), (float)(char)((pk3>>16)&0xff), (float)(char)((pk3>>24)&0xff)); + o_acc[lane_local] = mad(p3, v3, + mad(p2, v2, + mad(p1, v1, + mad(p0, v0, + o_acc[lane_local] * scale_prev)))); + } + } +#else // FA_V_STRATEGY == 0 + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) { + const int i_abs = dv_off_vec + i; + o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i_abs]), + mad(p2, CONVERT_KV_ACC4(l_v[j+2][i_abs]), + mad(p1, CONVERT_KV_ACC4(l_v[j+1][i_abs]), + mad(p0, CONVERT_KV_ACC4(l_v[j ][i_abs]), + o_acc[i] * scale_prev)))); + } +#endif + l_i = l_i * scale_prev + p0 + p1 + p2 + p3; + m_i = m_new; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + // Write output. With N_SPLIT>1 each thread writes its SPLIT_DV_VEC slice. + if (query_valid) { + if (sinks_void != NULL) { + const global ACC_TYPE * sinks_ptr = + (const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset); + const ACC_TYPE m_sink = sinks_ptr[head_idx]; + const ACC_TYPE m_final = max(m_i, m_sink); + const ACC_TYPE scale_o = exp(m_i - m_final); + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_acc[i] *= scale_o; + l_i = l_i * scale_o + exp(m_sink - m_final); + m_i = m_final; + } + const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f; + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global float4 * o_row = (global float4 *) (o_base + o_row_offset); + if (l_inv > 0.0f) { + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_row[dv_off_vec + i] = o_acc[i] * l_inv; + } else { + #pragma unroll + for (int i = 0; i < SPLIT_DV_VEC; ++i) o_row[dv_off_vec + i] = (float4)(0.0f); + } + } +} + +// FD Pass 2: merge split partials. Identical across q4_0/q8_0/f16; each FA +// source owns a copy since kernels compile per-source-program. +__kernel void flash_attn_f32_merge( + const global float * partial_void, + global void * o_void, + const ulong o_offset, + const int n_head, + const int n_splits, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const global void * sinks_void, + const ulong sinks_offset, + const int n_q +) { + const int lane = get_local_id(0); + const int head_batch_idx = get_global_id(1); + const int q_idx = get_global_id(2); + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const ulong record_stride = (ulong) FA_PARTIAL_FLOATS; + const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits; + const global float * rec0 = partial_void + record_idx_0 * record_stride; + + __local ACC_TYPE m_final_shared; + __local ACC_TYPE l_final_shared; + if (lane == 0) { + ACC_TYPE m = FA_M_INIT; + for (int c = 0; c < n_splits; ++c) { + const ACC_TYPE m_c = rec0[c * record_stride + 0]; + m = max(m, m_c); + } + ACC_TYPE m_sink = 0.0f; + bool has_sink = false; + if (sinks_void != NULL) { + const global ACC_TYPE * sinks_ptr = + (const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset); + m_sink = sinks_ptr[head_idx]; + has_sink = true; + m = max(m, m_sink); + } + ACC_TYPE l = 0.0f; + for (int c = 0; c < n_splits; ++c) { + const ACC_TYPE m_c = rec0[c * record_stride + 0]; + const ACC_TYPE l_c = rec0[c * record_stride + 1]; + if (m_c > FA_M_INIT) { + l += l_c * exp(m_c - m); + } + } + if (has_sink) { + l += exp(m_sink - m); + } + m_final_shared = m; + l_final_shared = l; + } + barrier(CLK_LOCAL_MEM_FENCE); + const ACC_TYPE m_final = m_final_shared; + const ACC_TYPE l_final = l_final_shared; + const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f; + + ACC_TYPE4 o = (ACC_TYPE4)(0.0f); + for (int c = 0; c < n_splits; ++c) { + const global float * rec_c = rec0 + c * record_stride; + const ACC_TYPE m_c = rec_c[0]; + if (m_c <= FA_M_INIT) continue; + const global float4 * rec_oc = (const global float4 *) (rec_c + 2); + const ACC_TYPE scale_c = exp(m_c - m_final); + o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o); + } + o = o * l_inv; + + const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1; + global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset); + o_row[lane] = CONVERT_O_DATA4(o); +} diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_pre_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_pre_f16.cl new file mode 100644 index 000000000000..88ead4bcb513 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/flash_attn_pre_f16.cl @@ -0,0 +1,156 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void flash_attn_kv_pad_f16( + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * k_pad_void, + global void * v_pad_void, + const int n_kv, + const int n_head_kv, + const int n_batch, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3 +) { + const int row_idx = get_global_id(0); + const int head_kv_idx = get_global_id(1); + const int batch_idx = get_global_id(2); + + if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) { + return; + } + + const int tail_start = n_kv - (n_kv % BLOCK_N); + const int src_row_idx = tail_start + row_idx; + + const global char * k_src = (const global char *) k_void + k_offset; + const global char * v_src = (const global char *) v_void + v_offset; + global char * k_pad = (global char *) k_pad_void; + global char * v_pad = (global char *) v_pad_void; + + const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1; + const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1; + + if (src_row_idx < n_kv) { + const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1; + const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1; + + for (ulong i = 0; i < k_nb1; ++i) { + k_pad[k_dst_offset + i] = k_src[k_src_offset + i]; + } + for (ulong i = 0; i < v_nb1; ++i) { + v_pad[v_dst_offset + i] = v_src[v_src_offset + i]; + } + } else { + for (ulong i = 0; i < k_nb1; ++i) { + k_pad[k_dst_offset + i] = 0; + } + for (ulong i = 0; i < v_nb1; ++i) { + v_pad[v_dst_offset + i] = 0; + } + } +} + +__kernel void flash_attn_mask_pad_f16( + const global void * mask_void, ulong mask_offset, + global void * mask_pad_void, + const int n_q, + const int n_kv, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int col_idx = get_global_id(0); + const int q_row = get_global_id(1); + const int mask_slice = get_global_id(2); + + if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) { + return; + } + + const int tail_start = n_kv - (n_kv % BLOCK_N); + const int src_col_idx = tail_start + col_idx; + const int mask_head_idx = mask_slice % mask_ne2; + const int mask_batch_idx = mask_slice / mask_ne2; + + const global char * mask_src_base = (const global char *) mask_void + mask_offset + + (ulong) mask_batch_idx * mask_nb3 + + (ulong) mask_head_idx * mask_nb2 + + (ulong) q_row * mask_nb1; + const global half * mask_src = (const global half *) mask_src_base; + + global half * mask_pad = (global half *) mask_pad_void; + const ulong dst_idx = + (((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N + + (ulong) col_idx; + + mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY); +} + +// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask), +// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1. +__kernel void flash_attn_blk_f16( + const global void * mask_void, ulong mask_offset, + global char * blk, + const int n_q, + const int n_kv, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int kv_block_idx = get_global_id(0); + const int q_block_idx = get_global_id(1); + const int mask_slice = get_global_id(2); + + const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M; + const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N; + if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) { + return; + } + + const int mask_head_idx = mask_slice % mask_ne2; + const int mask_batch_idx = mask_slice / mask_ne2; + const int q_start = q_block_idx * BLOCK_M; + const int k_start = kv_block_idx * BLOCK_N; + const int q_count = min(BLOCK_M, n_q - q_start); + const int k_count = min(BLOCK_N, n_kv - k_start); + + const half neg_max_half = (half) (-65504.0f); + char has_unmasked = 0; + char has_masked = 0; + char has_nonzero = 0; + + const global char * mask_base = (const global char *) mask_void + mask_offset + + (ulong) mask_batch_idx * mask_nb3 + + (ulong) mask_head_idx * mask_nb2; + + for (int qi = 0; qi < q_count; ++qi) { + const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start; + for (int ki = 0; ki < k_count; ++ki) { + const half v = mask_row[ki]; + if (v <= neg_max_half) { + has_masked = 1; + } else { + has_unmasked = 1; + if (v != (half) 0.0f) { + has_nonzero = 1; + } + } + } + if (has_masked && has_unmasked) break; // mixed tile — short-circuit. + } + + char res; + if (has_unmasked == 0) { + res = 0; + } else if (has_masked || has_nonzero) { + res = 1; + } else { + res = 2; + } + + blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res; +} diff --git a/ggml/src/ggml-opencl/kernels/set_rows.cl b/ggml/src/ggml-opencl/kernels/set_rows.cl index fc3ff7aa1e72..4ad5af13f138 100644 --- a/ggml/src/ggml-opencl/kernels/set_rows.cl +++ b/ggml/src/ggml-opencl/kernels/set_rows.cl @@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32( } } +// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32]. +#define QK8_0 32 + +inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) { + float amax = 0.0f; + for (int j = 0; j < QK8_0; j++) { + amax = fmax(amax, fabs(x[j])); + } + + float d = amax / 127.0f; + float id = (d != 0.0f) ? 127.0f / amax : 0.0f; + + vstore_half(d, 0, d_out); + + for (int j = 0; j < QK8_0; j++) { + qs[j] = (char)((int)round(x[j] * id)); + } +} + +kernel void kernel_set_rows_q8_0_i64( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3); + global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK8_0; + global char * y = dst_row + blk * (2 + QK8_0); + + quantize_q8_0_block(x, y + 2, (global half *)y); + } +} + +kernel void kernel_set_rows_q8_0_i32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3); + global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK8_0; + global char * y = dst_row + blk * (2 + QK8_0); + + quantize_q8_0_block(x, y + 2, (global half *)y); + } +} + +// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block. +// Layout matches kernel_convert_block_q8_0; block index follows dst element order. +kernel void kernel_set_rows_q8_0_soa_i64( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst_q, + ulong offset_q, + global char * dst_d, + ulong offset_d, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + int ne1_dst, + int ne2_dst, + int ne3_dst +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst_q = dst_q + offset_q; + dst_d = dst_d + offset_d; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0; + + global half * d_row = (global half *)(dst_d) + row_blk_base; + global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0; + global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK8_0; + global char * q = q_row + blk * QK8_0; + + quantize_q8_0_block(x, q, d_row + blk); + } +} + +kernel void kernel_set_rows_q8_0_soa_i32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst_q, + ulong offset_q, + global char * dst_d, + ulong offset_d, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + int ne1_dst, + int ne2_dst, + int ne3_dst +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst_q = dst_q + offset_q; + dst_d = dst_d + offset_d; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0; + + global half * d_row = (global half *)(dst_d) + row_blk_base; + global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0; + global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK8_0; + global char * q = q_row + blk * QK8_0; + + quantize_q8_0_block(x, q, d_row + blk); + } +} + kernel void kernel_set_rows_f16_i32( global char * src0, ulong offset0, @@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32( dst_row[ind] = src_row[ind]; } } + +// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled +// nibbles: qs[j] low/high = elem j / j+16). +// Dequant: val[i] = d * (nibble_i - 8) +// nblk0 = number of q4_0 blocks per row = ne00 / 32. +#define QK4_0 32 +#define Q4_0_BLOCK_SIZE 18 + +inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) { + // Find the signed value with the largest absolute magnitude (matches ggml ref). + float max = 0.0f; + float amax = 0.0f; + for (int j = 0; j < QK4_0; j++) { + float v = x[j]; + float a = fabs(v); + if (a > amax) { + amax = a; + max = v; + } + } + + float d = max / -8.0f; + float id = (d != 0.0f) ? 1.0f / d : 0.0f; + + vstore_half(d, 0, d_out); + + for (int j = 0; j < QK4_0/2; j++) { + float x0 = x[j] * id; + float x1 = x[j + QK4_0/2] * id; + + int i0 = (int)(x0 + 8.5f); + int i1 = (int)(x1 + 8.5f); + if (i0 < 0) i0 = 0; + if (i0 > 15) i0 = 15; + if (i1 < 0) i1 = 0; + if (i1 > 15) i1 = 15; + + qs[j] = (uchar)i0 | ((uchar)i1 << 4); + } +} + +kernel void kernel_set_rows_q4_0_i64( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3); + global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK4_0; + global char * y = dst_row + blk * Q4_0_BLOCK_SIZE; + global half * yd = (global half *)(y); + global uchar * yqs = (global uchar *)(y + 2); + + quantize_q4_0_block(x, yqs, yd); + } +} + +kernel void kernel_set_rows_q4_0_i32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3); + global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK4_0; + global char * y = dst_row + blk * Q4_0_BLOCK_SIZE; + global half * yd = (global half *)(y); + global uchar * yqs = (global uchar *)(y + 2); + + quantize_q4_0_block(x, yqs, yd); + } +} + +// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records +// into separate quant (dst_q) and scale (dst_d) sub-buffers — same pattern as +// the q8_0 SoA variants above. +// +// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant): +// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes. +// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes. +// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j, +// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is. +kernel void kernel_set_rows_q4_0_soa_i64( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst_q, + ulong offset_q, + global char * dst_d, + ulong offset_d, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + int ne1_dst, + int ne2_dst, + int ne3_dst +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst_q = dst_q + offset_q; + dst_d = dst_d + offset_d; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0; + + global half * d_row = (global half *)(dst_d) + row_blk_base; + global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2); + global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK4_0; + global uchar * qs = q_row + blk * (QK4_0/2); + global half * d_bk = d_row + blk; + + quantize_q4_0_block(x, qs, d_bk); + } +} + +kernel void kernel_set_rows_q4_0_soa_i32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst_q, + ulong offset_q, + global char * dst_d, + ulong offset_d, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + uint4 ne11, + uint4 ne12, + ulong nb10, + ulong nb11, + ulong nb12, + int nblk0, + int ne1_dst, + int ne2_dst, + int ne3_dst +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst_q = dst_q + offset_q; + dst_d = dst_d + offset_d; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1); + + if (i01 >= ne01) { + return; + } + + int i12 = fastmod(i03, ne12); + int i11 = fastmod(i02, ne11); + + int i10 = i01; + int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; + + long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0; + + global half * d_row = (global half *)(dst_d) + row_blk_base; + global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2); + global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + + for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) { + global float * x = src_row + blk * QK4_0; + global uchar * qs = q_row + blk * (QK4_0/2); + global half * d_bk = d_row + blk; + + quantize_q4_0_block(x, qs, d_bk); + } +}