From 9e3f81e16a1ca1fb4f377ac1ab1f2d78a135b4be Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Fri, 19 Jun 2026 22:28:38 +0200 Subject: [PATCH 01/30] mtmd, arg: fix utf8 handling on windows (llama/24779) * mtmd, arg: fix utf8 handling on windows * also fix ggml_fopen * fix build fail * also fix CLI --- ggml/src/ggml.c | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b43016c87d2..0f682fd1856 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -600,18 +600,15 @@ FILE * ggml_fopen(const char * fname, const char * mode) { // convert fname (UTF-8) wchar_t * wfname = ggml_mbstowcs(fname); if (wfname) { - // convert mode (ANSI) - wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t)); - wchar_t * wmode_p = wmode; - do { - *wmode_p++ = (wchar_t)*mode; - } while (*mode++); - - // open file - file = _wfopen(wfname, wmode); + // convert mode (UTF-8) + wchar_t * wmode = ggml_mbstowcs(mode); + if (wmode) { + // open file + file = _wfopen(wfname, wmode); + GGML_FREE(wmode); + } GGML_FREE(wfname); - GGML_FREE(wmode); } return file; From 7163f689d92dcb0264b574f48d46fad61a1ae471 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Sat, 20 Jun 2026 08:12:32 +0900 Subject: [PATCH 02/30] ggml-webgpu: add adapter toggles for F16 on Vulkan + NVIDIA --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 42 +++++++++++----------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 0b605fa86ba..f71d1aee73a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3788,7 +3788,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3800,17 +3800,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { options.nextInChain = &adapterTogglesDesc; #endif - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->webgpu_global_ctx->adapter = std::move(adapter); - }), - UINT64_MAX); + instance.WaitAny(instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); +} + +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter); GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); @@ -4543,20 +4546,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { // Probe for adapter support wgpu::Adapter adapter; if (ctx->webgpu_global_ctx->instance != nullptr) { - wgpu::RequestAdapterOptions options = {}; - - // probe for adapter support - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - adapter = std::move(_adapter); - }), - UINT64_MAX); + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter); } // WebGPU backend requires f16 support and, on native, implicit device synchronization. From a32ed1c15d62082f3cd34fb0c6f232bf9e32e946 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 20 Jun 2026 12:43:06 +0200 Subject: [PATCH 03/30] ggml : optimize AMX (llama/24806) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flatten the partition over n_batch * M so every thread participates in the quantization | CPU | Model | Test | t/s OLD | t/s NEW | Speedup | |:--------------------------------|:------------------------------|:-------|----------:|----------:|----------:| | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw | pp512 | 730.71 | 779.86 | 1.07 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw | tg128 | 87.88 | 86.79 | 0.99 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512 | 725.09 | 1023.31 | 1.41 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128 | 83.64 | 83.62 | 1.00 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0 | pp512 | 820.51 | 924.05 | 1.13 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0 | tg128 | 90.59 | 92.46 | 1.02 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1 | pp512 | 776.88 | 872.79 | 1.12 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1 | tg128 | 89.39 | 90.94 | 1.02 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M | pp512 | 719.28 | 1009.27 | 1.40 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M | tg128 | 80.62 | 80.86 | 1.00 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S | pp512 | 732.29 | 1077.29 | 1.47 | | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S | tg128 | 86.42 | 83.53 | 0.97 | Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/amx/mmq.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index d9383a04be8..9f3a744b5de 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - parallel_for_ggml(params, n_batch, [&](int begin, int end) { - for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + parallel_for_ggml(params, n_batch * M, [&](int begin, int end) { + for (int idx = begin; idx < end; ++idx) { + int batch_idx = idx / M; + int m = idx % M; int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); const float * A_data = (const float *)((const char *)src1->data + src1_offset); char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; - - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, wdata_batch + m * row_size_A, K); - } + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); } }); }); From 44aff5f606858244855997da494c25c654807d84 Mon Sep 17 00:00:00 2001 From: Guanhuai Zhang <67999475+BiReRa@users.noreply.github.com> Date: Sun, 21 Jun 2026 05:58:49 +0800 Subject: [PATCH 04/30] fix(hexagon): use padded stride for ssm-conv weights (llama/24470) --- ggml/src/ggml-hexagon/htp/ssm-conv.c | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index d574da2e2bc..a48bc9ed86b 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { // transposed into VTCM. // // VTCM layouts (per thread): -// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). -// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// src1_T : {d_inner_stride, d_conv} - staged once per launch (small). +// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile. // // d_inner_tile is chosen so that per-thread VTCM stays under the budget. // Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. #define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread -// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM) static inline void transpose_src1(const float * src1_data, uint32_t src1_stride_inner, uint32_t i1_off, uint32_t d_inner_per_thread, + uint32_t d_inner_stride, uint32_t d_conv, float * src1_T) { for (uint32_t i = 0; i < d_inner_per_thread; ++i) { const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; for (uint32_t j = 0; j < d_conv; ++j) { - src1_T[j * d_inner_per_thread + i] = src_row[j]; + src1_T[j * d_inner_stride + i] = src_row[j]; } } } @@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void } const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_stride = scctx->nrows_per_thread; const uint32_t d_inner_tile = scctx->d_inner_tile; const float * src0_data = (const float *) src0->data; @@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. - transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); + // Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T); const uint32_t C_TILE = VLEN_FP32; @@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc = hvx_vec_splat_f32(0.0f); for (uint32_t j = 0; j < d_conv; ++j) { HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); - HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb); acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); } HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); @@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { use_hvx = 1; } - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); + scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32); const uint32_t d_inner_per_thread = scctx.nrows_per_thread; const uint32_t ncs = src0->ne[0]; From bf0da8c019ad30486692af7e5bc5fa8fceb5847f Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 22 Jun 2026 19:09:02 +0800 Subject: [PATCH 05/30] support bf16 on bin_bcast OP and unary OPs (llama/24838) * support bf16 on bin_bcast OP and unary OPs * support the older Intel compiler than 2026.0 --- ggml/src/ggml-sycl/binbcast.cpp | 5 + ggml/src/ggml-sycl/element_wise.cpp | 208 +++++++++++++++++++++------- 2 files changed, 160 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index ad2e6ca35e5..306eeddc0c0 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) { + op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data, + (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, + ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), + ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); #endif } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index aca68e58ee1..0c82ceb969f 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) { return x > static_cast(0.f) ? static_cast(1.f) : ((x < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } + template static __dpct_inline__ T op_abs(T x) { - return sycl::fabs(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed + } else { + return sycl::fabs(x); + } +} + +template +static __dpct_inline__ T op_expm1(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::expm1(static_cast(x)) + ); + } else { + return sycl::expm1(x); + } } template static __dpct_inline__ T op_elu(T x) { - return (x > static_cast(0.f)) ? x : sycl::expm1(x); + return (x > static_cast(0.f)) ? x : op_expm1(x); +} + +template +static __dpct_inline__ T op_tanh(T x) { + if constexpr (std::is_same_v) { + constexpr int ver = __INTEL_LLVM_COMPILER; +#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000) + return sycl::ext::oneapi::experimental::tanh(x); +#else + return static_cast(sycl::tanh(static_cast(x))); +#endif + } else { + return sycl::tanh(x); + } } template @@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) { const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); return static_cast(0.5f) * x * (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); + op_tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); +} + +template +static __dpct_inline__ T op_exp(T x) { + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::exp(x); + } else { + return sycl::exp(x); + } } template static __dpct_inline__ T op_silu(T x) { - return x / (static_cast(1.0f) + sycl::native::exp(-x)); + return x / (static_cast(1.0f) + op_exp(-x)); } template -static __dpct_inline__ T op_gelu_quick(T x) { - const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); - return x * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x))); +static __dpct_inline__ T op_erf(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::erf(static_cast(x)) + ); + } else { + return sycl::erf(x); + } } template static __dpct_inline__ T op_gelu_erf(T x) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - return static_cast(0.5f) * x * (static_cast(1.0f) + sycl::erf(x * SQRT_2_INV)); + return static_cast(0.5f) * x * (static_cast(1.0f) + op_erf(x * SQRT_2_INV)); } template -static __dpct_inline__ T op_tanh(T x) { - return sycl::tanh(x); +static __dpct_inline__ T op_gelu_quick(T x) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + return x * (static_cast(1.0f) / (static_cast(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x))); } template static __dpct_inline__ T op_relu(T x) { - return sycl::fmax(x, static_cast(0)); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)); + } else { + return sycl::fmax(x, static_cast(0)); + } } template static __dpct_inline__ T op_sigmoid(T x) { - return static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x)); + return static_cast(1.0f) / (static_cast(1.0f) + op_exp(-x)); } template static __dpct_inline__ T op_sqrt(T x) { - return sycl::sqrt(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sqrt(x); + } else { + return sycl::sqrt(x); + } } template static __dpct_inline__ T op_sin(T x) { - return sycl::sin(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sin(x); + } else { + return sycl::sin(x); + } } template static __dpct_inline__ T op_cos(T x) { - return sycl::cos(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::cos(x); + } else { + return sycl::cos(x); + } } template static __dpct_inline__ T op_hardsigmoid(T x) { - return sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmin( + static_cast(1.0f), sycl::ext::oneapi::experimental::fmax( + static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return sycl::fmin(static_cast(1.0f), + sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template static __dpct_inline__ T op_hardswish(T x) { - return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); -} - -template -static __dpct_inline__ T op_exp(T x) { - return sycl::exp(x); -} - -template -static __dpct_inline__ T op_expm1(T x) { - return sycl::expm1(x); + if constexpr (std::is_same_v) { + return x * sycl::ext::oneapi::experimental::fmin(static_cast(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template @@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) { if (x <= static_cast(0)) { return neg_infinity(); } - return sycl::log(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::log(x); + } else { + return sycl::log(x); + } } template static __dpct_inline__ T op_softplus(T x) { const float xf = (float) x; - const float ax = sycl::fabs(xf); + const float ax = op_abs(xf); const float m = sycl::fmax(xf, 0.0f); const float y = m + sycl::log1p(sycl::exp(-ax)); return (T) y; @@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) { template static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) { T neg_slope_T = static_cast(negative_slope); - return sycl::fmax(x, static_cast(0)) + + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)) + + sycl::ext::oneapi::experimental::fmin(x, static_cast(0.0f)) * neg_slope_T; + + } else { + return sycl::fmax(x, static_cast(0)) + sycl::fmin(x, static_cast(0.0f)) * neg_slope_T; + } } template @@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { template static __dpct_inline__ T op_floor(T x) { - return sycl::floor(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::floor(x); + } else { + return sycl::floor(x); + } } template static __dpct_inline__ T op_ceil(T x) { - return sycl::ceil(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::ceil(x); + } else { + return sycl::ceil(x); + } } template static __dpct_inline__ T op_round(T x) { - return sycl::round(x); + if constexpr (std::is_same_v) { + return static_cast( + sycl::round(static_cast(x)) + ); + } else { + return sycl::round(x); + } } template static __dpct_inline__ T op_trunc(T x) { - return sycl::trunc(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::trunc(x); + } else { + return sycl::trunc(x); + } } template @@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> /*item_ct1*/) { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } @@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step, template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); @@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } +#ifdef GGML_SYCL_HAS_BF16 + case GGML_TYPE_BF16: + { + auto data_pts = cast_data(dst); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); + break; + } +#endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); @@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary( stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_generic_kernel( src, dst_ptr, k_elements, ne0, ne1, ne2, ne3, @@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE), sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { arange_kernel(dst_ptr, k, start, step, item_ct1); }); } @@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1); }); }, negative_slope); @@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); }); }, min_val, max_val); @@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp return out_glu; } - template static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, @@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x, const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1); }); } @@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); From 2c3dd7b25552a1625ac6a255b6a5b3e6e3e3149a Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Mon, 22 Jun 2026 22:25:21 -0700 Subject: [PATCH 06/30] opencl: q8_0 gemv precision improvement (llama/24923) --- ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl index 9703b693e56..f5c6fb3e843 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl @@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32( regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; - dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB); } // reduction in local memory, assumes #wave=4 From 44d75b606d298c8341479b6d290a02c59166a860 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 23 Jun 2026 17:13:55 +0900 Subject: [PATCH 07/30] ggml-webgpu: improve MTP inference by using mat-vec path for small batches (llama/24811) * ggml-webgpu: improve small batches decoding * Add barrier to the NUM_COLS loop in mul-mat-vec --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 +- .../wgsl-shaders/mul_mat_id_vec.wgsl | 4 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 84 +- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 1005 +++++++++-------- .../wgsl-shaders/mul_mat_vec_q_acc.tmpl | 132 ++- .../ggml-webgpu/wgsl-shaders/quantize_q8.wgsl | 23 +- 7 files changed, 685 insertions(+), 596 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6f877f15ce9..c00a2e9ee9b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + uint32_t num_cols; bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && - use_mmvq == other.use_mmvq; + num_cols == other.num_cols && use_mmvq == other.use_mmvq; } }; @@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } @@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; uint32_t n_experts; + uint32_t num_cols; int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && - vectorized == other.vectorized; + num_cols == other.num_cols && vectorized == other.vectorized; } }; @@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } @@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, const ggml_tensor * src1, bool supports_dot_product, const std::string & vendor) { - if (src1->ne[1] == 1) { + if (src1->ne[1] <= 4) { bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; if (supports_dp4a && supports_dot_product) { switch (src1->type) { @@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.num_cols = context.dst->ne[1]; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols)); auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); @@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=1")); defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f71d1aee73a..e8eafd185a4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & const size_t dst_offset = ggml_webgpu_tensor_offset(dst); const size_t q8_src1_align_offset = ROUNDUP_POW2( dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t q8_src1_binding_size = - ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), - WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t q8_src1_binding_size = ROUNDUP_POW2( + src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); std::vector q8_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], }; @@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & uint32_t q8_wg_x = 1; uint32_t q8_wg_y = 1; const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; - const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); @@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst) { // Determine if this is a mat-vec operation - bool is_vec = (dst->ne[1] == 1); + bool use_mat_vec = (dst->ne[1] <= 4); // use MMVQ path for mat-vec bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, @@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, webgpu_pipeline pipeline; std::vector dispatches; - if (is_vec) { + if (use_mat_vec) { if (use_mmvq) { ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); } @@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (is_vec) { + if (use_mat_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; @@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, ctx->webgpu_global_ctx->vendor); if (use_mmvq) { - const size_t q8_src1_size = - src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] * + (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); res = ROUNDUP_POW2(res + q8_src1_size + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl index 6ff9bcf2df0..78ae955e6ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -103,7 +103,7 @@ fn main( #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); + let subgroup_total = subgroupAdd(acc[0][row]); if (subgroup_invocation_id == 0u) { partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } @@ -126,7 +126,7 @@ fn main( #ifdef USE_WORKGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; + partial_sums[partial_index(row, thread_id)] = acc[0][row]; } workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f0a7fbd059a..ebdf09513e2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -91,61 +91,67 @@ fn main( let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; #ifdef MMVQ - let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u); let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); #else let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #endif -#ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } + for (var col = 0u;col < NUM_COLS;col += 1) { - workgroupBarrier(); +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[col][row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } + workgroupBarrier(); + + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + col * params.m + row] = row_total; + } + } #endif #ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[col][row]; + } - workgroupBarrier(); + workgroupBarrier(); - var stride = WG_SIZE / 2u; + var stride = WG_SIZE / 2u; - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; } - } - workgroupBarrier(); - stride = stride / 2; - } + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif + + workgroupBarrier(); - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } } -#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 08753b9d643..b0703fe9062 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -32,8 +32,8 @@ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { #endif #ifdef MUL_ACC_FLOAT -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let k_vec = params.k / VEC_SIZE; let src1_idx_base_vec = src1_idx_base / VEC_SIZE; @@ -41,12 +41,18 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src // Each thread walks K, loads from the vector, and updates // a small block of output rows held in registers. for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; + var x_vals: array; + for (var col = 0u;col < NUM_COLS;col += 1) { + x_vals[col] = src1[src1_idx_base_vec + col * (params.stride_11 / VEC_SIZE) + k]; + } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); + let w = src0[src0_idx]; + for (var col = 0u;col < NUM_COLS;col += 1) { + acc[col][row] += inner_dot(w, x_vals[col]); + } } } } @@ -60,30 +66,33 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 16 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[col][bit]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -97,35 +106,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -139,36 +150,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 20 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -182,19 +195,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 22 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -203,18 +217,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 2u); let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -228,19 +243,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 24 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -250,18 +266,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 4u); let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -275,33 +292,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 34 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -315,34 +337,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 36 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d + m; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -355,8 +382,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 84 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -379,14 +406,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 64u + i]); + x_block[col][i + 12u] = f32(src1[x_base + col * params.stride_11 + 96u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -404,30 +432,32 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qs0 = q_u32 & 0xFFFFu; let qs1 = q_u32 >> 16u; - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; - - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); - - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[col][0] + x_block[col][1] + x_block[col][2] + x_block[col][3]; + sumy[1] = x_block[col][4] + x_block[col][5] + x_block[col][6] + x_block[col][7]; + sumy[2] = x_block[col][8] + x_block[col][9] + x_block[col][10] + x_block[col][11]; + sumy[3] = x_block[col][12] + x_block[col][13] + x_block[col][14] + x_block[col][15]; + + acc1[0] = x_block[col][0] * f32(qs0 & 0x0003u) + x_block[col][2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[col][1] * f32(qs0 & 0x0300u) + x_block[col][3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[col][4] * f32(qs0 & 0x000Cu) + x_block[col][6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[col][5] * f32(qs0 & 0x0C00u) + x_block[col][7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[col][8] * f32(qs0 & 0x0030u) + x_block[col][10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[col][9] * f32(qs0 & 0x3000u) + x_block[col][11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[col][12] * f32(qs0 & 0x00C0u) + x_block[col][14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[col][13] * f32(qs0 & 0xC000u) + x_block[col][15] * f32(qs1 & 0xC000u); + + acc[col][row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } } } } @@ -440,8 +470,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -485,12 +515,13 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 8u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -516,28 +547,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[col][l + 0u] * f32(qs & qm0); + s2 += x_block[col][l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[col][l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[col][l + 1u], (hv & hm1) == 0u); + s4 += x_block[col][l + 8u] * f32(qs & qm2); + s5 += x_block[col][l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[col][l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[col][l + 9u], (hv & hm3) == 0u); + } - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[col][row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } } } } @@ -550,8 +583,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 144 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -573,12 +606,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { @@ -613,23 +649,25 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[col][i] * f32(q1b & 0x0Fu); + dot[1] += x_block[col][i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[col][i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[col][i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[col][i]; + sumx[1] += x_block[col][i + 4u]; + sumx[2] += x_block[col][i + 8u]; + sumx[3] += x_block[col][i + 12u]; + } - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + acc[col][row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } } } } @@ -642,8 +680,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 176 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -671,14 +709,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -712,37 +752,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); - - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; - - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; - - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[col][i]; + let yl8 = x_block[col][i + 4u]; + let yh0 = x_block[col][i + 8u]; + let yh8 = x_block[col][i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); + acc[col][row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } } } } @@ -755,8 +797,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 210 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -777,14 +819,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var l = 0u; l < 4u; l++) { + x_block[col][l] = f32(src1[col_base + l]); + x_block[col][l + 4u] = f32(src1[col_base + 32u + l]); + x_block[col][l + 8u] = f32(src1[col_base + 64u + l]); + x_block[col][l + 12u] = f32(src1[col_base + 96u + l]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -802,26 +846,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - var sums = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sums = vec4(0.0, 0.0, 0.0, 0.0); - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; - } + sums[0] += x_block[col][l] * dq0; + sums[1] += x_block[col][l + 4u] * dq1; + sums[2] += x_block[col][l + 8u] * dq2; + sums[3] += x_block[col][l + 12u] * dq3; + } - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); + acc[col][row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } } } } @@ -834,8 +880,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 50 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -850,11 +896,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -866,20 +913,22 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -892,8 +941,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 56 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -908,11 +957,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -936,26 +986,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_lo = qh & 0xFFu; let qh_hi = (qh >> 8u) & 0xFFu; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -968,8 +1020,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 66 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -984,11 +1036,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -999,22 +1052,24 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux_hi >> 28u; let db = d * (0.5 + f32(ls)) * 0.25; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1027,8 +1082,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 74 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1043,11 +1098,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1058,27 +1114,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(scales_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1091,8 +1149,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 82 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1107,11 +1165,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1124,24 +1183,26 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(sc_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1154,8 +1215,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 98 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1170,11 +1231,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1186,27 +1248,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux >> 28u; let db = d * (0.5 + f32(ls)) * 0.5; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1219,8 +1283,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1235,11 +1299,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1255,28 +1320,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; let db = d * (1.0 + 2.0 * f32(sub_scale)); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1290,35 +1357,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + i + 16u]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1331,8 +1400,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 136 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1346,11 +1415,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1370,17 +1440,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[col][i]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1394,35 +1466,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 17 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); let e = ldexp(1.0, i32(eu8) - 128); - var row_sum = 0.0; let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl index 3ef2f77ebe0..6ccaf61a6a0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_0 #ifdef MUL_ACC_Q4_1 #define BLOCK_SIZE_BYTES 20 @@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2 { f32(load_f16_at_src0(block_byte_base + 2u)) ); } -fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_1 #ifdef MUL_ACC_Q8_0 #define BLOCK_SIZE_BYTES 34 @@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds); -} -#endif +#endif // MUL_ACC_Q8_0 -#ifdef LEGACY_QUANTS -fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { - var row_sum = 0; - let a_repacked = repack_a(a_byte_base, b_inner_id); - - row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); - row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); - - return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); -} - -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +#if defined(LEGACY_QUANTS) +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let b_inner_id = thread_id % THREADS_PER_BLOCK; - let b_block_idx = src1q_idx_base + block; - - let b_repacked = repack_b_qs(b_block_idx, b_inner_id); - let b_ds = repack_b_dm(b_block_idx); - + let inner_id = thread_id % THREADS_PER_BLOCK; for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, inner_id); + let da = get_dm(block_byte_base); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block; + let b_repacked = repack_b_qs(src1q_idx, inner_id); + let b_ds = repack_b_dm(src1q_idx); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]); + +#if defined(MUL_ACC_Q4_0) + acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_0 + +#if defined(MUL_ACC_Q4_1) + acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_1 + +#if defined(MUL_ACC_Q8_0) + acc[col][row] += f32(row_sum) * (da * b_ds); +#endif // MUL_ACC_Q8_0 + } } } } return acc; } -#endif +#endif // LEGACY_QUANTS #ifdef MUL_ACC_Q2_K #define BLOCK_SIZE_BYTES 84 @@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); return vec2(f32(scale & 0xFu), f32(scale >> 4u)); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let scale_q = i32(scale_min.x); - let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; - - let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) - + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; - let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) - + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); - - return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); -} -#endif +#endif // MUL_ACC_Q2_K #ifdef MUL_ACC_Q4_K #define BLOCK_SIZE_BYTES 144 @@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { return vec2(scale, min_val); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) - + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); - - // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. - return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); -} -#endif +#endif // MUL_ACC_Q4_K #ifdef K_QUANTS -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { - let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; - let b_repacked = repack_b_qs(src1q_idx, tid); - let b_ds = repack_b_dm(src1q_idx); - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, tid); + let dm = get_dm(block_byte_base); + let scale_min = get_scale_min(block_byte_base, tid); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + +#if defined(MUL_ACC_Q2_K) + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +#endif // MUL_ACC_Q2_K + +#if defined(MUL_ACC_Q4_K) + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +#endif // MUL_ACC_Q4_K + + } } } } return acc; } -#endif +#endif // K_QUANTS diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl index b3f1fa04b80..847b27ffada 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product; struct Params { offset_src1: u32, + stride_11: u32, stride_12: u32, stride_13: u32, ne0: u32, + ne1: u32, ne2: u32, ne3: u32, }; @@ -57,25 +59,28 @@ fn main( @builtin(num_workgroups) num_wg: vec3 ) { let thread_id = local_id.x; - let num_vec4 = params.ne0 / 4u; + let ne0_vec4 = params.ne0 / 4u; - let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; - let total_batches = wg_per_vec * params.ne2 * params.ne3; + let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3; let wg_linear = wg_id.y * num_wg.x + wg_id.x; if (wg_linear >= total_batches) { return; } - let src13_idx = wg_linear / (params.ne2 * wg_per_vec); - let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; - let src11_wg_idx = wg_linear % wg_per_vec; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let vec_idx = wg_linear / wg_per_vec; + let src13_idx = vec_idx / (params.ne2 * params.ne1); + let vec_ne12_num = vec_idx % (params.ne2 * params.ne1); + let src12_idx = vec_ne12_num / params.ne1; + let src11_idx = vec_ne12_num % params.ne1; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11; let src1_idx_vec4_base = src1_idx_base / 4u; let blocks_per_row = params.ne0 / 32u; let blocks_per_wg = (WG_SIZE * 4u) / 32u; - let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row; + let src11_wg_idx = wg_linear % wg_per_vec; let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; let qs_idx = thread_id % 8u; @@ -85,7 +90,7 @@ fn main( var thread_amax = 0.0; let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; - let is_valid = src11_vec4_idx < num_vec4; + let is_valid = src11_vec4_idx < ne0_vec4; #ifdef USE_SUBGROUP_REDUCTION From c793db0750454a8bedec9a512589f13fa7e50634 Mon Sep 17 00:00:00 2001 From: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com> Date: Tue, 23 Jun 2026 03:55:46 -0700 Subject: [PATCH 08/30] vulkan: link ggml-cpu when GGML_VULKAN_CHECK_RESULTS / RUN_TESTS are enabled (llama/24444) The result-checking and test debug paths in ggml-vulkan.cpp call ggml_graph_compute_with_ctx() to compute a CPU reference graph, but that symbol is defined in ggml-cpu, which ggml-vulkan does not link. Enabling -DGGML_VULKAN_CHECK_RESULTS=ON (or -DGGML_VULKAN_RUN_TESTS=ON) therefore fails to link with an unresolved external (e.g. LNK2019 on MSVC, undefined reference on GCC/Clang). This regressed after ggml-cpu was split into its own library. Link ggml-cpu under those two options so the debug builds link again. Signed-off-by: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com> --- ggml/src/ggml-vulkan/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 2d9e85794ad..5aeb6e97b15 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -108,6 +108,9 @@ if (Vulkan_FOUND) if (GGML_VULKAN_CHECK_RESULTS) add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + # the result-checking path computes a CPU reference graph via + # ggml_graph_compute_with_ctx(), which is defined in ggml-cpu + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() if (GGML_VULKAN_DEBUG) @@ -129,6 +132,8 @@ if (Vulkan_FOUND) if (GGML_VULKAN_RUN_TESTS) add_compile_definitions(GGML_VULKAN_RUN_TESTS) + # the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu) + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() # Set up toolchain for host compilation whether cross-compiling or not From 67fd5c7913fb55638fbbde24f09e967f11974c6b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 07:26:17 -0500 Subject: [PATCH 09/30] vulkan: make mul_mm ALIGNED a spec constant (llama/24689) This trims down some of the shader variant explosion and reduces binary size. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 76 +++++---- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 54 ++++--- .../vulkan-shaders/mul_mm_cm2.comp | 11 +- .../vulkan-shaders/mul_mm_funcs.glsl | 151 ++++++++++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 17 +- 5 files changed, 172 insertions(+), 137 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9a36b45de88..b3c269783e7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4074,19 +4074,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } #endif + auto const &ggml_vk_mul_mm_spec = [](std::vector spec, bool aligned) { + spec.push_back(aligned ? 1u : 0u); + return spec; + }; + const int mul_mat_id_param_count = 5; #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { + auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector spec, bool aligned, bool mul_mat_id) { + if (mul_mat_id && spec.size() > 5) { + spec.insert(spec.begin() + 5, aligned ? 1u : 0u); + } else { + spec.push_back(aligned ? 1u : 0u); + } + if (mul_mat_id && spec.size() == 6) { + spec.push_back(32); + } + return spec; + }; // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ @@ -4161,17 +4177,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -4284,32 +4300,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ // bf16 scalar path promotes to f32, no dot2 variant #define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l_int[TYPE]) { \ @@ -4474,17 +4490,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l_int[TYPE]) \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f39410d74f0..57c0410e455 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -38,17 +38,7 @@ #define LOAD_VEC_B 1 #endif -// Load 2 values at once without affecting index calculations through LOAD_VEC -#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) -#define LOAD_VEC_BATCH_A 2 -#else -#define LOAD_VEC_BATCH_A 1 -#endif -#if !defined(ALIGNED) -#define LOAD_VEC_BATCH_B 2 -#else -#define LOAD_VEC_BATCH_B 1 -#endif +layout (constant_id = 11) const uint ALIGNED = 0; #if !defined(TO_FLOAT_TYPE) #define TO_FLOAT_TYPE FLOAT_TYPE @@ -57,6 +47,13 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(DATA_A_F32) +layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];}; +#elif defined(DATA_A_F16) +layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];}; +#elif defined(DATA_A_BF16) +layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];}; +#endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; #endif @@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -194,13 +192,23 @@ void main() { const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); - const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); - const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); - const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); - const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) + const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1; + const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2; +#else + const uint LOAD_VEC_A_EFF = LOAD_VEC_A; + const uint LOAD_VEC_BATCH_A = 1; +#endif + const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1; + const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2; + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B); - const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; - const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK; #ifdef MUL_MAT_ID #ifdef MUL_MAT_ID_USE_SUBGROUPS @@ -239,15 +247,15 @@ void main() { uint pos_a = #ifdef MUL_MAT_ID - expert_idx * (p.batch_stride_a / LOAD_VEC_A) + + expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) + #else - batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) + + batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) + #endif - (ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + (ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF; #ifdef MUL_MAT_ID uint pos_b = 0; #else - uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF; #endif #ifdef COOPMAT @@ -287,8 +295,8 @@ void main() { barrier(); - pos_a += BK / LOAD_VEC_A; - pos_b += BK / LOAD_VEC_B; + pos_a += BK / LOAD_VEC_A_EFF; + pos_b += BK / LOAD_VEC_B_EFF; #ifdef COOPMAT [[unroll]] for (uint i = 0; i < BK; i += TK) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 2656fe1c3e9..a2e15f6f5ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit layout (constant_id = 4) const bool enable_smaller_matrices = false; const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; +layout (constant_id = 5) const uint ALIGNED = 0; layout (push_constant) uniform parameter { @@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { }; uint _ne1; -layout (constant_id = 5) const uint subgroup_size = 32; +layout (constant_id = 6) const uint subgroup_size = 32; shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) @@ -297,12 +298,12 @@ void main() { // Hint to the compiler that values are aligned (want 16B alignment). // Quants are always block-aligned, no alignment needed. -#if ALIGNED + if (ALIGNED != 0) { #if QUANT_K == 1 - stride_a &= ~7; -#endif - stride_b &= ~7; + stride_a &= ~7; #endif + stride_b &= ~7; + } // Create layouts for both clamped and unclamped accesses tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 73595168984..56a8a0f187f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -1,50 +1,57 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { #if defined(DATA_A_F32) || defined(DATA_A_F16) #if LOAD_VEC_A == 8 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); - buf_a[buf_idx ] = aa[0].xy; - buf_a[buf_idx + 1] = aa[0].zw; - buf_a[buf_idx + 2] = aa[1].xy; - buf_a[buf_idx + 3] = aa[1].zw; + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; + return; + } #elif LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); - buf_a[buf_idx ] = aa.xy; - buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_BATCH_A == 2 + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; + return; + } +#endif const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], + data_a_scalar[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f); } else { buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); - buf_a[buf_idx ] = aa.xy; - buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_BATCH_A == 2 + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; + return; + } +#endif const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), + TO_FLOAT_TYPE(data_a_scalar[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f); } else { buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if !defined(MUL_MAT_ID) void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { #if LOAD_VEC_B == 8 - // Not supported for b_type bf16 because bf16mat2x4 does not exist - const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); - buf_b[buf_idx + 0] = bb[0].xy; - buf_b[buf_idx + 1] = bb[0].zw; - buf_b[buf_idx + 2] = bb[1].xy; - buf_b[buf_idx + 3] = bb[1].zw; + if (ALIGNED != 0) { + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; + return; + } #elif LOAD_VEC_B == 4 - const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + if (ALIGNED != 0) { + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; + return; + } #endif - buf_b[buf_idx + 0] = bb.xy; - buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_BATCH_B == 2 const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), + TO_FLOAT_TYPE(data_b_scalar[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif } #else void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { #if LOAD_VEC_B == 8 - // Not supported for b_type bf16 because bf16mat2x4 does not exist - const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); - buf_b[buf_idx + 0] = bb[0].xy; - buf_b[buf_idx + 1] = bb[0].zw; - buf_b[buf_idx + 2] = bb[1].xy; - buf_b[buf_idx + 3] = bb[1].zw; + if (ALIGNED != 0) { + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; + return; + } #elif LOAD_VEC_B == 4 - const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + if (ALIGNED != 0) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; + return; + } #endif - buf_b[buf_idx + 0] = bb.xy; - buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_BATCH_B == 2 const uint row_i = ic * BN + col; const uint buf_idx = col * SHMEM_STRIDE + row; if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), + TO_FLOAT_TYPE(data_b_scalar[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ca6b4443141..f07583b6abc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -539,11 +539,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -565,8 +563,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c #endif { if (!dot2) { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); } } } @@ -583,8 +580,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } std::string data_a_key = "DATA_A_" + to_uppercase(tname); - // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; @@ -597,13 +592,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) From f8f62c78b5eed76702f03a41c7a93c0acf84e072 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 08:39:20 -0500 Subject: [PATCH 10/30] vulkan: support CONV_3D (llama/24612) * vulkan: support CONV_3D This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex and cleaned up by me. * disable slower perf tests --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 244 +++++++++- .../ggml-vulkan/vulkan-shaders/conv3d_mm.comp | 431 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 25 + 3 files changed, 697 insertions(+), 3 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b3c269783e7..508d569f201 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state { } }; +struct vk_conv3d_pipeline_state { + vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2, + uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned) + : s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {} + + uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD; + uint32_t aligned; + + bool operator<(const vk_conv3d_pipeline_state &b) const { + return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) < + std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned); + } +}; + struct vk_solve_tri_pipeline_state { vk_solve_tri_pipeline_state(uint32_t N, uint32_t K) : N(N), K(K) {} @@ -924,6 +938,8 @@ struct vk_device_struct { std::map pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; std::map pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; std::map pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv3d_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; @@ -1669,6 +1685,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } +struct vk_op_conv3d_push_constants { + uint32_t OC; + uint32_t IC; + uint32_t N; + + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t OW; + uint32_t OH; + uint32_t OD; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t OWOHODmp; uint32_t OWOHODL; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) { + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -5330,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - // conv2d, conv_transpose_2d + // conv2d, conv_transpose_2d, conv3d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { // smaller WG for the small-tile fallback gives more concurrent WGs per SM uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; @@ -5393,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; }; - // coopmat1 needs to store the output through shared memory, so check up front - // whether it'll fit and disable it before applying coopmat1 parameters. + // 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem + // layout. cm1 needs Csh for output, so check before applying cm1 params. if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { conv2d_use_cm1 = false; } @@ -5486,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } #undef CREATE_CONV #undef CREATE_CONVS + + std::vector conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD }; +#define CREATE_CONV3D(type_suffix, spv_suffix) \ + for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \ + const vk_conv3d_pipeline_state &state = c.first; \ + std::vector spec_constants_cpy = conv3d_spec_constants; \ + spec_constants_cpy.push_back(state.s0); \ + spec_constants_cpy.push_back(state.s1); \ + spec_constants_cpy.push_back(state.s2); \ + spec_constants_cpy.push_back(state.p0); \ + spec_constants_cpy.push_back(state.p1); \ + spec_constants_cpy.push_back(state.p2); \ + spec_constants_cpy.push_back(state.d0); \ + spec_constants_cpy.push_back(state.d1); \ + spec_constants_cpy.push_back(state.d2); \ + spec_constants_cpy.push_back(state.KW); \ + spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.KD); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ + ggml_vk_create_pipeline( \ + device, c.second, "conv3d" #type_suffix, \ + conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \ + } +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_CONV3D(_f32, _cm2) + CREATE_CONV3D(_f16_f32, _cm2) + } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONV3D(_f32, _cm1) + CREATE_CONV3D(_f16_f32, _cm1) + } else +#endif + if (conv2d_UNROLL) { + CREATE_CONV3D(_f32, _unroll) + CREATE_CONV3D(_f16_f32, _unroll) + } else { + CREATE_CONV3D(_f32, ) + CREATE_CONV3D(_f16_f32, ) + } +#undef CREATE_CONV3D } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -10901,6 +10999,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } } return nullptr; + case GGML_OP_CONV_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11); + const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9); + const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10); + const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0]; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ); + + const uint32_t KW = (uint32_t)src0->ne[0]; + const uint32_t KH = (uint32_t)src0->ne[1]; + const uint32_t KD = (uint32_t)src0->ne[2]; + const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0); + const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1); + const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2); + const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3); + const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4); + const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5); + const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6); + const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7); + const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8); + + const uint32_t CRS = IC * KW * KH * KD; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((OC % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned); + + std::map *pipelines = nullptr; + if (src0->type == GGML_TYPE_F32) { + pipelines = &ctx->device->pipeline_conv3d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape]; + } else { + return nullptr; + } + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->compile_mutex); + auto it = pipelines->find(conv3d_pipeline_state); + if (it != pipelines->end()) { + pipeline = it->second; + } else { + (*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; + } + return nullptr; case GGML_OP_ADD1: if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_add1_f16_f16; @@ -11236,6 +11389,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ABORT("invalid push constant type for CONV_2D"); } break; + case GGML_OP_CONV_3D: + if constexpr (std::is_same_v) { + const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ); + const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ); + + elements = { pc.OC, NPQ_blocks, 1 }; + if (elements[1] > 512) { + elements[2] = CEIL_DIV(elements[1], 512); + elements[1] = 512; + } + } else { + GGML_ABORT("invalid push constant type for CONV_3D"); + } + break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -13134,6 +13302,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p)); } +static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv3d_push_constants p{}; + p.IC = static_cast(ggml_get_op_params_i32(dst, 9)); + p.N = static_cast(ggml_get_op_params_i32(dst, 10)); + p.OC = static_cast(ggml_get_op_params_i32(dst, 11)); + GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC); + GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N); + GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N); + + p.IW = static_cast(ne10); + p.IH = static_cast(ne11); + p.ID = static_cast(ne12); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + p.OD = static_cast(ne2); + + // the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the + // total input element count must fit in a uint32. + GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p)); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -14531,6 +14744,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_TRANSPOSE_2D: ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_CONV_3D: + ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node); @@ -17301,6 +17518,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_is_contiguous(op->src[1]) && ggml_is_contiguous(op)); } + case GGML_OP_CONV_3D: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op); default: return false; } @@ -18144,6 +18368,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t d0 = tensor->op_params[4]; const int32_t d1 = tensor->op_params[5]; tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + const int32_t N = tensor->op_params[10]; + const int32_t OC = tensor->op_params[11]; + tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC); } else if (tensor->op == GGML_OP_CONV_2D_DW) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp new file mode 100644 index 00000000000..a9712eb3acf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp @@ -0,0 +1,431 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#include "types.glsl" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, KD, IC*OC] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, OD, OC*N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t OC; + uint32_t IC; + uint32_t N; + + // Tensor spatial sizes: input, output + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t OW; + uint32_t OH; + uint32_t OD; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t OWOHODmp; uint32_t OWOHODL; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint SHMEM_PAD = 4; +// Stride, padding, dilation +layout(constant_id = 6) const uint s0 = 1; +layout(constant_id = 7) const uint s1 = 1; +layout(constant_id = 8) const uint s2 = 1; +layout(constant_id = 9) const uint p0 = 0; +layout(constant_id = 10) const uint p1 = 0; +layout(constant_id = 11) const uint p2 = 0; +layout(constant_id = 12) const uint d0 = 1; +layout(constant_id = 13) const uint d1 = 1; +layout(constant_id = 14) const uint d2 = 1; +// Kernel spatial sizes +layout(constant_id = 15) const uint KW = 1; +layout(constant_id = 16) const uint KH = 1; +layout(constant_id = 17) const uint KD = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 18) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 19) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 20) const uint WM = 32; +layout(constant_id = 21) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction +const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0); + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.OC; +uint32_t CRS = p.IC * KD * KH * KW; +uint32_t NPQ = p.N * p.OD * p.OH * p.OW; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#if defined(COOPMAT2) || defined(COOPMAT) +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=OC +C=IC +D,R,S=KD,KH,KW +Z,P,Q=OD,OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) { + const uint32_t KHKW = KH * KW; + const uint32_t KDKHKW = KD * KHKW; + ic = crs_idx / KDKHKW; + uint32_t rem = crs_idx - ic * KDKHKW; + kd = rem / KHKW; + rem = rem - kd * KHKW; + kh = rem / KW; + kw = rem - kh * KW; +} + +void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) { + const uint32_t OWOH = p.OW * p.OH; + n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL); + uint32_t rem = npq_idx - n * p.OD * OWOH; + od = fastdiv(rem, p.OWOHmp, p.OWOHL); + rem = rem - od * OWOH; + oh = fastdiv(rem, p.OWmp, p.OWL); + ow = rem - oh * p.OW; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { + if (B_idx_NPQ * BS_NPQ >= NPQ) { + return; + } + +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#elif defined(COOPMAT) + coopmat sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac; + uint32_t IC_idx_a; + uint32_t KD_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a); + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03; + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } + float val = knl_data[knl_idx]; + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + + uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; + uint32_t IC_idx_b; + uint32_t KD_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; + split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b); + + uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2; + uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1; + uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0; + + uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !dhw_in_bounds) { + src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1); + } + float val = src_data[src_idx]; + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) { + oob = true; + } + if (oob) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } +#ifdef COOPMAT2 + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]); + } + } + } + } +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f07583b6abc..2f5661f5485 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -1053,6 +1053,31 @@ void process_shaders() { } } + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines); +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (unroll) { + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false); + } +#endif + } + } + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); From d0aaddd3965ebe2d456cc3faa7d1d1bcd5636cfa Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 08:39:37 -0500 Subject: [PATCH 11/30] vulkan: Support GET_ROWS_BACK (llama/24883) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 32 +++++++++++++++++++ .../vulkan-shaders/get_rows_back.comp | 25 +++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 58 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 508d569f201..d2827ad71f9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -791,6 +791,7 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_back_f32; vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_set_f32; @@ -4946,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -10408,6 +10410,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_get_rows_f32[src0->type]; } return nullptr; + case GGML_OP_GET_ROWS_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_back_f32; + } + return nullptr; case GGML_OP_ACC: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_acc_f32; @@ -11304,6 +11311,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; + case GGML_OP_GET_ROWS_BACK: + elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + break; case GGML_OP_ARGSORT: GGML_ASSERT(0); break; @@ -11564,6 +11575,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, }); } +static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -14476,6 +14502,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_GET_ROWS: ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_GET_ROWS_BACK: + ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_ADD: if (ctx->num_additional_fused_ops) { @@ -17197,6 +17227,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } } + case GGML_OP_GET_ROWS_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SET_ROWS: { switch (op->type) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp new file mode 100644 index 00000000000..7e3d8a28197 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint col = gl_GlobalInvocationID.x; + + if (col >= p.ne20) { + return; + } + + for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) { + float sum = 0.0f; + for (uint i = 0; i < p.ne10; ++i) { + if (data_b[get_boffset() + i*p.nb10] == int(row)) { + sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00]; + } + } + + data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 2f5661f5485..502602f799f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -843,6 +843,7 @@ void process_shaders() { string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); From f55f4eb8c954cc27bd837313d76a6111eb290ff3 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 09:48:24 -0500 Subject: [PATCH 12/30] vulkan: support all backend tests for SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU/NORM (llama/24582) * vulkan: make SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU use unary.comp * vulkan: make NORM support noncontig * add noncontiguous row test cases for norm/l2_norm, handle this in the CPU backend and l2_norm.comp * fix supports_op for cuda and webgpu --- ggml/src/ggml-cpu/ops.cpp | 73 ++++++++++++------ ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 76 +++++++++++-------- .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 ----- ggml/src/ggml-vulkan/vulkan-shaders/cos.comp | 17 ----- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 11 +-- .../vulkan-shaders/leaky_relu.comp | 22 ------ ggml/src/ggml-vulkan/vulkan-shaders/norm.comp | 20 ++--- ggml/src/ggml-vulkan/vulkan-shaders/sin.comp | 17 ----- ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp | 17 ----- .../ggml-vulkan/vulkan-shaders/square.comp | 17 ----- .../src/ggml-vulkan/vulkan-shaders/unary.comp | 24 ++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 23 +++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- 14 files changed, 145 insertions(+), 193 deletions(-) delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/cos.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/sin.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/square.comp diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 74611dce7f1..6724686b8ae 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; const int nth = params->nth; @@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - float sum = 0.0; - ggml_vec_sum_f32(ne00, &sum, x); - float mean = sum/ne00; + if (nb00 == sizeof(float) && nb0 == sizeof(float)) { + const float * xf = (const float *) x; - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - float variance = 0; + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, xf); + float mean = sum/ne00; + + float * yf = (float *) y; + float variance = 0; #ifdef GGML_USE_ACCELERATE - mean = -mean; - vDSP_vsadd(x, 1, &mean, y, 1, ne00); - vDSP_measqv(y, 1, &variance, ne00); + mean = -mean; + vDSP_vsadd(xf, 1, &mean, yf, 1, ne00); + vDSP_measqv(yf, 1, &variance, ne00); #else - variance = ggml_vec_cvar_f32(ne00, y, x, mean); + variance = ggml_vec_cvar_f32(ne00, yf, xf, mean); #endif //GGML_USE_ACCELERATE - const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); + const float scale = 1.0f/sqrtf(variance + eps); + ggml_vec_scale_f32(ne00, yf, scale); + } else { + float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += *(const float *) (x + i00*nb00); + } + const float mean = sum/ne00; + + float variance = 0.0f; + for (int64_t i00 = 0; i00 < ne00; i00++) { + const float v = *(const float *) (x + i00*nb00) - mean; + *(float *) (y + i00*nb0) = v; + variance += v * v; + } + variance /= ne00; + + const float scale = 1.0f/sqrtf(variance + eps); + for (int64_t i00 = 0; i00 < ne00; i00++) { + *(float *) (y + i00*nb0) *= scale; + } + } } } } @@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; const int nth = params->nth; @@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; ggml_float sum = 0.0; for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); + const float xi = *(const float *) (x + i00*nb00); + sum += (ggml_float)(xi * xi); } - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - const float scale = 1.0f/fmaxf(sqrtf(sum), eps); - ggml_vec_scale_f32(ne00, y, scale); + char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + + if (nb00 == sizeof(float) && nb0 == sizeof(float)) { + memcpy(y, x, ne00 * sizeof(float)); + ggml_vec_scale_f32(ne00, (float *) y, scale); + } else { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const float xi = *(const float *) (x + i00*nb00); + *(float *) (y + i00*nb0) = xi * scale; + } + } } } } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3d4b5f60565..cca70592f80 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return true; + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]); break; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d2827ad71f9..f4a578b893d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -816,14 +816,10 @@ struct vk_device_struct { vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64; vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32; vk_pipeline pipeline_scale_f32; - vk_pipeline pipeline_sqr_f32; - vk_pipeline pipeline_sqrt_f32; - vk_pipeline pipeline_sin_f32; - vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; vk_pipeline pipeline_tri[2]; vk_pipeline pipeline_diag[2]; - vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_clamp[2]; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; @@ -855,6 +851,10 @@ struct vk_device_struct { vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_sqr[2]; + vk_pipeline pipeline_sqrt[2]; + vk_pipeline pipeline_sin[2]; + vk_pipeline pipeline_cos[2]; vk_pipeline pipeline_xielu[2]; vk_pipeline pipeline_neg[2]; vk_pipeline pipeline_tanh[2]; @@ -886,7 +886,7 @@ struct vk_device_struct { vk_pipeline pipeline_geglu_erf[2]; vk_pipeline pipeline_geglu_quick[2]; - vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_leaky_relu[2]; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; @@ -4972,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); @@ -5092,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5106,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5127,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) + CREATE_UNARY(sqr) + CREATE_UNARY(sqrt) + CREATE_UNARY(sin) + CREATE_UNARY(cos) + CREATE_UNARY(clamp) + CREATE_UNARY(leaky_relu) CREATE_UNARY(xielu) CREATE_UNARY(neg) CREATE_UNARY(tanh) @@ -5166,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_GLU(geglu_quick) #undef CREATE_GLU - ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -10521,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SQR: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sqr_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_SQRT: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sqrt_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_SIN: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sin_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_COS: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_cos_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_LOG: @@ -10559,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_CLAMP: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_clamp_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_PAD: @@ -10928,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_LEAKY_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_leaky_relu_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_CONV_2D: @@ -11431,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_TRI: case GGML_OP_DIAG: case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: case GGML_OP_PAD: case GGML_OP_ROLL: case GGML_OP_REPEAT: @@ -12297,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p)); } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -13399,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f }); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p)); } #ifdef GGML_VULKAN_RUN_TESTS @@ -17325,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_TRANSPOSE: case GGML_OP_RMS_NORM: return true; - case GGML_OP_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_NORM: case GGML_OP_L2_NORM: - return ggml_is_contiguous_rows(op->src[0]) && - op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -17349,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LEAKY_RELU: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->type == op->src[0]->type; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp deleted file mode 100644 index 653431895e7..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp deleted file mode 100644 index db6865db981..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index f9af46744df..9039ed1ded3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -14,16 +14,13 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint i3 = row / (p.ne11 * p.ne12); - const uint i3_offset = i3 * p.ne12 * p.ne11; - const uint i2 = (row - i3_offset) / p.ne11; - const uint i2_offset = i2 * p.ne11; - const uint i1 = row - i3_offset - i2_offset; + const uint a_base = get_aoffset() + src0_idx(row * p.ne00); + const uint d_base = get_doffset() + dst_idx(row * p.ne10); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]); sum[tid] += xi * xi; } @@ -39,6 +36,6 @@ void main() { const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1)); [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { - data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); + data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp deleted file mode 100644 index b281e855cb2..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +++ /dev/null @@ -1,22 +0,0 @@ -#version 450 - -#include "generic_head.glsl" -#include "types.glsl" - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; - - if (i >= p.KX) { - return; - } - - const float val = float(data_a[i]); - data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp index cc3ea0b7606..792012d57e8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -1,26 +1,26 @@ #version 450 -#include "generic_head.glsl" #include "types.glsl" +#include "generic_unary_head.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared vec2 sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint a_base = get_aoffset() + src0_idx(row * p.ne00); + const uint d_base = get_doffset() + dst_idx(row * p.ne10); + sum[tid] = vec2(0.0f, 0.0f); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const float xi = float(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const float xi = float(data_a[a_base + i0*p.nb00]); sum[tid].x += xi; sum[tid].y += xi * xi; } @@ -34,11 +34,11 @@ void main() { barrier(); } - const float mean = sum[0].x / p.KX; - const float var = sum[0].y / p.KX - mean * mean; + const float mean = sum[0].x / p.ne00; + const float var = sum[0].y / p.ne00 - mean * mean; const float inv_std = inversesqrt(var + p.param1); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp deleted file mode 100644 index 61f17b2f006..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp deleted file mode 100644 index 70daad6c5db..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp deleted file mode 100644 index 4eb56afcb1e..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp b/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp index 47a45739960..c62bce82555 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp @@ -17,6 +17,30 @@ float op_neg(float x) { return -x; } +float op_sqr(float x) { + return x * x; +} + +float op_sqrt(float x) { + return sqrt(x); +} + +float op_sin(float x) { + return sin(x); +} + +float op_cos(float x) { + return cos(x); +} + +float op_clamp(float x) { + return clamp(x, p.param1, p.param2); +} + +float op_leaky_relu(float x) { + return max(x, 0.0f) + min(x, 0.0f) * p.param1; +} + float op_step(float x) { return x >= 0.0f ? 1.0f : 0.0f; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 502602f799f..3bd93d256c8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -849,16 +849,6 @@ void process_shaders() { string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}}); @@ -885,6 +875,18 @@ void process_shaders() { string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}}); string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}}); string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}}); + string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}}); + string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}}); + string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}}); + string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}}); + string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}}); + string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}}); + string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}}); + string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}}); + string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}}); + string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}}); + string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}}); + string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}}); string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}}); string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}}); string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}}); @@ -942,7 +944,6 @@ void process_shaders() { string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e8eafd185a4..f0ec18abd9a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4270,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_RMS_NORM: case GGML_OP_NORM: case GGML_OP_L2_NORM: - supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0); break; case GGML_OP_ROPE: supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; From 376093c348a73d48d08030302f6d1fc30a4d1a42 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 23 Jun 2026 22:34:00 -0500 Subject: [PATCH 13/30] vulkan: Apply bias before softmax in FA, to avoid overflow (llama/24909) --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 1 + ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp | 1 + 2 files changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 91fb07c93e7..3192130ccf5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -463,6 +463,7 @@ void main() { } rowmaxf = max(rowmaxf, float(Sf[r][c])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // M = max(rowmax, Mold) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 23ae3833e52..16178e57702 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -352,6 +352,7 @@ void main() { } rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // Compute max across the row From 7f3f9fd1ea8a39411c7eb726a93ad90e090cedc9 Mon Sep 17 00:00:00 2001 From: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com> Date: Wed, 24 Jun 2026 17:42:03 +0800 Subject: [PATCH 14/30] vulkan: fail the build when a shader fails to compile (llama/24450) * vulkan-shaders-gen: fail the build when a shader fails to compile vulkan-shaders-gen did not detect shader-compile subprocess failures, so a broken libggml-vulkan could be produced while the build reported success and the breakage only surfaced at run time. execute_command() discarded the child exit code (POSIX waitpid passed nullptr for status; the Windows branch never called GetExitCodeProcess) and string_to_spv decided success only from whether stderr was empty, so a non-zero exit with empty stderr, or a subprocess that failed to launch, was treated as success. Return the child exit code from execute_command() (WEXITSTATUS on POSIX, GetExitCodeProcess on Windows), treat a non-zero exit or non-empty stderr or a launch exception as a failure, and record it in an atomic flag. main() checks the flag after process_shaders() and returns EXIT_FAILURE before writing the output files, so the build stops instead of emitting a broken backend. Fixes #24393 Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com> * vulkan-shaders-gen: simplify compile_failed access and drop unreachable return Address review feedback on #24450: - Access the std::atomic compile_failed directly (= / implicit bool) instead of .store()/.load(); the flag stays atomic because the worker threads in process_shaders() set it concurrently. - Remove the unreachable trailing return -1 in execute_command(): on POSIX the child _exit()s after execvp and the parent returns (fork()<0 throws); on Windows the block returns the exit code. Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com> --------- Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com> --- .../vulkan-shaders/vulkan-shaders-gen.cpp | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 3bd93d256c8..1925582ffed 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,9 @@ std::mutex lock; std::vector> shader_fnames; +// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the +// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393) +static std::atomic compile_failed{false}; std::locale c_locale("C"); std::string GLSLC = "glslc"; @@ -78,7 +82,7 @@ enum MatMulIdType { namespace { -void execute_command(std::vector& command, std::string& stdout_str, std::string& stderr_str) { +int execute_command(std::vector& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 HANDLE stdout_read, stdout_write; HANDLE stderr_read, stderr_write; @@ -127,8 +131,11 @@ void execute_command(std::vector& command, std::string& stdout_str, CloseHandle(stdout_read); CloseHandle(stderr_read); WaitForSingleObject(pi.hProcess, INFINITE); + DWORD exit_code = 1; + GetExitCodeProcess(pi.hProcess, &exit_code); CloseHandle(pi.hProcess); CloseHandle(pi.hThread); + return (int)exit_code; #else int stdout_pipe[2]; int stderr_pipe[2]; @@ -175,7 +182,9 @@ void execute_command(std::vector& command, std::string& stdout_str, close(stdout_pipe[0]); close(stderr_pipe[0]); - waitpid(pid, nullptr, 0); + int status = 0; + waitpid(pid, &status, 0); + return WIFEXITED(status) ? WEXITSTATUS(status) : -1; } #endif } @@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p // } // std::cout << std::endl; - execute_command(cmd, stdout_str, stderr_str); - if (!stderr_str.empty()) { - std::cerr << "cannot compile " << name << "\n\n"; + int exit_code = execute_command(cmd, stdout_str, stderr_str); + if (exit_code != 0 || !stderr_str.empty()) { + std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n"; for (const auto& part : cmd) { std::cerr << part << " "; } std::cerr << "\n\n" << stderr_str << std::endl; + compile_failed = true; return; } @@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p shader_fnames.push_back(std::make_pair(name, out_path)); } catch (const std::exception& e) { std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + compile_failed = true; } } @@ -1271,6 +1282,11 @@ int main(int argc, char** argv) { process_shaders(); + if (compile_failed) { + std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl; + return EXIT_FAILURE; + } + write_output_files(); return EXIT_SUCCESS; From 21bedb35cef5bee2969880d1ccedecabaa1f880f Mon Sep 17 00:00:00 2001 From: Wagner Bruna Date: Wed, 24 Jun 2026 11:29:24 -0300 Subject: [PATCH 15/30] vulkan: allow reducing the graph submission batches to avoid timeouts (llama/24872) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f4a578b893d..5fbebc6d751 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -699,6 +699,7 @@ struct vk_device_struct { bool add_rms_fusion; uint32_t partials_binding_alignment; + uint32_t max_nodes_per_submit; bool shader_64b_indexing; @@ -5878,6 +5879,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote); + // Submit at least every 100 nodes, in case there are workloads without as much matmul. + device->max_nodes_per_submit = 100; + const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT"); + if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) { + uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT); + device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u); + } + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -16173,8 +16182,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). - // Also submit at least every 100 nodes, in case there are workloads without as much matmul. - int nodes_per_submit = 100; int submitted_nodes = 0; int submit_count = 0; uint64_t mul_mat_bytes = 0; @@ -16400,7 +16407,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; - bool submit = (submitted_nodes >= nodes_per_submit) || + bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) || (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) || (i + ctx->num_additional_fused_ops >= last_node) || (almost_ready && !ctx->almost_ready_fence_pending); From 393a2c8039267d9374b96e020fbea2bf45c511f2 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 24 Jun 2026 12:14:25 -0700 Subject: [PATCH 16/30] hexagon: MUL_MAT and MUL_MAT_ID rework : 32x32 tiled weight repack, kernel-params, cached graphs (llama/24954) * hex-mm: new weight layout and fusion updates * hvx-mm: unroll the new tiled vec_dots to optimize hvx register util * hex-mm: optimize dyn.quant format for q8_0 and q8_1 to reduce overhead in vec_dots. * hvx-mm: parallel quantizer per block for large rows * hvx-mm: simplify and futher optimize dyn.quant and vec_dots * hvx-mm: keep intermediate per tile accumulators in fp16 * hmx-mm: optimize weight dequant by aligning the repacked tiles with the DMA * hmx-mm: remove qweight scratch and just use vtcm_weight * hmx-mm: remove all unused and obsolete code * hmx-mm: the new tiled repack format is here to stay -- rename all x4x2 to _tiled * hmx-mm: improve activation processing with dma prefetch * hex-mm: fix hmx/hvx fallback logic and MUL_MAT_ID allocation (unbreaks OLMoE) * hex-mm: align the weight tiles with dma just like we did in hmx-mm * hex-mm: factor out common mm bits into htp/matmul-ops.h * hex-mm: start moving mm kernel selection to the host * hex-mm: move all of the matmul param compute into the host * hmx-mm: restore pipelined mode * hmx-mm: unroll the dequant functions to optimize register usage * hmx-mm: further improve activation process * hex-mm: use vtcm_seq_alloc for all vtcm allocations and define more common functions * hex-mm: improve mm optimizer to acount for number of activation threads * hex-mm: fix matmul-id kernel params selection (unbreaks OLMoE and LFM) * hexagon: remove support for arch < v73 since HMX is now required for most use-cases * hex-mm: cleanup naming for consistency * hex-mm: make sure matmul fusion accounts for vtcm allocation * hex-mm: minor cleanup for kernel_params definition * hex-mm: replace hardcoded limits with proper checks for vtcm requirements * hex-mm: add support for non-tiled mm as a fallback option and factor out hvx kernels into separate header * hex-mm: remove unused functions * hex-mm: add shorthand for MM_SELECT in run-tool script * hvx-mm: factor out hvx/hmx microkernels and unify matmul entry and dispatch * hex-mm: further cleanup matmul fallback path * hex-mm: refactor matmul entry point and dispatch a bit further * hexagon: update cmake build to enable hmx for everything * hex-ops: optimize kernel_param updates and include summary in the logs * hex-mm: add support for GGML_HEXAGON_MM_SELECT * hex-mm: add hex-common header * hex-mm: pass correct number of tasks to workpool * hex-mm: add proper checks for no-work in dyn.quant tasks * hex-mm: convert all quantizers into a macro * hex-mm: fix hvx-flat fallback to pass all MUL_MAT tests * hex-mm: vectorize q8_1 quantizer * hex-mm: improve fused ffn mm stride handling * hex-mm: consistent use of n_threads and pipeline in kernel_params * hexagon: minor formatting * hex-mm: update MUL_MAT_ID kernel_param handling to make sure host/npu are in sync * hvx-mm: go back to accumulating in fp32 in tiled hvx kernels, more accurate and same perf * hvx-mm: unroll the loops and remove masking that is not needed for tiled accums * hmx-mm: optimize activation processing (slit loops, some unrolling, etc) * hmx-mm: minor optimization for output processing * hex-mm: consistent use of uint32_t and size_t in mm kernels * hex-mm: remove legacy restrictions for rows to be multiple of 256 * hexagon: replace sprintf with snprintf * hex-mm: relax hardcoded nrows checks and rely on VTCM size requirements * hexagon: minor alignment fix * hexagon: fix trailing spaces * hex-mm: relax padding from 256 to 128 (leftovers) * hex-mm: remove redundant checks for weight align to 128 we always use 2D dma for the weights and align them properly * hmx-mm: MUL_MAT_ID better work distribution between hvx threads and hmx tracing * hex-mm: specialize per-token mmid activation handling * hex-profile: update python scripts to handle kernel-params section in the logging output * hex-mm: move n_prefetch (aka dma_depth) into kernel params and remove unused fields * hex-trace: use easier to parse format, simply and fix post-proc scripts * hmx-mm: relax 32 row limit for output processing which helps utilization * hmx-mm: use start-chunk idx for tracing info * hmx-mm: parameterize activation dma pipeline * hexagon: add support for simple graph caching to avoid recomputing kernel-params * hex-mm: remove left-over repack functions * hex-mm: tighten n_prefetch asserts * hex-mm: remove duplicate round/align_up helper * hexagon: cleanup common header used in host/npu * hexagon: update early wakeup threshold * hmx-mm: define cost constants and update solver to assume that repacked ne[1] is padded to 32 * hmx-mm: make precompute_matmul a bit more readable (split into smaller functions, etc) * hex-mm: remove n_threads constraint * hex-mm: minor formatting updates * hex-mm: remove obsolete profiling logs * hex-mm: restore hardcode gate to refuse lm-head to avoid repacking that tensor --- ggml/CMakeLists.txt | 1 - ggml/src/ggml-hexagon/CMakeLists.txt | 4 - ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2645 +++---- ggml/src/ggml-hexagon/htp-opnode.h | 206 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 52 +- .../ggml-hexagon/htp/cmake-toolchain.cmake | 28 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 5 +- ggml/src/ggml-hexagon/htp/hex-common.h | 80 + ggml/src/ggml-hexagon/htp/hex-dma.h | 6 +- ggml/src/ggml-hexagon/htp/hex-utils.h | 57 +- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 26 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 2080 ----- .../ggml-hexagon/htp/hmx-mm-kernels-tiled.h | 1306 ++++ ggml/src/ggml-hexagon/htp/hmx-ops.c | 6 - ggml/src/ggml-hexagon/htp/hmx-ops.h | 88 - ggml/src/ggml-hexagon/htp/htp-ctx.h | 12 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 23 +- ggml/src/ggml-hexagon/htp/htp_iface.idl | 3 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 31 +- .../ggml-hexagon/htp/hvx-mm-kernels-flat.h | 1024 +++ .../ggml-hexagon/htp/hvx-mm-kernels-tiled.h | 1140 +++ ggml/src/ggml-hexagon/htp/main.c | 104 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 6820 +++++++---------- ggml/src/ggml-hexagon/htp/matmul-ops.h | 508 ++ ggml/src/ggml-hexagon/libggml-htp.inf | 4 - 25 files changed, 8459 insertions(+), 7800 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hex-common.h delete mode 100644 ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h delete mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.c delete mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h create mode 100644 ggml/src/ggml-hexagon/htp/matmul-ops.h diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 04069784f19..a0cd4e7158f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING "ggml: OpenCL API version to target") option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) -set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)") # toolchain for vulkan-shaders-gen set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index b82bae0c103..c6e49a71d11 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -25,7 +25,6 @@ include(ExternalProject) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") -set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") add_library(htp_iface OBJECT ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c) @@ -72,15 +71,12 @@ function(build_htp_skel V) -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT} -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT} -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} - -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE} -DDSP_VERSION=${V} -DPREBUILT_LIB_DIR="toolv19_${V}") list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so) set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE) endfunction() -build_htp_skel(v68) -build_htp_skel(v69) build_htp_skel(v73) build_htp_skel(v75) build_htp_skel(v79) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index e612ec392b2..3d41c47b651 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #ifdef _WIN32 # include @@ -41,6 +42,7 @@ #include "ggml-quants.h" #include "htp-opnode.h" #include "htp-ops.h" +#include "htp/matmul-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -51,7 +53,7 @@ using u32vec = std::vector; static int opt_arch = 0; // autodetect static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all -static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static int opt_nhmx = 1; // when set, enable HMX; when 0, use HVX only static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size static int opt_etm = 0; @@ -59,6 +61,8 @@ static int opt_verbose = 0; static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) static int opt_hostbuf = 1; // hostbuf ON by default +static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU + // Default PMU events, if profiling with PMU (mode=2) is enabled // See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html // https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html @@ -68,22 +72,15 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches -static int opt_oppoll = 0; // polling for batch completions static int opt_optrace = 0; // trace buffer size per thread (0 means default) +static int opt_oppoll = 0; // polling for batch completions +static int opt_opfusion = 1; // enable/disable op fusion static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) -static inline uint64_t hex_is_aligned(void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline size_t hex_round_up(size_t n, size_t m) { - return m * ((n + m - 1) / m); -} - static const char * status_to_str(uint32_t status) { switch (status) { case HTP_STATUS_OK: @@ -107,15 +104,15 @@ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_op if (!opt_verbose) return; htp_opformat fmt(node); - GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); + GGML_LOG_DEBUG("ggml-hex: %s execute-op %s|%s|%s|%s|%s|%s|%s|flags 0x%x\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, fmt.kparams, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; htp_opformat fmt(htp_opformat(htp_opnode{const_cast(op), {}, HTP_OP_INVALID})); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s|%s|%s|%s|%s|%s|%s\n", sess_name.c_str(), ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } @@ -144,16 +141,52 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_op char pmu_str[256] = ""; if (opt_profile == 2) { static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); - sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + snprintf(pmu_str, sizeof(pmu_str), " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } htp_opformat fmt(node); float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f; - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(), - node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str); + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s|%s|%s|%s|%s|%s|usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.kparams, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str); +} + +// ** + +static inline bool ggml_hexagon_is_repack_type(enum ggml_type type) { + return type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || + type == GGML_TYPE_Q8_0 || type == GGML_TYPE_IQ4_NL || + type == GGML_TYPE_MXFP4; } +static inline bool ggml_hexagon_is_hmx_weight_type(enum ggml_type type) { + return type == GGML_TYPE_F16 || type == GGML_TYPE_F32 || ggml_hexagon_is_repack_type(type); +} + +struct htp_mm_kernel_params; +struct ggml_hexagon_session; +static void ggml_hexagon_precompute_matmul_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + struct htp_mm_kernel_params * kparams +); + +static void ggml_hexagon_precompute_fused_qkv_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct htp_mm_kernel_params * kparams +); + +static void ggml_hexagon_precompute_fused_ffn_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct htp_mm_kernel_params * kparams +); + // ** backend sessions struct ggml_hexagon_opbatch; @@ -180,6 +213,18 @@ struct ggml_hexagon_session { ggml_backend_buffer_type buffer_type = {}; ggml_backend_buffer_type repack_buffer_type = {}; + uint32_t n_threads = 0; + uint32_t n_hvx = 0; + uint32_t n_hmx = 0; + uint64_t vtcm_size = 0; + size_t max_vmem = 0; + size_t max_bufsize = 0; + + struct { + uint64_t uid = 0; + std::vector htp_nodes; + } cached_graph; + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); ~ggml_hexagon_session() noexcept(true); @@ -325,47 +370,7 @@ static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buf return GGML_STATUS_SUCCESS; } -// ======== Q4x4x2 ==================== -struct x2_q4 { - int v[2]; -}; - -static x2_q4 unpack_q4(uint8_t v) { - x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 }; - return x; -} - -static void dump_block_q4_0(const block_q4_0 * b, int i) { - HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0], - unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1], - unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1], - GGML_FP16_TO_FP32(b->d)); -} - -static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_Q4_0x4x2; - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_d = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); - - HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0], - unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0], - unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0], - GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); - - HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1], - unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1], - unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1], - GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); -} +// ** Repack helpers for tiled quantized weights static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) { static const int qk = QK4_0; @@ -388,300 +393,6 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi } } -static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q4_0(&x[i * 8 + 0], 0); - dump_block_q4_0(&x[i * 8 + 1], 1); - dump_block_q4_0(&x[i * 8 + 2], 2); - dump_block_q4_0(&x[i * 8 + 3], 3); - dump_block_q4_0(&x[i * 8 + 4], 4); - dump_block_q4_0(&x[i * 8 + 5], 5); - dump_block_q4_0(&x[i * 8 + 6], 6); - dump_block_q4_0(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - unpack_q4_0_quants(qs, &x[i * 8 + 0], 0); - unpack_q4_0_quants(qs, &x[i * 8 + 1], 1); - unpack_q4_0_quants(qs, &x[i * 8 + 2], 2); - unpack_q4_0_quants(qs, &x[i * 8 + 3], 3); - unpack_q4_0_quants(qs, &x[i * 8 + 4], 4); - unpack_q4_0_quants(qs, &x[i * 8 + 5], 5); - unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); - unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - ggml_half * d = (ggml_half *) (y_d + i * dblk_size); - d[0] = x[i * 8 + 0].d; - d[1] = x[i * 8 + 1].d; - d[2] = x[i * 8 + 2].d; - d[3] = x[i * 8 + 3].d; - d[4] = x[i * 8 + 4].d; - d[5] = x[i * 8 + 5].d; - d[6] = x[i * 8 + 6].d; - d[7] = x[i * 8 + 7].d; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q4x4x2(y, i, k); - } - } -} - -static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q4x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0xf; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0xf; - qs[j+128] = q[j] >> 4; - } - } - - pack_q4_0_quants(&x[i * 8 + 0], qs, 0); - pack_q4_0_quants(&x[i * 8 + 1], qs, 1); - pack_q4_0_quants(&x[i * 8 + 2], qs, 2); - pack_q4_0_quants(&x[i * 8 + 3], qs, 3); - pack_q4_0_quants(&x[i * 8 + 4], qs, 4); - pack_q4_0_quants(&x[i * 8 + 5], qs, 5); - pack_q4_0_quants(&x[i * 8 + 6], qs, 6); - pack_q4_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); - x[i * 8 + 0].d = d[0]; - x[i * 8 + 1].d = d[1]; - x[i * 8 + 2].d = d[2]; - x[i * 8 + 3].d = d[3]; - x[i * 8 + 4].d = d[4]; - x[i * 8 + 5].d = d[5]; - x[i * 8 + 6].d = d[6]; - x[i * 8 + 7].d = d[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q4_0(&x[i * 8 + 0], 0); - dump_block_q4_0(&x[i * 8 + 1], 1); - dump_block_q4_0(&x[i * 8 + 2], 2); - dump_block_q4_0(&x[i * 8 + 3], 3); - dump_block_q4_0(&x[i * 8 + 4], 4); - dump_block_q4_0(&x[i * 8 + 5], 5); - dump_block_q4_0(&x[i * 8 + 6], 6); - dump_block_q4_0(&x[i * 8 + 7], 7); - } - } -} - -static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - memset(qs, 8, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q4_0_quants(&x[i * 8 + 0], qs, 0); - pack_q4_0_quants(&x[i * 8 + 1], qs, 1); - pack_q4_0_quants(&x[i * 8 + 2], qs, 2); - pack_q4_0_quants(&x[i * 8 + 3], qs, 3); - pack_q4_0_quants(&x[i * 8 + 4], qs, 4); - pack_q4_0_quants(&x[i * 8 + 5], qs, 5); - pack_q4_0_quants(&x[i * 8 + 6], qs, 6); - pack_q4_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].d = 0; - x[i * 8 + 1].d = 0; - x[i * 8 + 2].d = 0; - x[i * 8 + 3].d = 0; - x[i * 8 + 4].d = 0; - x[i * 8 + 5].d = 0; - x[i * 8 + 6].d = 0; - x[i * 8 + 7].d = 0; - } -} - -// repack q4_0 data into q4x4x2 tensor -static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer - repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// repack q4x4x2 tensor into q4_0 data -static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - - // But we only copy the remaining number of bytes to the destination. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { static const int qk = QK4_1; @@ -703,603 +414,19 @@ static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi } } -static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes - const int qblk_size = qk / 2; // int4 = 128 bytes - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales/offsets - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); - unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); - unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); - unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); - unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); - unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); - unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); - unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales and offsets - for (int i = 0; i < nb; i++) { - ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); - for (int j = 0; j < 8; j++) { - d_m[j * 2 + 0] = x[i * 8 + j].d; - d_m[j * 2 + 1] = x[i * 8 + j].m; - } - } -} - -static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes - const int qblk_size = qk / 2; // int4 = 128 bytes - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales/offsets - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0x0F; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0x0F; - qs[j+128] = q[j] >> 4; - } - } - - pack_q4_1_quants(&x[i * 8 + 0], qs, 0); - pack_q4_1_quants(&x[i * 8 + 1], qs, 1); - pack_q4_1_quants(&x[i * 8 + 2], qs, 2); - pack_q4_1_quants(&x[i * 8 + 3], qs, 3); - pack_q4_1_quants(&x[i * 8 + 4], qs, 4); - pack_q4_1_quants(&x[i * 8 + 5], qs, 5); - pack_q4_1_quants(&x[i * 8 + 6], qs, 6); - pack_q4_1_quants(&x[i * 8 + 7], qs, 7); - } - - // Unpack the scales and offsets - for (int i = 0; i < nb; i++) { - const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); - for (int j = 0; j < 8; j++) { - x[i * 8 + j].d = d_m[j * 2 + 0]; - x[i * 8 + j].m = d_m[j * 2 + 1]; - } - } -} - -static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q4_1_quants(&x[i * 8 + 0], qs, 0); - pack_q4_1_quants(&x[i * 8 + 1], qs, 1); - pack_q4_1_quants(&x[i * 8 + 2], qs, 2); - pack_q4_1_quants(&x[i * 8 + 3], qs, 3); - pack_q4_1_quants(&x[i * 8 + 4], qs, 4); - pack_q4_1_quants(&x[i * 8 + 5], qs, 5); - pack_q4_1_quants(&x[i * 8 + 6], qs, 6); - pack_q4_1_quants(&x[i * 8 + 7], qs, 7); - } - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < 8; j++) { - x[i * 8 + j].d = 0; - x[i * 8 + j].m = 0; - } - } -} - -static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); - - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); - memcpy(buf_pd, src, n_rem_bytes); - repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros - - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_rp, src, row_size); - unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); - memcpy(dst, buf_pd, row_size); - } - - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_rp, src, row_size); - unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); - memcpy(dst, buf_pd, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// ======== Q8x4x2 ==================== -static void dump_block_q8_0(const block_q8_0 * b, int i) { - HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], - b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d)); -} - -static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_Q8_0x4x2; - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_d = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); - - HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127], - GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); - - HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255], - GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); -} - -static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) { - static const int qk = QK8_0; - - for (unsigned int i = 0; i < qk; ++i) { - qs[bi * qk + i] = x->qs[i]; - } -} - -static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) { - static const int qk = QK8_0; - - for (unsigned int i = 0; i < qk; ++i) { - x->qs[i] = qs[bi * qk + i]; - } -} - -static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q8_0(&x[i * 8 + 0], 0); - dump_block_q8_0(&x[i * 8 + 1], 1); - dump_block_q8_0(&x[i * 8 + 2], 2); - dump_block_q8_0(&x[i * 8 + 3], 3); - dump_block_q8_0(&x[i * 8 + 4], 4); - dump_block_q8_0(&x[i * 8 + 5], 5); - dump_block_q8_0(&x[i * 8 + 6], 6); - dump_block_q8_0(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q8_0x4x2]; // unpacked quants - - unpack_q8_0_quants(qs, &x[i * 8 + 0], 0); - unpack_q8_0_quants(qs, &x[i * 8 + 1], 1); - unpack_q8_0_quants(qs, &x[i * 8 + 2], 2); - unpack_q8_0_quants(qs, &x[i * 8 + 3], 3); - unpack_q8_0_quants(qs, &x[i * 8 + 4], 4); - unpack_q8_0_quants(qs, &x[i * 8 + 5], 5); - unpack_q8_0_quants(qs, &x[i * 8 + 6], 6); - unpack_q8_0_quants(qs, &x[i * 8 + 7], 7); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk; j++) { - q[j] = qs[j]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - ggml_half * d = (ggml_half *) (y_d + i * dblk_size); - d[0] = x[i * 8 + 0].d; - d[1] = x[i * 8 + 1].d; - d[2] = x[i * 8 + 2].d; - d[3] = x[i * 8 + 3].d; - d[4] = x[i * 8 + 4].d; - d[5] = x[i * 8 + 5].d; - d[6] = x[i * 8 + 6].d; - d[7] = x[i * 8 + 7].d; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q8x4x2(y, i, k); - } - } -} - -static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q8x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk; j++) { - qs[j] = q[j]; - } - - pack_q8_0_quants(&x[i * 8 + 0], qs, 0); - pack_q8_0_quants(&x[i * 8 + 1], qs, 1); - pack_q8_0_quants(&x[i * 8 + 2], qs, 2); - pack_q8_0_quants(&x[i * 8 + 3], qs, 3); - pack_q8_0_quants(&x[i * 8 + 4], qs, 4); - pack_q8_0_quants(&x[i * 8 + 5], qs, 5); - pack_q8_0_quants(&x[i * 8 + 6], qs, 6); - pack_q8_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); - x[i * 8 + 0].d = d[0]; - x[i * 8 + 1].d = d[1]; - x[i * 8 + 2].d = d[2]; - x[i * 8 + 3].d = d[3]; - x[i * 8 + 4].d = d[4]; - x[i * 8 + 5].d = d[5]; - x[i * 8 + 6].d = d[6]; - x[i * 8 + 7].d = d[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q8_0(&x[i * 8 + 0], 0); - dump_block_q8_0(&x[i * 8 + 1], 1); - dump_block_q8_0(&x[i * 8 + 2], 2); - dump_block_q8_0(&x[i * 8 + 3], 3); - dump_block_q8_0(&x[i * 8 + 4], 4); - dump_block_q8_0(&x[i * 8 + 5], 5); - dump_block_q8_0(&x[i * 8 + 6], 6); - dump_block_q8_0(&x[i * 8 + 7], 7); - } - } -} - -static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_Q8_0x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q8_0_quants(&x[i * 8 + 0], qs, 0); - pack_q8_0_quants(&x[i * 8 + 1], qs, 1); - pack_q8_0_quants(&x[i * 8 + 2], qs, 2); - pack_q8_0_quants(&x[i * 8 + 3], qs, 3); - pack_q8_0_quants(&x[i * 8 + 4], qs, 4); - pack_q8_0_quants(&x[i * 8 + 5], qs, 5); - pack_q8_0_quants(&x[i * 8 + 6], qs, 6); - pack_q8_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].d = 0; - x[i * 8 + 1].d = 0; - x[i * 8 + 2].d = 0; - x[i * 8 + 3].d = 0; - x[i * 8 + 4].d = 0; - x[i * 8 + 5].d = 0; - x[i * 8 + 6].d = 0; - x[i * 8 + 7].d = 0; - } -} - -// repack q8_0 data into q8x4x2 tensor -static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer - repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// repack q8x4x2 tensor into q8_0 data -static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - - // But we only copy the remaining number of bytes to the destination. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// ======== MXFP4x4x2 ==================== -struct x2_mxfp4 { - int v[2]; -}; - -static x2_mxfp4 unpack_mxfp4(uint8_t v) { - x2_mxfp4 x; - x.v[0] = kvalues_mxfp4[(v & 0x0f)]; - x.v[1] = kvalues_mxfp4[(v >> 4)]; - return x; -} - -static void dump_block_mxfp4(const block_mxfp4 * b, int i) { - HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0], - unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0], - unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1], - unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e)); -} - -static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_MXFP4x4x2; - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_e = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size); - - HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0], - unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0], - unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0], - unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]), - GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3])); - - HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1], - unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1], - unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1], - unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]), - GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7])); -} - static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) { static const int qk = QK_MXFP4; for (unsigned int i = 0; i < qk / 2; ++i) { - const uint8_t x0 = (x->qs[i] & 0x0F); - const uint8_t x1 = (x->qs[i] >> 4); + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); qs[bi * qk + i + 0] = x0; qs[bi * qk + i + qk / 2] = x1; } } static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) { - static const int qk = QK4_0; + static const int qk = QK_MXFP4; for (unsigned int i = 0; i < qk / 2; ++i) { const uint8_t x0 = qs[bi * qk + i + 0]; @@ -1308,299 +435,419 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int } } -static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_e = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_mxfp4(&x[i * 8 + 0], 0); - dump_block_mxfp4(&x[i * 8 + 1], 1); - dump_block_mxfp4(&x[i * 8 + 2], 2); - dump_block_mxfp4(&x[i * 8 + 3], 3); - dump_block_mxfp4(&x[i * 8 + 4], 4); - dump_block_mxfp4(&x[i * 8 + 5], 5); - dump_block_mxfp4(&x[i * 8 + 6], 6); - dump_block_mxfp4(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants - - unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0); - unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1); - unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2); - unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3); - unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4); - unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5); - unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); - unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); +// repack q4_0 data into q4_0_tiled tensor +static void repack_q4_0_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q4_0 * src_matrix = (const block_q4_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q4_0 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_q4_0_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 8, 32); + } + } - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - uint8_t * e = (uint8_t *) (y_e + i * eblk_size); - e[0] = x[i * 8 + 0].e; - e[1] = x[i * 8 + 1].e; - e[2] = x[i * 8 + 2].e; - e[3] = x[i * 8 + 3].e; - e[4] = x[i * 8 + 4].e; - e[5] = x[i * 8 + 5].e; - e[6] = x[i * 8 + 6].e; - e[7] = x[i * 8 + 7].e; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_mxfp4x4x2(y, i, k); + ggml_half * scale_dst = (ggml_half *)(tile_dst + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].d : 0; + } + } + } } } } -static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) +// repack q4_0_tiled tensor into q4_0 data +static void repack_tiled_q4_0(void * data, const ggml_tensor * t, size_t size) { + block_q4_0 * dst_matrix = (block_q4_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q4_0 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_e = y + qrow_size; // then scales + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_q4_0_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_mxfp4x4x2(y, i, k); + const ggml_half * scale_src = (const ggml_half *)(tile_src + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[row]; + } + } + } + } } } +} - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants +// repack q4_1 data into q4_1_tiled tensor +static void repack_q4_1_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q4_1 * src_matrix = (const block_q4_1 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q4_1 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_q4_1_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 0, 32); + } + } - bool partial = (nloe && i == nb-1); + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0xf; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0xf; - qs[j+128] = q[j] >> 4; + ggml_half * scale_dst = (ggml_half *)(tile_dst + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + scale_dst[2 * row + 0] = src_expert[r * (ne0 / 32) + kt].d; + scale_dst[2 * row + 1] = src_expert[r * (ne0 / 32) + kt].m; + } else { + scale_dst[2 * row + 0] = 0; + scale_dst[2 * row + 1] = 0; + } + } + } } } - - pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); - pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); - pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); - pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); - pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); - pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); - pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); - pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); - x[i * 8 + 0].e = e[0]; - x[i * 8 + 1].e = e[1]; - x[i * 8 + 2].e = e[2]; - x[i * 8 + 3].e = e[3]; - x[i * 8 + 4].e = e[4]; - x[i * 8 + 5].e = e[5]; - x[i * 8 + 6].e = e[6]; - x[i * 8 + 7].e = e[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_mxfp4(&x[i * 8 + 0], 0); - dump_block_mxfp4(&x[i * 8 + 1], 1); - dump_block_mxfp4(&x[i * 8 + 2], 2); - dump_block_mxfp4(&x[i * 8 + 3], 3); - dump_block_mxfp4(&x[i * 8 + 4], 4); - dump_block_mxfp4(&x[i * 8 + 5], 5); - dump_block_mxfp4(&x[i * 8 + 6], 6); - dump_block_mxfp4(&x[i * 8 + 7], 7); - } } } -static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); - pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); - pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); - pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); - pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); - pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); - pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); - pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].e = 0; - x[i * 8 + 1].e = 0; - x[i * 8 + 2].e = 0; - x[i * 8 + 3].e = 0; - x[i * 8 + 4].e = 0; - x[i * 8 + 5].e = 0; - x[i * 8 + 6].e = 0; - x[i * 8 + 7].e = 0; - } -} - -// repack mxfp4 data into mxfp4x4x2 tensor -static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, - size, t->ne[0], nrows, row_size); - - init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); +// repack q4_1_tiled tensor into q4_1 data +static void repack_tiled_q4_1(void * data, const ggml_tensor * t, size_t size) { + block_q4_1 * dst_matrix = (block_q4_1 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q4_1 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } - // Repack the entire buffer (partial data + zero padding). - repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_q4_1_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); + const ggml_half * scale_src = (const ggml_half *)(tile_src + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[2 * row]; + dst_expert[r * (ne0 / 32) + kt].m = scale_src[2 * row + 1]; + } + } + } + } + } } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); } -// repack mxfp4x4x2 tensor into mxfp4 data -static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; +// repack q8_0 data into q8_0_tiled tensor +static void repack_q8_0_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q8_0 * src_matrix = (const block_q8_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q8_0 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + for (int cp = 0; cp < 16; cp++) { + int col0 = cp * 2; + int col1 = col0 + 1; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + const block_q8_0 * b = (r < ne1 && kt < ne0 / 32) ? &src_expert[r * (ne0 / 32) + kt] : NULL; + tile_dst[cp * 64 + 2 * row + 0] = b ? b->qs[col0] : 0; + tile_dst[cp * 64 + 2 * row + 1] = b ? b->qs[col1] : 0; + } + } - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); + ggml_half * scale_dst = (ggml_half *)(tile_dst + 1024); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].d : 0; + } + } + } + } + } +} - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); +// repack q8_0_tiled tensor into q8_0 data +static void repack_tiled_q8_0(void * data, const ggml_tensor * t, size_t size) { + block_q8_0 * dst_matrix = (block_q8_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q8_0 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + for (int cp = 0; cp < 16; cp++) { + int col0 = cp * 2; + int col1 = col0 + 1; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + block_q8_0 & b = dst_expert[r * (ne0 / 32) + kt]; + b.qs[col0] = tile_src[cp * 64 + 2 * row + 0]; + b.qs[col1] = tile_src[cp * 64 + 2 * row + 1]; + } + } + } - HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, - size, t->ne[0], nrows, row_size); + const ggml_half * scale_src = (const ggml_half *)(tile_src + 1024); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[row]; + } + } + } + } + } + } +} - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros +// repack mxfp4 data into mxfp4_tiled tensor +static void repack_mxfp4_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_mxfp4 * src_matrix = (const block_mxfp4 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_mxfp4 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_mxfp4_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 0, 32); + } + } - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } - memcpy(buf_pd, src, row_size); - unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); + uint8_t * scale_dst = tile_dst + 512; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].e : 0; + } + } + } + } } +} - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); +// repack mxfp4_tiled tensor into mxfp4 data +static void repack_tiled_mxfp4(void * data, const ggml_tensor * t, size_t size) { + block_mxfp4 * dst_matrix = (block_mxfp4 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_mxfp4 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } - // We still need to read and unpack the entire source row because the format is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_mxfp4_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } - // But we only copy the remaining number of bytes to the destination to respect the size limit. - memcpy(dst, buf_rp, n_rem_bytes); + const uint8_t * scale_src = tile_src + 512; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].e = scale_src[row]; + } + } + } + } + } } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); } static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1617,32 +864,32 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, case GGML_TYPE_Q4_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4_0_q4x4x2(tensor, data, size); + repack_q4_0_tiled(tensor, data, size); break; case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4_1_q4x4x2(tensor, data, size); + repack_q4_1_tiled(tensor, data, size); break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8_0_q8x4x2(tensor, data, size); + repack_q8_0_tiled(tensor, data, size); break; case GGML_TYPE_IQ4_NL: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) - repack_q4_0_q4x4x2(tensor, data, size); + repack_q4_0_tiled(tensor, data, size); break; case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_mxfp4_mxfp4x4x2(tensor, data, size); + repack_mxfp4_tiled(tensor, data, size); break; default: @@ -1665,31 +912,31 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, case GGML_TYPE_Q4_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_0(data, tensor, size); + repack_tiled_q4_0(data, tensor, size); break; case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_1(data, tensor, size); + repack_tiled_q4_1(data, tensor, size); break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8x4x2_q8_0(data, tensor, size); + repack_tiled_q8_0(data, tensor, size); break; case GGML_TYPE_IQ4_NL: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_0(data, tensor, size); + repack_tiled_q4_0(data, tensor, size); break; case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_mxfp4x4x2_mxfp4(data, tensor, size); + repack_tiled_mxfp4(data, tensor, size); break; default: @@ -1767,12 +1014,19 @@ static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer } static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) { + if (t->type == GGML_TYPE_Q4_0 || t->type == GGML_TYPE_Q4_1 || t->type == GGML_TYPE_Q8_0 || t->type == GGML_TYPE_IQ4_NL || t->type == GGML_TYPE_MXFP4) { + int64_t ne0 = hex_round_up(t->ne[0], 32); + int64_t ne1 = hex_round_up(t->ne[1], 32); + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + return ggml_row_size(t->type, ne0) * ne1 * ne2 * ne3; + } return ggml_nbytes(t); } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return opt_mbuf; // typically 1GB per buffer - GGML_UNUSED(buffer_type); + auto * context = static_cast(buffer_type->context); + return context->sess->max_bufsize; } static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) { @@ -1803,6 +1057,17 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; +static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { + return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; +} + +static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { + if (!opt_hostbuf) { + return ggml_backend_buffer_is_hexagon(b); + } + return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; +} + struct ggml_hexagon_opbatch { ggml_hexagon_session* sess; @@ -1883,14 +1148,25 @@ struct ggml_hexagon_opbatch { b_vmem += b.size; - HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); + HEX_VERBOSE("ggml-hex: %s add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", sess->c_name(), bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); return bi; } bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { - return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && - (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + const bool is_repack = ggml_backend_buffer_is_hexagon_repack(t->buffer) && ggml_hexagon_is_repack_type(t->type); + if (is_repack) { + ne0 = hex_round_up(ne0, 32); + ne1 = hex_round_up(ne1, 32); + } + int64_t nb1 = is_repack ? ggml_row_size(t->type, ne0) : t->nb[1]; + int64_t nb2 = is_repack ? nb1 * ne1 : t->nb[2]; + int64_t nb3 = is_repack ? nb2 * t->ne[2] : t->nb[3]; + + return (h->ne[0] == ne0) && (h->ne[1] == ne1) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == nb1) && (h->nb[2] == nb2) && (h->nb[3] == nb3); } // add tensor and return its index @@ -1921,19 +1197,35 @@ struct ggml_hexagon_opbatch { htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); h.data = t_offset; - h.size = t_size; h.type = t->type; - h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; - h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + + const bool is_repack = ggml_backend_buffer_is_hexagon_repack(t->buffer) && ggml_hexagon_is_repack_type(t->type); + if (is_repack) { + h.ne[0] = hex_round_up(t->ne[0], 32); + h.ne[1] = hex_round_up(t->ne[1], 32); + h.ne[2] = t->ne[2]; + h.ne[3] = t->ne[3]; + + h.nb[0] = t->nb[0]; + h.nb[1] = ggml_row_size(t->type, h.ne[0]); + h.nb[2] = h.nb[1] * h.ne[1]; + h.nb[3] = h.nb[2] * h.ne[2]; + h.size = h.nb[3] * h.ne[3]; + t_size = h.size; + } else { + h.size = t_size; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + } h.flags = 0; if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { h.flags |= HTP_TENSOR_COMPUTE; } - HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + HEX_VERBOSE("ggml-hex: %s add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", sess->c_name(), ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, - (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + (size_t) h.ne[0], (size_t) h.ne[1], (size_t) h.ne[2], (size_t) h.ne[3]); return ti; } @@ -1962,7 +1254,9 @@ struct ggml_hexagon_opbatch { for (const auto * src : node.get_inputs()) { fit_tensor(src); } - fit_tensor(node.dst()); + for (const auto * output : node.get_outputs()) { + fit_tensor(output); + } if ((extra_bufs + n_bufs) > n_bufs_max) return false; if ((extra_tens + n_tens) > n_tens_max) return false; @@ -1981,7 +1275,8 @@ struct ggml_hexagon_opbatch { ops[n] = node; htp_op_desc &o = h_ops[n]; - memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + memcpy(o.params, node.node->op_params, sizeof(node.node->op_params)); + memcpy(o.kernel_params, node.kernel_params, sizeof(o.kernel_params)); o.opcode = node.opcode; o.flags = 0; @@ -1989,13 +1284,17 @@ struct ggml_hexagon_opbatch { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), ops[n], o.flags); auto inputs = node.get_inputs(); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { - o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + } + + auto outputs = node.get_outputs(); + for (unsigned int i=0; i < HTP_OP_MAX_OUTPUTS; i++) { + o.dst[i] = (i < outputs.size() && outputs[i]) ? add_tensor(outputs[i]) : 0xffff; } - o.dst = add_tensor(node.dst()); } }; @@ -2006,14 +1305,14 @@ struct ggml_hexagon_opqueue { using opvec = std::vector; - std::queue done; // completed batch ids - std::vector op_cache; // per batch op cache - std::vector start_usec; // per batch start time + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { size_t n_bufs = HTP_OP_MAX_BUFS; size_t n_ops = batch_size; - size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + size_t n_tensors = n_ops * HTP_OP_MAX_OUTPUTS + n_ops * HTP_OP_MAX_INPUTS; size_t tr_size = 0; if (opt_profile == 3) { @@ -2200,7 +1499,7 @@ struct ggml_hexagon_opqueue { char evt_str[256] = ""; if (opt_profile == 3) { - sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]", + snprintf(evt_str, sizeof(evt_str), " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]", rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3], rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7], rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]); @@ -2224,6 +1523,7 @@ void ggml_hexagon_session::flush_pending(bool all) { // Read response packet from queue const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; @@ -2404,6 +1704,31 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; + // Query HW info and resolve session options + this->max_bufsize = opt_mbuf; + { + unsigned int hw_n_threads = 0; + unsigned int hw_n_hvx = 0; + unsigned int hw_n_hmx = 0; + unsigned long long hw_vtcm_size = 0; + int hw_err = htp_iface_hwinfo(this->handle, &hw_n_threads, &hw_n_hvx, &hw_n_hmx, &hw_vtcm_size); + if (hw_err == 0) { + this->n_threads = opt_nhvx > 0 ? (uint32_t)opt_nhvx : (uint32_t)hw_n_threads; + this->n_hvx = opt_nhvx > 0 ? (uint32_t)opt_nhvx : (uint32_t)hw_n_hvx; + this->n_hmx = (opt_nhmx != 0) ? (uint32_t)hw_n_hmx : 0; + this->vtcm_size = (uint64_t)hw_vtcm_size; + GGML_LOG_INFO("ggml-hex: %s hwinfo: threads %u, hvx %u, hmx %u, vtcm %llu MB\n", + this->c_name(), this->n_threads, this->n_hvx, this->n_hmx, + (unsigned long long)(this->vtcm_size / (1024 * 1024))); + } else { + GGML_LOG_WARN("ggml-hex: %s failed to query hwinfo (0x%x), using defaults\n", this->c_name(), hw_err); + this->n_threads = opt_nhvx > 0 ? (uint32_t)opt_nhvx : 8; + this->n_hvx = opt_nhvx > 0 ? (uint32_t)opt_nhvx : 8; + this->n_hmx = (opt_nhmx != 0) ? 1 : 0; + this->vtcm_size = 8 * 1024 * 1024; + } + } + // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -2468,11 +1793,12 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { opt_vmem = ggml_hexagon_measure_max_vmem(this); GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); } + this->max_vmem = opt_vmem; - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, this->max_vmem); // Start dspqueue/opbatch processing - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_nhmx, this->max_vmem); if (err != 0) { GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); @@ -2553,16 +1879,6 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { // ** backend interface -static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { - return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; -} - -static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { - if (!opt_hostbuf) { - return ggml_backend_buffer_is_hexagon(b); - } - return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; -} static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; @@ -2653,6 +1969,640 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses return true; } +static bool ggml_hexagon_matmul_is_hmx_eligible( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int ne01_padded, + bool is_matmul_id, + bool is_batched +) { + const int ne00 = src0->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int wtype = src0->type; + + // HMX weight tile requires N to be 32-aligned. + if (ne01_padded % 32 != 0) { + return false; + } + + // HMX supports F16, F32, and repack quantized types. + if (!ggml_hexagon_is_hmx_weight_type((ggml_type) wtype)) { + return false; + } + + // HMX paths require K aligned to 32. + if (ne00 % 32 != 0) { + return false; + } + + // Quantized HMX kernels only handle flat 2D matmul (or matmul_id wrapping flat 2D matmuls). + if (!is_matmul_id && is_batched && wtype != GGML_TYPE_F16) { + return false; + } + + // HMX assumes contiguous row-major layout. + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return false; + } + + // M alignment: Use HMX when M > HTP_MM_HMX_MIN_NROWS + const int m = is_matmul_id ? ne12 : ne11; + if (m <= HTP_MM_HMX_MIN_NROWS) { + return false; + } + + return true; +} + +static bool ggml_hexagon_precompute_hmx_mm_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int wtype, + int ne00_padded, + int ne01_padded, + int ne02, + int ne11, + int ne12, + int ne11_padded, + bool is_matmul_id, + bool is_batched, + size_t vtcm_budget, + struct htp_mm_kernel_params * kparams +) { + const int aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + const bool pipeline = is_matmul_id ? false : htp_mm_hmx_pipeline(ne11); + const int n_threads = (int)sess->n_threads; + const int ne10 = src1->ne[0]; + + const bool is_batched_val = is_matmul_id ? false : is_batched; + const int group_size = (ne02 > 0 ? ne12 / ne02 : 1); + + size_t m_chunk = 0; + size_t n_chunk = 0; + size_t vtcm_size = 0; + bool use_grouped = false; + int act_threads_selected = 0; + + if (is_batched_val && wtype == GGML_TYPE_F16 && group_size > 1) { + // Try grouped path first + const bool use_dma_activation = (src1->nb[1]/sizeof(float) > (size_t)ne00_padded); + size_t best_mblocks = SIZE_MAX; + int best_act_threads = 0; + size_t best_m_chunk = 0; + size_t best_n_chunk = 0; + size_t best_vtcm_size = 0; + + int act_threads = n_threads; + while (act_threads >= 1) { + const size_t f32_scratch_size = use_dma_activation ? hex_align_up(act_threads * HTP_MM_DMA_ACT_MULTIPLIER * ne00_padded * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + size_t group_overhead = 256 + f32_scratch_size; + size_t group_size_per_n, group_size_per_m, group_size_per_mn; + htp_mm_hmx_get_batched_chunk_costs(ne00_padded, group_size, &group_size_per_n, &group_size_per_m, &group_size_per_mn); + + size_t m_chunk_candidate = 0; + size_t n_chunk_candidate = 0; + size_t vtcm_size_candidate = 0; + + if (htp_mm_hmx_compute_chunks(vtcm_budget, group_overhead, group_size_per_n, group_size_per_m, group_size_per_mn, hex_align_up(ne11, 32), ne01_padded, + (size_t) ne01_padded * HTP_MM_HMX_COST_W_DEQUANT, (size_t) ne11 * HTP_MM_HMX_COST_A_CONVERT, + &m_chunk_candidate, &n_chunk_candidate, &vtcm_size_candidate) == 0) { + size_t exact_size = htp_mm_hmx_get_batched_vtcm_size(wtype, ne00_padded, m_chunk_candidate, n_chunk_candidate, group_size, use_dma_activation, pipeline, act_threads); + if (exact_size <= vtcm_budget) { + size_t mblocks = ((size_t) ne11 + m_chunk_candidate - 1) / m_chunk_candidate; + if (mblocks < best_mblocks || (mblocks == best_mblocks && act_threads > best_act_threads)) { + best_mblocks = mblocks; + best_act_threads = act_threads; + best_m_chunk = m_chunk_candidate; + best_n_chunk = n_chunk_candidate; + best_vtcm_size = exact_size; + } + } + } + if (act_threads == 1) { + act_threads = 0; + } else { + act_threads /= 2; + } + } + + if (best_act_threads > 0) { + m_chunk = best_m_chunk; + n_chunk = best_n_chunk; + vtcm_size = best_vtcm_size; + act_threads_selected = best_act_threads; + use_grouped = true; + } + } + + if (!use_grouped) { + // Fallback to simple 2D path (group_size = 1) + size_t best_mblocks = SIZE_MAX; + int best_act_threads = 0; + size_t best_m_chunk = 0; + size_t best_n_chunk = 0; + size_t best_vtcm_size = 0; + + // For MUL_MAT_ID the kernel runs one 2D matmul per expert, with M equal to the number of rows routed to that expert. + // A single expert can receive up to all routed rows (dst->ne[1]*dst->ne[2] = n_expert_used*n_tokens), so size the chunk + // search for that upper bound rather than ne12 (token positions only). + // We recompute m_chunk per expert against the actual count in the NPU kernel. + const int m_id_rows = (int) ((size_t) dst->ne[1] * dst->ne[2]); + const int m_for_chunks = is_matmul_id ? hex_align_up(m_id_rows, 32) : ne11_padded; + const int m_for_cost = is_matmul_id ? m_id_rows : ne11; + + int act_threads = n_threads; + while (act_threads >= 1) { + const size_t act_f32_size = is_matmul_id ? 0 : hex_align_up(act_threads * HTP_MM_DMA_ACT_MULTIPLIER * ne00_padded * sizeof(float), HTP_MM_HMX_TILE_SIZE); + size_t simple_2d_overhead = 256 + act_f32_size; + size_t simple_2d_size_per_n, simple_2d_size_per_m, simple_2d_size_per_mn; + htp_mm_hmx_get_2d_chunk_costs(wtype, ne00_padded, pipeline, aligned_tile_size, &simple_2d_size_per_n, &simple_2d_size_per_m, &simple_2d_size_per_mn); + + size_t m_chunk_candidate = 0; + size_t n_chunk_candidate = 0; + size_t vtcm_size_candidate = 0; + + if (htp_mm_hmx_compute_chunks(vtcm_budget, simple_2d_overhead, simple_2d_size_per_n, simple_2d_size_per_m, simple_2d_size_per_mn, m_for_chunks, ne01_padded, + (size_t) ne01_padded * HTP_MM_HMX_COST_W_DEQUANT, (size_t) m_for_cost * HTP_MM_HMX_COST_A_CONVERT, + &m_chunk_candidate, &n_chunk_candidate, &vtcm_size_candidate) == 0) { + size_t exact_size = htp_mm_hmx_get_2d_vtcm_size(wtype, ne00_padded, m_chunk_candidate, n_chunk_candidate, pipeline, is_matmul_id ? 0 : act_threads, aligned_tile_size); + if (exact_size <= vtcm_budget) { + size_t mblocks = ((size_t) m_for_cost + m_chunk_candidate - 1) / m_chunk_candidate; + if (mblocks < best_mblocks || (mblocks == best_mblocks && act_threads > best_act_threads)) { + best_mblocks = mblocks; + best_act_threads = act_threads; + best_m_chunk = m_chunk_candidate; + best_n_chunk = n_chunk_candidate; + best_vtcm_size = exact_size; + } + } + } + if (act_threads == 1) { + act_threads = 0; + } else { + act_threads /= 2; + } + } + + if (best_act_threads > 0) { + m_chunk = best_m_chunk; + n_chunk = best_n_chunk; + vtcm_size = best_vtcm_size; + act_threads_selected = best_act_threads; + } else { + return false; + } + } + + kparams->n_hmx = 1; + kparams->pipeline = pipeline ? 1 : 0; + kparams->m_chunk = m_chunk; + kparams->n_chunk = n_chunk; + kparams->n_threads = n_threads; + kparams->n_act_threads = act_threads_selected; + kparams->tile_size = htp_mm_get_weight_tile_size(wtype); + kparams->aligned_tile_size = aligned_tile_size; + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = 0; + kparams->vtcm_src1_size = 0; + kparams->vtcm_dst_size = 0; + + if (is_batched && !is_matmul_id) { + kparams->kernel_type = HTP_MM_KERNEL_HMX_F16_BATCHED; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HMX_2D; + } + return true; +} + +static void ggml_hexagon_precompute_hvx_mm_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int wtype, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + bool is_matmul_id, + size_t vtcm_budget, + struct htp_mm_kernel_params * kparams +) { + kparams->n_hmx = 0; + + const bool is_quant = (wtype != GGML_TYPE_F16 && wtype != GGML_TYPE_F32); + const int src1_nrows = ne11 * ne12 * ne13; + + if (is_quant) { + // Quantized HVX + kparams->tile_size = htp_mm_get_weight_tile_size(wtype); + kparams->aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + + const bool k_align = (ne10 % 32 == 0); + + if (is_matmul_id) { + kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0; + uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + uint32_t best_n_prefetch = 2; + size_t total_size = 0; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + total_size = htp_mm_hvx_id_get_vtcm_sizes( + wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d, + &vtcm_src0_size, &vtcm_src1_size + ); + if (total_size <= vtcm_budget) { + best_n_prefetch = d; + break; + } + } + if (best_n_prefetch == 2 && total_size > vtcm_budget) { + total_size = htp_mm_hvx_id_get_vtcm_sizes( + wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2, + &vtcm_src0_size, &vtcm_src1_size + ); + } + kparams->n_prefetch = best_n_prefetch; + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = 0; + } else { + bool try_tiled = (k_align && opt_mm_select >= 2); + if (try_tiled) { + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + if (src1_nrows < (int)sess->n_threads) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_BLOCK; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + } + + uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + uint32_t best_n_prefetch = 2; + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t total_size = 0; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], d, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + if (total_size <= vtcm_budget) { + best_n_prefetch = d; + break; + } + } + if (best_n_prefetch == 2 && total_size > vtcm_budget) { + total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 2, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + } + + kparams->n_prefetch = best_n_prefetch; + + if (total_size <= vtcm_budget) { + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + goto done_quant; + } + HEX_VERBOSE("ggml-hex: %s HVX tiled path VTCM size needed (%zu) > budget (%zu), falling back to HVX flat\n", sess->name.c_str(), total_size, vtcm_budget); + } + + // Flat HVX fallback + { + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + kparams->n_prefetch = 16; + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + } + } + + done_quant:; + } else if (wtype == GGML_TYPE_F16) { + // F16 HVX + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = ggml_is_permuted(src0) || ggml_is_permuted(src1); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t vtcm_size = htp_mm_hvx_get_vtcm_sizes( + HTP_MM_KERNEL_HVX_F16_F16_VTCM, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + if (!is_batched && !is_permuted && vtcm_size <= vtcm_budget) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F16_VTCM; + kparams->src1_row_size = hex_round_up(ne10 * 2, 128); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } else { + if (src1->type == GGML_TYPE_F32) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F32_DDR; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F16_DDR; + } + kparams->src1_row_size = src1->nb[1]; + size_t ddr_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + kparams->vtcm_size = ddr_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } + } else { + // F32 HVX + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = ggml_is_permuted(src0) || ggml_is_permuted(src1); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t vtcm_size = htp_mm_hvx_get_vtcm_sizes( + HTP_MM_KERNEL_HVX_F32_F32_VTCM, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + if (!is_batched && !is_permuted && vtcm_size <= vtcm_budget) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F32_F32_VTCM; + kparams->src1_row_size = hex_round_up(ne10 * 4, 128); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F32_F32_DDR; + kparams->src1_row_size = src1->nb[1]; + size_t ddr_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + kparams->vtcm_size = ddr_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } + } +} + +static void ggml_hexagon_precompute_matmul_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + const int ne00_padded = is_repack ? hex_round_up(ne00, 32) : ne00; + const int ne01_padded = is_repack ? hex_round_up(ne01, 32) : ne01; + const int ne11_padded = hex_round_up(ne11, 32); + + const bool is_matmul_id = (dst->op == GGML_OP_MUL_MAT_ID); + const bool is_batched = (ne02 * ne03 > 1 || ne12 * ne13 > 1); + + const size_t vtcm_budget = sess->vtcm_size; + + // Check HMX eligibility and try precomputing HMX parameters + bool hmx_enabled = (sess->n_hmx > 0) && (opt_mm_select >= 3); + if (hmx_enabled && ggml_hexagon_matmul_is_hmx_eligible(src0, src1, dst, ne01_padded, is_matmul_id, is_batched)) { + if (ggml_hexagon_precompute_hmx_mm_params(sess, src0, src1, dst, wtype, ne00_padded, ne01_padded, ne02, ne11, ne12, ne11_padded, is_matmul_id, is_batched, vtcm_budget, kparams)) { + goto finalize; + } + } + + // Fallback to HVX parameter computation + ggml_hexagon_precompute_hvx_mm_params(sess, src0, src1, dst, wtype, ne02, ne03, ne10, ne11, ne12, ne13, is_matmul_id, vtcm_budget, kparams); + +finalize: + kparams->div_ne12_ne1 = init_fastdiv_values(ne12 * ne11); + kparams->div_ne1 = init_fastdiv_values(ne11); + kparams->div_r2 = init_fastdiv_values(ne02 > 0 ? ne12 / ne02 : 1); + kparams->div_r3 = init_fastdiv_values(ne03 > 0 ? ne13 / ne03 : 1); + kparams->div_ne11 = init_fastdiv_values(ne11); +} + +static void ggml_hexagon_precompute_fused_qkv_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, // Wk + const struct ggml_tensor * src1, // x + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + + const int ne10 = src1->ne[0]; + const int src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + const size_t src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + size_t src0_sz_per_thread = 0; + size_t src2_sz_per_thread = 0; + size_t src3_sz_per_thread = 0; + uint32_t best_n_prefetch = 16; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + size_t src1_sz = src1_sz_per_thread; + + const uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + best_n_prefetch = 2; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + size_t src0_sz = repacked_vtcm_size * sess->n_threads; + size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz; + + if (tiled_vtcm_size <= sess->vtcm_size) { + best_n_prefetch = d; + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(d * tile_row_size, 128); + src3_sz_per_thread = hex_round_up(d * tile_row_size, 128); + break; + } + } + if (best_n_prefetch == 2 && src0_sz_per_thread == 0) { + size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + } + } else { + best_n_prefetch = 16; + src0_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src2_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src3_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + } + + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + + size_t src0_sz = src0_sz_per_thread * sess->n_threads; + size_t src1_sz = src1_sz_per_thread; + size_t src2_sz = src2_sz_per_thread * sess->n_threads; + size_t src3_sz = src3_sz_per_thread * sess->n_threads; + + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz; + bool try_tiled = (opt_mm_select >= 2); + if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_src3_size = src3_sz; + kparams->vtcm_size = tiled_vtcm_size; + kparams->n_prefetch = best_n_prefetch; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + size_t flat_src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + size_t flat_src1_sz = hex_round_up(flat_src1_row_size * src1_nrows, 128); + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = flat_src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_src3_size = src3_sz; + kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz; + kparams->n_prefetch = best_n_prefetch; + } +} + +static void ggml_hexagon_precompute_fused_ffn_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, // Wgate + const struct ggml_tensor * src1, // y + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + + const int ne10 = src1->ne[0]; + const int src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + const size_t src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + size_t src0_sz_per_thread = 0; + size_t src2_sz_per_thread = 0; + uint32_t best_n_prefetch = 16; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + size_t src1_sz = src1_sz_per_thread; + + const uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + best_n_prefetch = 2; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + size_t src0_sz = repacked_vtcm_size * sess->n_threads; + size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz; + + if (tiled_vtcm_size <= sess->vtcm_size) { + best_n_prefetch = d; + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(d * tile_row_size, 128); + break; + } + } + if (best_n_prefetch == 2 && src0_sz_per_thread == 0) { + size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + } + } else { + best_n_prefetch = 16; + src0_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src2_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + } + + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + + size_t src0_sz = src0_sz_per_thread * sess->n_threads; + size_t src1_sz = src1_sz_per_thread; + size_t src2_sz = src2_sz_per_thread * sess->n_threads; + + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz; + bool try_tiled = (opt_mm_select >= 2); + if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_size = tiled_vtcm_size; + kparams->n_prefetch = best_n_prefetch; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + size_t flat_src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + size_t flat_src1_sz = hex_round_up(flat_src1_row_size * src1_nrows, 128); + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = flat_src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz; + kparams->n_prefetch = best_n_prefetch; + } +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -2675,12 +2625,13 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } - if (ggml_nrows(src0) > 16 * 1024) { - return false; // typically the lm-head which would be too large for VTCM + // hardcoded limit to refuse the lm-head for now + if (src0->ne[1] > 32768) { + return false; } - if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) { - return false; // no huge batches or broadcasting (for now) + if (src1->ne[2] != 1 || src1->ne[3] != 1) { + return false; // no broadcasting (for now) } // src0 (weights) must be repacked @@ -2691,16 +2642,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s case GGML_TYPE_F16: if (src0->nb[1] < src0->nb[0]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); return false; } - if (ggml_nrows(src1) > 1024) { - return false; // no huge batches (for now) - } break; case GGML_TYPE_F32: @@ -2708,22 +2654,24 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } if (src0->nb[1] < src0->nb[0]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); return false; } if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); return false; } - if (ggml_nrows(src1) > 1024) { - return false; // no huge batches (for now) - } break; default: return false; } + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_matmul_params(sess, src0, src1, dst, &kparams); + if ((size_t)kparams.vtcm_size > sess->vtcm_size) { + HEX_VERBOSE("ggml-hex: %s supported MUL_MAT VTCM size needed (%d) > budget (%zu)\n", sess->c_name(), kparams.vtcm_size, sess->vtcm_size); + return false; + } + return true; } @@ -2757,6 +2705,13 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session return false; } + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_matmul_params(sess, src0, src1, dst, &kparams); + if ((size_t)kparams.vtcm_size > sess->vtcm_size) { + HEX_VERBOSE("ggml-hex: %s supported MUL_MAT_ID VTCM size needed (%d) > budget (%zu)\n", sess->c_name(), kparams.vtcm_size, sess->vtcm_size); + return false; + } + return true; } @@ -3288,47 +3243,172 @@ static inline bool op_is_compute(ggml_tensor *node) return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } +static bool is_hmx_eligible(const ggml_tensor * t) { + if (opt_nhmx == 0) { return false; } + + const ggml_tensor * src0 = t->src[0]; + const ggml_tensor * src1 = t->src[1]; + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + const bool is_matmul_id = (t->op == GGML_OP_MUL_MAT_ID); + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + const int ne01_padded = is_repack ? hex_round_up(src0->ne[1], 32) : src0->ne[1]; + + return ggml_hexagon_matmul_is_hmx_eligible(src0, src1, t, ne01_padded, is_matmul_id, is_batched); +} + +static bool is_mergeable_mul_mat(const ggml_tensor * t) { + if (!t || t->op != GGML_OP_MUL_MAT) return false; + if (t->src[1]->type != GGML_TYPE_F32) return false; + return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t); +} + +static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) { + if (!is_mergeable_mul_mat(n1) || !is_mergeable_mul_mat(n2)) { + return false; + } + if (n1->src[1] != n2->src[1]) { + return false; + } + if (n1->src[0]->ne[0] != n2->src[0]->ne[0] || + n1->src[0]->ne[1] != n2->src[0]->ne[1]) { + return false; + } + if (n1->src[0]->type != n2->src[0]->type) { + return false; + } + return true; +} + +static bool is_qkv_mergeable(const ggml_tensor * n_q, const ggml_tensor * n_k, const ggml_tensor * n_v) { + if (!is_mergeable_mul_mat(n_q) || !is_mergeable_mul_mat(n_k) || !is_mergeable_mul_mat(n_v)) { + return false; + } + if (n_q->src[1] != n_k->src[1] || n_q->src[1] != n_v->src[1]) { + return false; + } + if (n_q->src[0]->type != n_k->src[0]->type || n_q->src[0]->type != n_v->src[0]->type) { + return false; + } + if (n_k->src[0]->ne[0] != n_v->src[0]->ne[0] || + n_k->src[0]->ne[1] != n_v->src[0]->ne[1]) { + return false; + } + if (n_q->src[0]->ne[0] != n_k->src[0]->ne[0]) { + return false; + } + return true; +} + +static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph * graph, int & i, std::vector & nodes) { + if (!opt_opfusion) { + return false; + } + + ggml_tensor * n = graph->nodes[i]; + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + htp_opnode node(n, {}, HTP_OP_RMS_NORM_MUL); + node.add_fused(next_node); + nodes.push_back(std::move(node)); + i++; // skip the fused MUL node + return true; + } + } + + if (is_mergeable_mul_mat(n)) { + ggml_tensor * n1 = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + ggml_tensor * n2 = (i + 2 < graph->n_nodes) ? graph->nodes[i + 2] : nullptr; + if (is_qkv_mergeable(n, n1, n2)) { + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_fused_qkv_params(sess, n1->src[0], n1->src[1], &kparams); + if ((size_t)kparams.vtcm_size <= sess->vtcm_size) { + // Reorder to KVQ: K (n1), V (n2), Q (n) + htp_opnode node(n1, {}, HTP_OP_MUL_MAT_QKV); + node.add_fused(n2, true); + node.add_fused(n, true); + memcpy(node.kernel_params, &kparams, sizeof(kparams)); + nodes.push_back(std::move(node)); + i += 2; + return true; + } else { + HEX_VERBOSE("ggml-hex: skip QKV fusion because VTCM needed (%d) > budget (%zu)\n", + kparams.vtcm_size, sess->vtcm_size); + } + } + if (is_mergeable_mul_mat_pair(n, n1)) { + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_fused_ffn_params(sess, n->src[0], n->src[1], &kparams); + if ((size_t)kparams.vtcm_size <= sess->vtcm_size) { + htp_opnode node(n, {}, HTP_OP_MUL_MAT_FFN); + node.add_fused(n1, true); + memcpy(node.kernel_params, &kparams, sizeof(kparams)); + nodes.push_back(std::move(node)); + i += 1; + return true; + } else { + HEX_VERBOSE("ggml-hex: skip FFN fusion because VTCM needed (%d) > budget (%zu)\n", + kparams.vtcm_size, sess->vtcm_size); + } + } + } + + return false; +} + static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast(backend->context); HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); - std::vector nodes; - nodes.reserve(graph->n_nodes); - - // Fusion - for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * n = graph->nodes[i]; - if (!op_is_compute(n)) { - continue; - } + const std::vector * nodes_ptr = nullptr; + std::vector computed_nodes; - ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + // Check for cache hit + bool cache_hit = (graph->uid != 0 && sess->cached_graph.uid == graph->uid); + if (cache_hit) { + nodes_ptr = &sess->cached_graph.htp_nodes; + } else { + computed_nodes.reserve(graph->n_nodes); - htp_opnode node = { - /*.node =*/ n, - /*.fused =*/ {}, - /*.opcode =*/ HTP_OP_INVALID - }; + // Fuse and finalize + for (int i = 0; i < graph->n_nodes; ++i) { + ggml_tensor * n = graph->nodes[i]; + if (!op_is_compute(n)) { + continue; + } - if (n->op == GGML_OP_RMS_NORM && next_node) { - if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - node.add_fused(next_node); - node.opcode = HTP_OP_RMS_NORM_MUL; - i++; // skip the fused MUL node + if (try_fuse_node(sess, graph, i, computed_nodes)) { + continue; } - } - if (node.opcode == HTP_OP_INVALID) { + htp_opnode node(n, {}, HTP_OP_INVALID); node.opcode = op_remap_to_htp(n); + if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID) { + ggml_hexagon_precompute_matmul_params(sess, + node.node->src[0], node.node->src[1], node.node, + (struct htp_mm_kernel_params *)node.kernel_params + ); + } + computed_nodes.push_back(std::move(node)); } - nodes.push_back(std::move(node)); + if (graph->uid != 0) { + sess->cached_graph.uid = graph->uid; + sess->cached_graph.htp_nodes = std::move(computed_nodes); + nodes_ptr = &sess->cached_graph.htp_nodes; + } else { + nodes_ptr = &computed_nodes; + } } // Queue and execute if (opt_opstage & HTP_OPSTAGE_QUEUE) { - for (const auto & node : nodes) { + for (const auto & node : *nodes_ptr) { sess->enqueue_op(node); } } @@ -3991,16 +4071,19 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); - const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE"); + const char * str_opfusion = getenv("GGML_HEXAGON_OPFUSION"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_nhmx = getenv("GGML_HEXAGON_NHMX"); + const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE"); // Init Arch first since it affects other defaults if (!str_arch) { @@ -4029,12 +4112,14 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128); + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; + opt_opfusion = str_opfusion ? atoi(str_opfusion) : opt_opfusion; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx); + opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 52c727c6206..6fe23b0d6aa 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -5,10 +5,12 @@ #include "ggml-backend-impl.h" #include "ggml-common.h" +#include #include #include #include #include "htp-ops.h" +#include "htp/matmul-ops.h" struct htp_opnode { ggml_tensor * node = nullptr; @@ -17,6 +19,13 @@ struct htp_opnode { htp_op_code opcode = HTP_OP_INVALID; + std::vector extra_dsts; + + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0}; + + htp_opnode(ggml_tensor * node = nullptr, std::vector fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector extra_dsts = {}) + : node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {} + ggml_op op() const { return node->op; } @@ -25,6 +34,26 @@ struct htp_opnode { return fused.empty() ? node : fused.back(); } + void add_fused(ggml_tensor * t, bool extra_dst = false) { + fused.push_back(t); + if (extra_dst) { + extra_dsts.push_back(t); + } + } + + std::vector get_outputs() const { + std::vector res; + if (extra_dsts.empty()) { + res.push_back(dst()); + } else { + res.push_back(node); + for (const auto * x : extra_dsts) { + res.push_back(x); + } + } + return res; + } + const ggml_tensor * src0() const { return node->src[0]; } @@ -37,10 +66,6 @@ struct htp_opnode { return ggml_op_is_empty(node->op); } - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - bool stackable() const { switch (this->op()) { case GGML_OP_MUL_MAT: @@ -131,87 +156,117 @@ struct htp_opformat { char types[16 * GGML_MAX_SRC]; char buffs[64 * GGML_MAX_SRC]; char names[64 * GGML_MAX_SRC]; + char kparams[128]; - int format_tensor_dims(char * str, const struct ggml_tensor * t) { + int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) { if (!t) { - return sprintf(str, "NONE"); + return snprintf(str, max_size, "NONE"); } if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); } } - void format_op_dims(char * str, const htp_opnode & node) { + void format_op_dims(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += format_tensor_dims(p, inputs[0]); + p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p)); for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, inputs[i]); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } char self[64]; - format_tensor_dims(self, node.dst()); - p += sprintf(p, "%s", self); + format_tensor_dims(self, sizeof(self), node.dst()); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p)); + } } - int format_tensor_strides(char * str, const struct ggml_tensor * t) { + int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) { if (!t) { - return sprintf(str, "NONE"); + return snprintf(str, max_size, "NONE"); } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); } } - void format_op_strides(char * str, const htp_opnode & node) { + void format_op_strides(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += format_tensor_strides(p, inputs[0]); + p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p)); for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, inputs[i]); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } char self[64]; - format_tensor_strides(self, node.dst()); - p += sprintf(p, "%s", self); + format_tensor_strides(self, sizeof(self), node.dst()); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p)); + } } - void format_op_types(char * str, const htp_opnode & node) { + void format_op_types(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p)); + } for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p)); + } } const char * tensor_buff_name(const struct ggml_tensor * t) { @@ -221,51 +276,102 @@ struct htp_opformat { return "NONE"; } - void format_op_buffs(char * str, const htp_opnode & node) { + void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", tensor_buff_name(inputs[0])); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p)); + } for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", tensor_buff_name(node.dst())); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p)); + } } - void format_op_names(char * str, const htp_opnode & node) { + void format_op_names(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p)); + } for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", node.dst()->name); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p)); + } + } + void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) { + if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID || + node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) { + const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params; + const char * path = "unknown"; + int32_t type = kparams->kernel_type; + if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) { + path = "hmx-tiled"; + } else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM || + type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) { + path = "hvx-tiled"; + } else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR || + type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR || + type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + path = "hvx-flat"; + } + snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size); + } else { + snprintf(str, max_size, "----"); + } } void format(const htp_opnode & node) { - format_op_dims(dims, node); - format_op_strides(strides, node); - format_op_types(types, node); - format_op_buffs(buffs, node); - format_op_names(names, node); + format_op_dims(dims, sizeof(dims), node); + format_op_strides(strides, sizeof(strides), node); + format_op_types(types, sizeof(types), node); + format_op_buffs(buffs, sizeof(buffs), node); + format_op_names(names, sizeof(names), node); + format_kernel_params(kparams, sizeof(kparams), node); } - htp_opformat() {} + htp_opformat() { + strides[0] = '\0'; + dims[0] = '\0'; + types[0] = '\0'; + buffs[0] = '\0'; + names[0] = '\0'; + kparams[0] = '\0'; + } htp_opformat(const htp_opnode & node) { format(node); } }; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 31ba5276231..c48a5b86e3b 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c -) - -target_compile_definitions(${HTP_LIB} PRIVATE - $,HTP_DEBUG=1,NDEBUG=1> - $,FARF_HIGH=1,> - FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -if (GGML_HEXAGON_FA_EXP2_HF) - message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") - target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) -endif() - -# HMX acceleration: available on v73+ architectures -set(HTP_HMX_VERSIONS v73 v75 v79 v81) -list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) - -if (_hmx_idx GREATER_EQUAL 0) - target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - ) - - # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) - set_source_files_properties( - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - PROPERTIES COMPILE_OPTIONS "-mhmx" - ) - - target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) -endif() - -build_idl(htp_iface.idl ${HTP_LIB}) - -target_sources(${HTP_LIB} PRIVATE + hmx-queue.c + flash-attn-ops.c + hmx-flash-attn-ops.c matmul-ops.c binary-ops.c unary-ops.c @@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE softmax-ops.c act-ops.c rope-ops.c - flash-attn-ops.c set-rows-ops.c get-rows-ops.c cpy-ops.c @@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE pad-ops.c ) +target_compile_definitions(${HTP_LIB} PRIVATE + $,HTP_DEBUG=1,NDEBUG=1> + $,FARF_HIGH=1,>) + +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +build_idl(htp_iface.idl ${HTP_LIB}) + set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index ed5c198468c..3eff2a3986e 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED) endif() set(HEXAGON_TOOLCHAIN_INCLUDED true) -#Cross Compiling for Hexagon +# Cross Compiling for Hexagon set(HEXAGON TRUE) set(CMAKE_SYSTEM_NAME QURT) set(CMAKE_SYSTEM_PROCESSOR Hexagon) @@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) set(CUSTOM_RUNELF_PATH "") -#To fix backward compatibility with EAI addon. if (NOT HEXAGON_SDK_ROOT) set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) endif() @@ -31,7 +30,6 @@ endif() file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) -#Get the Binary extension of the Hexagon Toolchain if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) set(HEXAGON_TOOLCHAIN_SUFFIX .exe) endif() @@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES HEXAGON_TOOLS_ROOT ) -#QURT Related includes and linker flags +# QURT Related includes and linker flags set(V_ARCH ${HEXAGON_ARCH}) set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}") set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}") -if( ${TREE} MATCHES PAKMAN ) +if (${TREE} MATCHES PAKMAN) set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}") endif() message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}") @@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS ) STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}") -set(QURT_END_LINK_LIBS - ${TARGET_DIR}/fini.o - ) +set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o) -#Non QURT related includes and linker flags +# Non QURT related includes and linker flags set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}") @@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API) set(WRAP_MEMALIGN -Wl,--wrap=memalign) endif() +set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx") + set(PIC_SHARED_LD_FLAGS - -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} + ${ARCH_FLAGS} -G0 -fpic -Wl,-Bsymbolic @@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}") set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}") -#System include paths +# System include paths include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs) include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef) include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs) -#LLVM toolchain setup -#Compiler paths, options and architecture +# LLVM toolchain setup +# Compiler paths, options and architecture set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX}) set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX}) @@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon) set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") -#Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +# Compiler Options +set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index b7511cd6442..65f7844ae33 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -18,7 +18,8 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" -#include "hmx-ops.h" + +int hmx_flash_attn_ext(struct htp_ops_context * octx); // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) @@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } -#ifdef HTP_HAS_HMX // HMX path: head_dim multiple of 64, F16 KV, and no sinks if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); @@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } // VTCM too small or other failure -> fall through to HVX path } -#endif struct htp_fa_context factx; factx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/hex-common.h b/ggml/src/ggml-hexagon/htp/hex-common.h new file mode 100644 index 00000000000..4714486a042 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-common.h @@ -0,0 +1,80 @@ +#ifndef HEX_COMMON_H +#define HEX_COMMON_H + +#include +#include +#include + +#ifndef SIZE_MAX +#define SIZE_MAX ((size_t)-1) +#endif + +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + +static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + +static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +#endif // HEX_COMMON_H diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 93c21ebe5ee..8031a5679c4 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -5,6 +5,7 @@ #include #include #include +#include "hex-utils.h" #include "hex-profile.h" @@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src) return p; } -#if __HVX_ARCH__ < 73 -static const uint32_t dma_src_l2_bypass_on = 1; -static const uint32_t dma_dst_l2_bypass_on = 0; -#else static const uint32_t dma_src_l2_bypass_on = 1; static const uint32_t dma_dst_l2_bypass_on = 1; -#endif static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 8e6e3ea7506..07930bef6ec 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -11,14 +11,7 @@ #include "hex-fastdiv.h" #include "hex-dump.h" - -#ifndef MAX -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#endif - -#ifndef MIN -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#endif +#include "hex-common.h" static inline uint64_t hex_get_cycles() { uint64_t cycles = 0; @@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } -static inline uint32_t hex_ceil_pow2(uint32_t x) { - if (x <= 1) { return 1; } - int p = 2; - x--; - while (x >>= 1) { p <<= 1; } - return p; -} - -static inline size_t hmx_ceil_div(size_t num, size_t den) { - return (num + den - 1) / den; -} - -static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline size_t hex_align_up(size_t v, size_t align) { - return hmx_ceil_div(v, align) * align; -} - -static inline size_t hex_align_down(size_t v, size_t align) { - return (v / align) * align; -} - -static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { - return m * ((n + m - 1) / m); -} - -static inline size_t hex_smin(size_t a, size_t b) { - return a < b ? a : b; -} - -static inline size_t hex_smax(size_t a, size_t b) { - return a > b ? a : b; -} - -static inline void hex_swap_ptr(void ** p1, void ** p2) { - void * t = *p1; - *p1 = *p2; - *p2 = t; -} - static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 986dde148dd..996fd597570 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -49,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + + v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = struct hmx_fa_context { const struct htp_ops_context * octx; - bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 uint32_t n_threads; // Op parameters @@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + __fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start); @@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 - const uint32_t n_threads = use_pipeline ? n_threads_init : 1; + const uint32_t n_threads = pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", - neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget); // ======== Build context ======== struct hmx_fa_context factx; @@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_blocks = n_kv_blocks; factx.is_q_fp32 = (q->type == HTP_TYPE_F32); factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); - factx.use_pipeline = use_pipeline; + factx.pipeline = pipeline; factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) @@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); - if (use_pipeline) { + if (pipeline) { factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); } else { factx.vtcm_v_tiles[1] = NULL; @@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== HMX lock strategy ======== // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. // Fallback: main thread holds the lock (original behavior). - if (!factx.use_pipeline) { + if (!factx.pipeline) { HAP_compute_res_hmx_lock(ctx->vtcm_rctx); } @@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); - if (factx.use_pipeline) { + if (factx.pipeline) { // ================================================================== // Pipeline path: HVX phases ‖ HMX queue worker // ================================================================== @@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); // HMX: O_final = diag(1/l) @ O_prev - if (factx.use_pipeline) { + if (factx.pipeline) { on_job.o_curr = o_tile_curr; on_job.o_prev = o_tile_prev; on_job.d_tiles = factx.vtcm_d_tiles; @@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { } // end KV head loop } // end batch loop - if (factx.use_pipeline) { + if (factx.pipeline) { hmx_queue_suspend(ctx->hmx_queue); } else { HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c deleted file mode 100644 index 5c37f24ff00..00000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ /dev/null @@ -1,2080 +0,0 @@ -#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include -#include - -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" - -#include "hex-dma.h" -#include "hex-fastdiv.h" -#include "worker-pool.h" - -#include "hvx-utils.h" -#include "hvx-dump.h" -#include "htp-ctx.h" -#include "htp-ops.h" - -#include "hmx-ops.h" -#include "hmx-utils.h" -#include "hmx-queue.h" -#include "hex-profile.h" - -#include "vtcm-utils.h" - -static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, -}; - -static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, -}; - -// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value -// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 -static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, -}; - -static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, - 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, -}; - -// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes -#define HMX_X4X2_SCALES_PER_BLK 8 -#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) -#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) - -// Compute the byte stride of one row in x4x2 format. -// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because -// x4x2 packing has the same density as block_q4_0 / block_q8_0. -// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] -// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). -// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). -static inline size_t get_x4x2_row_stride(int weight_type, int k) { - int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - switch (weight_type) { - case HTP_TYPE_Q4_0: - case HTP_TYPE_IQ4_NL: - return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb - case HTP_TYPE_Q4_1: - return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb - case HTP_TYPE_Q8_0: - return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb - case HTP_TYPE_MXFP4: - return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb - case HTP_TYPE_F16: - return (size_t) k * sizeof(__fp16); - case HTP_TYPE_F32: - return (size_t) k * sizeof(float); - default: - return 0; - } -} - -// --- Overflow-safe arithmetic for VTCM budget calculation --- - -static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { - if (a != 0 && b > SIZE_MAX / a) return true; - *out = a * b; - return false; -} - -static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { - if (a > SIZE_MAX - b) return true; - *out = a + b; - return false; -} - -// Search for optimal (mc, nc) chunk sizes within VTCM budget. -// -// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead -// -// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. -// All matmul paths repeat weight processing per M-block and activation loading -// per N-block, so discrete block counts drive total overhead. -// Tie-break: when cost is equal, prefer larger mc * nc. -// -// Caller-provided coefficients: -// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). -// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). -// -// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. -// Returns 0 on success, -1 if VTCM is insufficient. -static int hmx_compute_chunks(size_t vtcm_total, - size_t overhead, - size_t per_n_cost, - size_t per_m_cost, - size_t per_mn_cost, - int m, - int n, - size_t m_block_cost, - size_t n_block_cost, - size_t * m_chunk_out, - size_t * n_chunk_out, - size_t * total_out) { - if (m <= 0 || n <= 0) return -1; - if (vtcm_total <= overhead) return -1; - if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; - - const size_t usable = vtcm_total - overhead; - - size_t best_cost = SIZE_MAX; - size_t best_mn = 0; - size_t best_m = 0, best_n = 0; - - const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); - for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { - size_t n_fixed = 0, ncmn = 0, mc_denom = 0; - if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; - if (n_fixed >= usable) goto next_nc; - - if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; - if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; - - { - size_t remain = usable - n_fixed; - size_t mc = remain / mc_denom; - mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); - mc = hex_smin(mc, (size_t)m); - - if (mc == 0) { - goto next_nc; - } - - size_t mblocks = ((size_t) m + mc - 1) / mc; - size_t nblocks = ((size_t) n + nc - 1) / nc; - size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; - size_t mn = mc * nc; - if (cost < best_cost || (cost == best_cost && mn > best_mn)) { - best_cost = cost; - best_mn = mn; - best_m = mc; - best_n = nc; - } - } - -next_nc: - if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow - } - - if (best_m == 0 || best_n == 0) return -1; - - // Compute exact total (with overflow checks) - size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; - if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; - if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; - if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; - if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; - if (hmx_add_overflow(t0, t1, &total)) return -1; - if (hmx_add_overflow(total, t2, &total)) return -1; - if (hmx_add_overflow(total, overhead, &total)) return -1; - - *m_chunk_out = best_m; - *n_chunk_out = best_n; - *total_out = total; - return 0; -} - -// --- x4x2 format dequantizers --- - -// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. -// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles -// of the same 32 packed bytes. -static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using -// full HVX vector width. -// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. -static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - - HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); - HVX_Vector v_lo = Q6_V_lo_W(vp_int16); - HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - - v_lo = Q6_Vhf_equals_Vh(v_lo); - v_hi = Q6_Vhf_equals_Vh(v_hi); - - HVX_Vector vscale = hvx_vmemu(scales_4); - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_dm = hvx_vmemu(scale_offset); - HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); - HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); - - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); -} - -static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); - HVX_Vector v_lo = Q6_V_lo_W(vp_int16); - HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - - v_lo = Q6_Vhf_equals_Vh(v_lo); - v_hi = Q6_Vhf_equals_Vh(v_hi); - - HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); - HVX_Vector vd = Q6_V_lo_W(dm_deal); - HVX_Vector vm = Q6_V_hi_W(dm_deal); - - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); - - HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); - HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -// LUT-based dequantizers for non-linear IQ4_NL format. -static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); - - HVX_Vector vscale = hvx_vmemu(scales_4); - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. -static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { - HVX_Vector vq = hvx_vmemu(quants_32); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -// --- MXFP4 E8M0 scale conversion and dequantization --- -// -// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. -// Scalar loads from the stack array execute on the scalar pipeline, in parallel -// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop. -// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 -// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. - -typedef struct { - __fp16 v[8] __attribute__((aligned(16))); -} mxfp4_scales_t; - -static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { - mxfp4_scales_t s; - HVX_Vector v = hvx_vmemu(e8m0_8); - HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); - vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); - vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); - vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); - vh = Q6_Vh_vasl_VhR(vh, 10); - hvx_vec_store_u(s.v, 16, vh); - return s; -} - -static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { - return hvx_vec_splat_f16(scales.v[idx]); -} - -// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. -static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, - bool upper_nibbles, - int sub_blk, - const HVX_Vector vlut_cvt, - mxfp4_scales_t scales) { - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); -} - -// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, - bool upper_nibbles, - int sub_blk_base, - const HVX_Vector vlut_cvt, - mxfp4_scales_t scales) { - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); - - HVX_VectorPred q64 = Q6_Q_vsetq_R(64); - HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), - mxfp4_extract_splat(scales, sub_blk_base + 1)); - HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), - mxfp4_extract_splat(scales, sub_blk_base + 3)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; - return r; -} - -typedef struct { - __fp16 *dst; - const uint8_t *src; - int n_cols; - int k_block; - size_t row_stride; - int weight_type; - int n_tot_tiles; - int n_tiles_per_task; - int n_tasks; - int n_k_tiles; - struct fastdiv_values n_k_tiles_div; - struct htp_thread_trace * traces; -} x4x2_dequantize_state_t; - -// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. -// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. -// Output: vtcm_dst in tile-major FP16 layout. - -#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ -static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ - const x4x2_dequantize_state_t *state, \ - int start_tile, int end_tile) { \ - \ - const int n_k_tiles = state->n_k_tiles; \ - const int qrow_size = (unsigned)state->k_block / 2; \ - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ - const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ - \ - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ - \ - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ - \ - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ - \ - if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ - unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ - unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ - bool upper = (sub_blk_base >= 4); \ - unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ - unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ - \ - __fp16 *tile_bases[4]; \ - for (unsigned g = 0; g < 4; g++) { \ - tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ - } \ - \ - HVX_Vector v_off = v_scat_base; \ - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ - \ - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ - const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ - const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ - \ - HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ - r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - \ - HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ - r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - } \ - \ - for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ - t += 4; kt += 4; \ - continue; \ - } \ - \ - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ - { \ - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ - unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ - bool upper = (sub_blk >= 4); \ - unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ - unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ - \ - HVX_Vector v_off = v_scat_base; \ - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ - \ - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ - const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ - const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ - \ - HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ - HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ - ? dequantize_x4x2_##helper_prefix##_group_hvx( \ - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ - : Q6_V_vzero(); \ - \ - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - } \ - (void) *(volatile HVX_Vector *)(tile_base); \ - } \ - ++t; ++kt; \ - } \ - \ - if (start_tile < end_tile) { \ - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ - } \ -} \ - \ -static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \ - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ - int start = task_id * state->n_tiles_per_task; \ - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ - dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ - } \ - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ -} - -DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) -DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) -DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) - -static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const int qrow_size = (unsigned)state->k_block / 2; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - // Batch-4 fast path for MXFP4 - if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { - int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; - bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; - - __fp16 * tile_bases[4]; - for (int g = 0; g < 4; g++) { - tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; - } - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - const uint8_t * r0 = state->src + row0 * state->row_stride; - const uint8_t * r1 = state->src + row1 * state->row_stride; - - mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - - HVX_Vector_x4 dv0, dv1; - dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); - if (row1 < state->n_cols) { - mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); - } else { - dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); - } - - for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); - } - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); - } - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - - for (int g = 0; g < 4; g++) { - (void) *(volatile HVX_Vector *) (tile_bases[g]); - } - - t += 4; kt += 4; - continue; - } - - // Single-tile fallback - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; - bool upper = (sub_blk >= 4); - int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t * r0 = state->src + row0 * state->row_stride; - const uint8_t * r1 = state->src + row1 * state->row_stride; - - mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - - HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); - HVX_Vector v1; - if (row1 < state->n_cols) { - mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); - } else { - v1 = Q6_V_vzero(); - } - - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *) (tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int blk_idx = (kt * 32) / QK_Q8_0x4x2; - int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; - int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void convert_f16_weight_to_fp16_tiles_task( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int byte_off = kt * 32 * sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); - HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - convert_f16_weight_to_fp16_tiles_task(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void quantize_f32_weight_to_fp16_tiles_task( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int byte_off = kt * 32 * sizeof(float); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); - HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - - HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - quantize_f32_weight_to_fp16_tiles_task(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - - -static void dequantize_x4x2_weight_chunk_to_fp16_tiles( - struct htp_context *ctx, __fp16 *vtcm_dst, - const void *vtcm_src, int n_cols, int k_block, - size_t row_stride, int weight_type, - int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn, int n_threads) { - - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - assert(k_block % HMX_FP16_TILE_N_COLS == 0); - - size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - size_t n_tot_tiles = n_col_tiles * n_k_tiles; - - size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); - - x4x2_dequantize_state_t state; - state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; - state.n_tot_tiles = n_tot_tiles; - state.n_tiles_per_task = n_tiles_per_task; - state.dst = vtcm_dst; - state.src = (const uint8_t *)vtcm_src; - state.n_cols = n_cols; - state.k_block = k_block; - state.row_stride = row_stride; - state.weight_type = weight_type; - state.n_k_tiles = n_k_tiles; - state.n_k_tiles_div = n_k_tiles_div; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - dequant_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); - } -} - -// --- End x4x2 dequantizers --- - -#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics - -// requires external HMX lock -static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, - int n_row_tiles, int n_col_tiles, int n_dot_tiles) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *)scales); - for (int r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < n_col_tiles; ++c) { - Q6_mxclracc_hf(); - - const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - - for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { - k_block = hex_smin(n_dot_tiles - k, 32); - const uint32_t range = 2048u * (uint32_t)k_block - 1; - Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); - row_tiles += k_block * HMX_FP16_TILE_N_ELMS; - col_tiles += k_block * HMX_FP16_TILE_N_ELMS; - } - - __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; - Q6_mxmem_AR_after_hf(out_tile, 0); - } - } -} - -// --- Async HMX matmul job (for pipeline overlap) --- - -typedef struct { - __fp16 * output; - const __fp16 * activation; - const __fp16 * weight; - const __fp16 * scales; - uint32_t n_row_tiles; - uint32_t n_col_tiles; - uint32_t n_dot_tiles; -} hmx_matmul_job_t; - -static void hmx_matmul_worker_fn(void * data) { - hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; - FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); - core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); -} - -static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, - __fp16 * output, - const __fp16 * activation, - const __fp16 * weight, - const __fp16 * scales, - int n_row_tiles, - int n_col_tiles, - int n_dot_tiles) { - job->output = output; - job->activation = activation; - job->weight = weight; - job->scales = scales; - job->n_row_tiles = n_row_tiles; - job->n_col_tiles = n_col_tiles; - job->n_dot_tiles = n_dot_tiles; -} - -// output : fp16 -> f32p - -static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - - const HVX_Vector one = hvx_vec_splat_f16(1.0); - - for (size_t r = 0; r < n_rows; r += 2) { - const size_t r0 = r / HMX_FP16_TILE_N_ROWS; - const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile - const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) - - #pragma unroll(4) - for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { - const size_t c0 = c / HMX_FP16_TILE_N_COLS; - const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; - HVX_Vector v = ((const HVX_Vector *) tile)[r1]; - HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); - volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory - - *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); - if (r + 1 < n_rows) { - *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); - } - } - } -} - -typedef struct { - const __fp16 *vtcm_src; - float *dst; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int n_cols; - int n; // DDR row stride (total output columns) - struct htp_thread_trace * traces; -} output_transfer_task_state_t; - -static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i); - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - float *dst = st->dst + chunk_idx * st->n; - const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; - transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i); -} - -static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n, int n_threads) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) - - output_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.vtcm_src = vtcm_src; - state.n_cols = n_cols; - state.n = n; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - transfer_output_chunk_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); - } -} - -// activations : fp32 -> fp16 - -static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { - const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; - - int r = 0; - - #pragma unroll(2) - for (r = 0; r < n_rows_tiled; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); - const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = *pv_in1++; - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } - - for (; r < n_rows_padded; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const bool row0_valid = r < n_rows; - const bool row1_valid = (r + 1) < n_rows; - - const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; - const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); - HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - int k_stride; - struct htp_thread_trace * traces; -} activation_transfer_task_state_t; - -static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i); - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - // one chunk: one row - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - __fp16 *dst = st->dst + chunk_idx * st->k_block; - const float *src = st->src + chunk_idx * st->k_stride; - transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i); -} - -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { - assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); - assert(VLEN == 32 * sizeof(float)); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address - - activation_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.src = src; - state.k_block = k_block; - state.k_stride = k_stride; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - transfer_activation_chunk_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); - } -} - -// C += AB -static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, - const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, - int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *)col_scales); - - const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t i = 0; i < n_row_tiles; ++i) { - const __fp16 *row_base = a + i * dot_tile_stride; - __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t j = 0; j < n_col_tiles; ++j) { - Q6_mxclracc_hf(); - - const __fp16 *col_tiles = b + j * dot_tile_stride; - const __fp16 *row_tiles = row_base; - __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; - if (!zero_init) { - Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); - } - - for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { - k_block = hex_smin(n_dot_tiles - k, 32); - const uint32_t range = 2048u * (uint32_t)k_block - 1; - Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); - row_tiles += k_block * HMX_FP16_TILE_N_ELMS; - col_tiles += k_block * HMX_FP16_TILE_N_ELMS; - } - - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const uint8_t *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride, int weight_type) { - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - worker_callback_t dequant_worker_fn = NULL; - switch (weight_type) { - case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; - case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; - case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; - case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; - case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; - case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; - case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; - default: - return -1; - } - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - - // --- Dynamic Mode Configuration --- - const bool use_pipeline = (m > 32); - const int num_threads = (m <= 32) ? 1 : ctx->n_threads; - - // --- Dynamic VTCM layout --- - const size_t vec_dot_size = k * sizeof(__fp16); - const size_t vtcm_budget = ctx->vtcm_size; - size_t vtcm_used = 0; - - // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); - return -1; - } - - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); - const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size, scratch1_size, scratch2_size; - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 - scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; - void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - - vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; - if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", - m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); - - - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - - if (use_pipeline) { - // --- Asynchronous Pipelined Loop (Current implementation) --- - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); - } - - { - const float *activation_chunk = activation + mr * act_stride; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - } - - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - } - } - - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); - } - - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } - - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); - - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - } - } - } - hmx_queue_suspend(ctx->hmx_queue); - } else { - // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - // Load Activation - const float *activation_chunk = activation + mr * act_stride; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - // A: DMA Load Weight - const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); - dma_queue_pop(ctx->dma[0]); - - // B: Dequantize / Convert Weight - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - // C: HMX Compute (Synchronous) - { - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - // D: Output Store - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); - } - } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } - - - - return 0; -} - -// - -static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { - return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; -} - -static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { - return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; -} - -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - const int r2 = hmx_matmul_batch_r2(params); - const int r3 = hmx_matmul_batch_r3(params); - return (const __fp16 *) ((const uint8_t *) params->permuted_weight + - (size_t) (dst_b2 / r2) * params->src0_nb2 + - (size_t) (dst_b3 / r3) * params->src0_nb3); -} - -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (const float *) ((const uint8_t *) params->activation + - (size_t) dst_b2 * params->src1_nb2 + - (size_t) dst_b3 * params->src1_nb3); -} - -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (float *) ((uint8_t *) params->dst + - (size_t) dst_b2 * params->dst_nb2 + - (size_t) dst_b3 * params->dst_nb3); -} - -static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_f16_f32_batched_params_t *params) { - int ret = 0; - for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { - for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); - } - } - return ret; -} - -int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { - if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } - if (!params->m || !params->k || !params->n) { return -1; } - if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } - if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } - if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } - if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } - - if (!hex_is_aligned(params->dst, VLEN) || - !hex_is_aligned(params->activation, VLEN) || - !hex_is_aligned(params->permuted_weight, VLEN)) { - return -1; - } - - const int group_size = hmx_matmul_batch_r2(params); - - if (group_size <= 1) { - FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - // Grouped path: reuse interleaved weight across all q_heads sharing a - // kv_head. Each q_head gets its own activation buffer in VTCM (so - // activation is loaded once per m_chunk and reused across all n_chunks), - // and each q_head is computed individually to avoid tile-major packing - // issues. m_chunk_n_rows is always a multiple of 32 (from - // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = params->k * sizeof(__fp16); - - // When the activation has a large stride (e.g. permuted Q tensor with - // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. - // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather - // strided rows into a contiguous block before the F32->F16 conversion. - const bool use_dma_activation = (params->act_stride > params->k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), - hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, - /*m_block_cost=*/(size_t) params->n, - /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - - if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { - FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, params->m, params->k, params->n, group_size, params->ne13, - m_chunk_n_rows, n_chunk_n_cols, - (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - - - const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (int b3 = 0; b3 < params->ne13; ++b3) { - for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { - const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); - - for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - - // Pre-load activations for all heads in the group (once per m_chunk). - // When the source is strided (permuted Q), use 2D DMA to gather - // contiguous rows into a VTCM scratch buffer first, then HVX - // converts from the contiguous VTCM buffer. This avoids L2 cache - // thrashing from HVX loads at large strides. - for (int g = 0; g < group_size; ++g) { - const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; - __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) params->k * sizeof(float); - const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - vtcm_f32_act, (int) n_rows, - params->k, params->k, ctx->n_threads); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - activation_chunk, (int) n_rows, - params->k, params->act_stride, ctx->n_threads); - } - } - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - - { - dma_queue_pop(ctx->dma[0]); - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < (size_t) params->n) { - const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } - - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, - 0, n_cols); - hex_swap_ptr(&buf_curr, &buf_next); - } - - // Reuse the interleaved weight for every q_head in this GQA group - for (int g = 0; g < group_size; ++g) { - { - const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, - params->k / 32); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - { - float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); - } - } - } - } - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - - - return 0; -} - -int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const __fp16 *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, - act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); -} - -struct mmid_row_mapping { - uint32_t i1; - uint32_t i2; -}; - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - const struct mmid_row_mapping *matrix_rows; - int cur_a; - int mapping_stride; - int ne11; - struct fastdiv_values ne11_div; - size_t nb11; - size_t nb12; - int start_row; - int cne1; - struct htp_thread_trace *traces; -} activation_transfer_gathered_task_state_t; - -typedef struct { - const __fp16 *vtcm_src; - float *dst; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int n_cols; - const struct mmid_row_mapping *matrix_rows; - int cur_a; - int mapping_stride; - size_t dst_nb1; - size_t dst_nb2; - int start_row; - int cne1; - struct htp_thread_trace *traces; -} output_transfer_scattered_task_state_t; - -static void transfer_activation_chunk_fp32_to_fp16_gathered( - __fp16 *restrict vtcm_dst, - const float *restrict src, - int start_row, - int n_rows, - int k_block, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - int ne11, - const struct fastdiv_values * ne11_div, - size_t nb11, - size_t nb12, - int cne1) { - const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; - - int r = 0; - - #pragma unroll(2) - for (r = 0; r < n_rows_tiled; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - int r_idx0 = start_row + r + 0; - int r_idx1 = start_row + r + 1; - - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; - - int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); - int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); - - const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); - const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); - - const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; - const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; - - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = *pv_in1++; - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } - - for (; r < n_rows_padded; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const bool row0_valid = (start_row + r + 0) < cne1; - const bool row1_valid = (start_row + r + 1) < cne1; - - const float *row0_ptr = NULL; - const float *row1_ptr = NULL; - - if (row0_valid) { - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; - int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); - row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); - } - if (row1_valid) { - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; - int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); - row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); - } - - const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; - const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; - - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); - HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_gathered_task_state_t *st = data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i); - - int chunk_idx = i; - int chunk_size = st->n_chunks_per_task; - int start_row = st->start_row + chunk_idx * chunk_size; - int n_rows = hex_smin(st->cne1 - start_row, chunk_size); - if (n_rows > 0) { - __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; - transfer_activation_chunk_fp32_to_fp16_gathered( - dst, st->src, start_row, n_rows, st->k_block, - st->matrix_rows, st->cur_a, st->mapping_stride, - st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i); -} - -static void transfer_activation_chunk_gathered_threaded( - struct htp_context *ctx, - __fp16 *dst, - const float *src, - int start_row, - int n_rows, - int k_block, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - int ne11, - size_t nb11, - size_t nb12, - int cne1, - int n_threads) { - if (n_rows <= 0) return; - int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); - chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); - - int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); - - activation_transfer_gathered_task_state_t state = { - .dst = dst, - .src = src, - .n_tasks = actual_threads, - .n_tot_chunks = n_rows, - .n_chunks_per_task = chunks_per_thread, - .k_block = k_block, - .matrix_rows = matrix_rows, - .cur_a = cur_a, - .mapping_stride = mapping_stride, - .ne11 = ne11, - .ne11_div = init_fastdiv_values(ne11), - .nb11 = nb11, - .nb12 = nb12, - .start_row = start_row, - .cne1 = cne1, - .traces = ctx ? ctx->trace : NULL, - }; - - if (actual_threads <= 1) { - transfer_activation_chunk_gathered_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); - } -} - -static void transfer_output_chunk_fp16_to_fp32_scattered( - float *restrict dst, - const __fp16 *restrict vtcm_src, - int start_row, - int n_rows, - int n_cols, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - size_t dst_nb1, - size_t dst_nb2, - int cne1) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - - const HVX_Vector one = hvx_vec_splat_f16(1.0); - - for (size_t r = 0; r < n_rows; r += 2) { - const size_t r0 = r / HMX_FP16_TILE_N_ROWS; - const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile - const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - - int r_idx0 = start_row + (int)r + 0; - int r_idx1 = start_row + (int)r + 1; - - if (r_idx0 >= cne1) break; - - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; - float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); - - float *output_row1 = NULL; - if (r_idx1 < cne1) { - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; - output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); - } - - #pragma unroll(4) - for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { - const size_t c0 = c / HMX_FP16_TILE_N_COLS; - const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; - HVX_Vector v = ((const HVX_Vector *) tile)[r1]; - HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); - volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; - - *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); - if (pv_out1) { - *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); - } - } - } -} - -static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { - output_transfer_scattered_task_state_t *st = data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i); - - int chunk_idx = i; - int chunk_size = st->n_chunks_per_task; - int start_row = st->start_row + chunk_idx * chunk_size; - int n_rows = hex_smin(st->cne1 - start_row, chunk_size); - if (n_rows > 0) { - const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; - transfer_output_chunk_fp16_to_fp32_scattered( - st->dst, src, start_row, n_rows, st->n_cols, - st->matrix_rows, st->cur_a, st->mapping_stride, - st->dst_nb1, st->dst_nb2, st->cne1); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i); -} - -static void transfer_output_chunk_scattered_threaded( - struct htp_context *ctx, - float *dst, - const __fp16 *vtcm_src, - int start_row, - int n_rows, - int n_cols, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - size_t dst_nb1, - size_t dst_nb2, - int cne1, - int n_threads) { - if (n_rows <= 0) return; - int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); - chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); - - int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); - - output_transfer_scattered_task_state_t state = { - .vtcm_src = vtcm_src, - .dst = dst, - .n_tasks = actual_threads, - .n_tot_chunks = n_rows, - .n_chunks_per_task = chunks_per_thread, - .n_cols = n_cols, - .matrix_rows = matrix_rows, - .cur_a = cur_a, - .mapping_stride = mapping_stride, - .dst_nb1 = dst_nb1, - .dst_nb2 = dst_nb2, - .start_row = start_row, - .cne1 = cne1, - .traces = ctx ? ctx->trace : NULL, - }; - - if (actual_threads <= 1) { - transfer_output_chunk_scattered_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); - } -} - -int hmx_matmul_id_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int ne11, - size_t act_nb1, size_t act_nb2, - size_t dst_nb1, size_t dst_nb2, - int weight_stride, - int weight_type, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride) { - const int cne1 = m; - const int m_padded = hex_align_up(m, 32); - - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - worker_callback_t dequant_worker_fn = NULL; - switch (weight_type) { - case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; - case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; - case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; - case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; - case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; - case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; - case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; - default: - return -1; - } - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - - const int num_threads = ctx->n_threads; - - const size_t vec_dot_size = k * sizeof(__fp16); - const size_t vtcm_budget = ctx->vtcm_size; - size_t vtcm_used = 0; - - const size_t size_per_n = row_stride + vec_dot_size; - const size_t size_per_mn = sizeof(__fp16); - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, - m_padded, n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); - return -1; - } - - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); - const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - - vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; - if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - transfer_activation_chunk_gathered_threaded( - ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, - matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); - - for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); - dma_queue_pop(ctx->dma[0]); - - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - { - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - transfer_output_chunk_scattered_threaded( - ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, - matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - return 0; -} diff --git a/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h b/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h new file mode 100644 index 00000000000..b7fba22a87f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h @@ -0,0 +1,1306 @@ +#include "hmx-utils.h" +#include "hmx-queue.h" + +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// --- tiled format dequantizers --- + +typedef struct { + struct htp_context * ctx; + struct htp_thread_trace * traces; + __fp16 * dst; + const uint8_t * src; + + struct fastdiv_values n_k_tiles_div; + uint32_t n_k_tiles; + uint32_t n_tot_tiles; + uint32_t n_tiles_per_task; + uint32_t tile_size; + uint32_t aligned_tile_size; + uint32_t n_tasks; + uint32_t n_cols; + uint32_t k_block; + size_t row_stride; + uint32_t weight_type; +} tiled_dequantize_state_t; + +// Dequantize a single tile from tiled weight data (already in VTCM) to tile-major FP16. +static void dequantize_tiled_weight_to_fp16_task_q4_0( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 512); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Offsetting (-8) + v_lo0 = Q6_Vb_vsub_VbVb(v_lo0, i8); + v_hi0 = Q6_Vb_vsub_VbVb(v_hi0, i8); + v_lo1 = Q6_Vb_vsub_VbVb(v_lo1, i8); + v_hi1 = Q6_Vb_vsub_VbVb(v_hi1, i8); + v_lo2 = Q6_Vb_vsub_VbVb(v_lo2, i8); + v_hi2 = Q6_Vb_vsub_VbVb(v_hi2, i8); + v_lo3 = Q6_Vb_vsub_VbVb(v_lo3, i8); + v_hi3 = Q6_Vb_vsub_VbVb(v_hi3, i8); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Unpack to 16-bit + HVX_VectorPair vp_int16_lo0 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_VectorPair vp_int16_hi0 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_VectorPair vp_int16_lo1 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_VectorPair vp_int16_hi1 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_VectorPair vp_int16_lo2 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_VectorPair vp_int16_hi2 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_VectorPair vp_int16_lo3 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_VectorPair vp_int16_hi3 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf3)); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo0)), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo0)), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi0)), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi0)), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo1)), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo1)), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi1)), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi1)), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo2)), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo2)), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi2)), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi2)), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo3)), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo3)), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi3)), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi3)), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_q4_1( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector vscale_offset = hvx_vmem(tile_src + 512); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vd, vd, -2)); + HVX_Vector v_offset_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vm, vm, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Unpack to 16-bit + HVX_VectorPair vp_int16_lo0 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_VectorPair vp_int16_hi0 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_VectorPair vp_int16_lo1 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_VectorPair vp_int16_hi1 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_VectorPair vp_int16_lo2 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_VectorPair vp_int16_hi2 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_VectorPair vp_int16_lo3 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_VectorPair vp_int16_hi3 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf3)); + + // Convert, multiply, add offset + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi0)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi1)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi2)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi3)), v_scale_duplicated), v_offset_duplicated)); + + // Parallel Stores + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_iq4_nl( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector vlut_cvt = hvx_vmem(iq4_nl_to_fp16_lut); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 512); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Shuffle for LUT lookup + HVX_Vector v_q_lo0 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_Vector v_q_hi0 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_Vector v_q_lo1 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_Vector v_q_hi1 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_Vector v_q_lo2 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_Vector v_q_hi2 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_Vector v_q_lo3 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_Vector v_q_hi3 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf3)); + + // LUT lookup + HVX_VectorPair vp_lo0 = Q6_Wh_vlut16_VbVhR(v_q_lo0, vlut_cvt, 0); + HVX_VectorPair vp_hi0 = Q6_Wh_vlut16_VbVhR(v_q_hi0, vlut_cvt, 0); + HVX_VectorPair vp_lo1 = Q6_Wh_vlut16_VbVhR(v_q_lo1, vlut_cvt, 0); + HVX_VectorPair vp_hi1 = Q6_Wh_vlut16_VbVhR(v_q_hi1, vlut_cvt, 0); + HVX_VectorPair vp_lo2 = Q6_Wh_vlut16_VbVhR(v_q_lo2, vlut_cvt, 0); + HVX_VectorPair vp_hi2 = Q6_Wh_vlut16_VbVhR(v_q_hi2, vlut_cvt, 0); + HVX_VectorPair vp_lo3 = Q6_Wh_vlut16_VbVhR(v_q_lo3, vlut_cvt, 0); + HVX_VectorPair vp_hi3 = Q6_Wh_vlut16_VbVhR(v_q_hi3, vlut_cvt, 0); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi0), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi0), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi1), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi1), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi2), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi2), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi3), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi3), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_mxfp4( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v = hvx_vmem(tile_src + 512); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vh, vh, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Shuffle for LUT lookup + HVX_Vector v_q_lo0 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_Vector v_q_hi0 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_Vector v_q_lo1 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_Vector v_q_hi1 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_Vector v_q_lo2 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_Vector v_q_hi2 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_Vector v_q_lo3 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_Vector v_q_hi3 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf3)); + + // LUT lookup + HVX_VectorPair vp_lo0 = Q6_Wh_vlut16_VbVhR(v_q_lo0, vlut_cvt, 0); + HVX_VectorPair vp_hi0 = Q6_Wh_vlut16_VbVhR(v_q_hi0, vlut_cvt, 0); + HVX_VectorPair vp_lo1 = Q6_Wh_vlut16_VbVhR(v_q_lo1, vlut_cvt, 0); + HVX_VectorPair vp_hi1 = Q6_Wh_vlut16_VbVhR(v_q_hi1, vlut_cvt, 0); + HVX_VectorPair vp_lo2 = Q6_Wh_vlut16_VbVhR(v_q_lo2, vlut_cvt, 0); + HVX_VectorPair vp_hi2 = Q6_Wh_vlut16_VbVhR(v_q_hi2, vlut_cvt, 0); + HVX_VectorPair vp_lo3 = Q6_Wh_vlut16_VbVhR(v_q_lo3, vlut_cvt, 0); + HVX_VectorPair vp_hi3 = Q6_Wh_vlut16_VbVhR(v_q_hi3, vlut_cvt, 0); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi0), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi0), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi1), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi1), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi2), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi2), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi3), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi3), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_q8_0( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 1024); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load groups 0-3 in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + HVX_VectorPair vp_int16_0 = Q6_Wh_vunpack_Vb(vq0); + HVX_VectorPair vp_int16_1 = Q6_Wh_vunpack_Vb(vq1); + HVX_VectorPair vp_int16_2 = Q6_Wh_vunpack_Vb(vq2); + HVX_VectorPair vp_int16_3 = Q6_Wh_vunpack_Vb(vq3); + + // Load groups 4-7 in parallel + HVX_Vector vq4 = hvx_vmem(tile_src + 4 * 128); + HVX_Vector vq5 = hvx_vmem(tile_src + 5 * 128); + HVX_Vector vq6 = hvx_vmem(tile_src + 6 * 128); + HVX_Vector vq7 = hvx_vmem(tile_src + 7 * 128); + + HVX_VectorPair vp_int16_4 = Q6_Wh_vunpack_Vb(vq4); + HVX_VectorPair vp_int16_5 = Q6_Wh_vunpack_Vb(vq5); + HVX_VectorPair vp_int16_6 = Q6_Wh_vunpack_Vb(vq6); + HVX_VectorPair vp_int16_7 = Q6_Wh_vunpack_Vb(vq7); + + // Convert and scale multiply for groups 0-3 + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_0)), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_0)), v_scale_duplicated)); + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_1)), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_1)), v_scale_duplicated)); + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_2)), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_2)), v_scale_duplicated)); + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_3)), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_3)), v_scale_duplicated)); + + // Store groups 0-3 + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 3 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 4 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 7 * 64) = v_grp3_1; + + // Convert and scale multiply for groups 4-7 + HVX_Vector v_grp4_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_4)), v_scale_duplicated)); + HVX_Vector v_grp4_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_4)), v_scale_duplicated)); + HVX_Vector v_grp5_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_5)), v_scale_duplicated)); + HVX_Vector v_grp5_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_5)), v_scale_duplicated)); + HVX_Vector v_grp6_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_6)), v_scale_duplicated)); + HVX_Vector v_grp6_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_6)), v_scale_duplicated)); + HVX_Vector v_grp7_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_7)), v_scale_duplicated)); + HVX_Vector v_grp7_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_7)), v_scale_duplicated)); + + // Store groups 4-7 + hvx_vmem(dst_ptr + 8 * 64) = v_grp4_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp4_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp5_0; + hvx_vmem(dst_ptr + 11 * 64) = v_grp5_1; + hvx_vmem(dst_ptr + 12 * 64) = v_grp6_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp6_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp7_0; + hvx_vmem(dst_ptr + 15 * 64) = v_grp7_1; + } +} + +static void convert_f16_weight_to_fp16_tiles_task( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const uint32_t n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + { + uint32_t byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (uint32_t r = 0; r < HTP_MM_HMX_TILE_N_ROWS; r += 2) { + uint32_t row0 = ct * HTP_MM_HMX_TILE_N_COLS + r; + uint32_t row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HTP_MM_HMX_TILE_N_ELMS); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const uint32_t n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + { + uint32_t byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (uint32_t r = 0; r < HTP_MM_HMX_TILE_N_ROWS; r += 2) { + uint32_t row0 = ct * HTP_MM_HMX_TILE_N_COLS + r; + uint32_t row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmem((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmem((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HTP_MM_HMX_TILE_N_ELMS); + } +} + +// --- End tiled dequantizers --- + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, + uint32_t n_row_tiles, uint32_t n_col_tiles, uint32_t n_dot_tiles) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); + for (uint32_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + + for (uint32_t k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + col_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HTP_MM_HMX_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + uint32_t n_row_tiles, uint32_t n_col_tiles, uint32_t n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)col_scales); + + const size_t dot_tile_stride = n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HTP_MM_HMX_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HTP_MM_HMX_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } + + for (uint32_t k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + col_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + } + + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} + +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + uint32_t n_row_tiles, + uint32_t n_col_tiles, + uint32_t n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// output : fp16 -> f32p + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + const size_t limit_c = hex_smin(n_cols, dst_cols); + const size_t limit_c_aligned = (limit_c & ~31); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r_idx0 = start_row + r + 0; + const size_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; + const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1) + + #pragma unroll(4) + for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) { + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0); + HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride); + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + + if (limit_c_aligned < limit_c) { + size_t c = limit_c_aligned; + size_t valid_c = limit_c - c; + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp))); + if (r + 1 < n_rows) { + hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp))); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t n_cols; + uint32_t dst_stride; // DDR row stride + uint32_t dst_cols; // Actual output columns + struct htp_thread_trace * traces; +} output_transfer_task_state_t; + +// activations : fp32 -> fp16 + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, uint32_t n_rows, uint32_t k_block, uint32_t k_stride, uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r0 = r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const float *ptr_in0 = src + (r + 0) * k_stride; + const float *ptr_in1 = src + (r + 1) * k_stride; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(ptr_in0 + c); + HVX_Vector v1 = *(const HVX_Vector *)(ptr_in1 + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(ptr_in0 + c); + HVX_Vector v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r0 = r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const float *ptr_in0 = row0_valid ? (src + (r + 0) * k_stride) : NULL; + const float *ptr_in1 = row1_valid ? (src + (r + 1) * k_stride) : NULL; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t k_block; + uint32_t k_stride; + uint32_t k_valid; + struct htp_thread_trace * traces; + struct htp_context * ctx; + float * vtcm_f32_act; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_dma_pipelined( + dma_queue *dma_q, + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t n_rows, + uint32_t k_block, + uint32_t k_stride, + uint32_t k_valid, + float *thread_f32_act) { + + const uint32_t R = HTP_MM_DMA_ACT_ROWS_PER_STEP; + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + + const uint32_t n_steps = n_rows_padded / R; + + // pre-fetch step 0 + if (n_steps > 0 && n_rows > 0) { + uint32_t nrows_to_fetch = hex_smin(n_rows, R); + dma_queue_push(dma_q, dma_make_ptr(thread_f32_act, src), + k_block * sizeof(float), k_stride * sizeof(float), k_valid * sizeof(float), nrows_to_fetch); + } + + for (uint32_t s = 0; s < n_steps; ++s) { + uint32_t r = R * s; + float *curr_buf = thread_f32_act + (s % 2) * R * k_block; + + if (r < n_rows) { + dma_queue_pop(dma_q); + } + + uint32_t next_s = s + 1; + uint32_t next_r = R * next_s; + if (next_r < n_rows) { + uint32_t nrows_to_fetch = hex_smin(n_rows - next_r, R); + const float *next_src = src + next_r * k_stride; + float *next_buf = thread_f32_act + (next_s % 2) * R * k_block; + dma_queue_push(dma_q, dma_make_ptr(next_buf, next_src), + k_block * sizeof(float), k_stride * sizeof(float), k_valid * sizeof(float), nrows_to_fetch); + } + + #pragma unroll + for (uint32_t i = 0; i < HTP_MM_DMA_ACT_ROWS_PER_STEP; i += 2) { + uint32_t curr_r = r + i; + const bool row0_valid = (curr_r < n_rows); + const bool row1_valid = (curr_r + 1) < n_rows; + + const float *ptr_in0 = curr_buf + i * k_block; + const float *ptr_in1 = curr_buf + (i + 1) * k_block; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t r0 = curr_r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = curr_r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t r0 = curr_r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = curr_r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + } +} + +typedef struct { + const struct mmid_row_mapping *matrix_rows; + __fp16 *dst; + const float *src; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t k_block; + uint32_t cur_a; + uint32_t mapping_stride; + uint32_t ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + uint32_t start_row; + uint32_t cne1; + uint32_t k_valid; + struct htp_thread_trace *traces; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const struct mmid_row_mapping *matrix_rows; + const __fp16 *vtcm_src; + float *dst; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t n_cols; + uint32_t cur_a; + uint32_t mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + uint32_t start_row; + uint32_t cne1; + struct htp_thread_trace *traces; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t start_row, + uint32_t n_rows, + uint32_t k_block, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + uint32_t ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + uint32_t cne1, + uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + uint32_t i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + uint32_t i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r_idx0 = start_row + r; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + uint32_t i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + uint32_t i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_activation_chunk_fp32_to_fp16_gathered_flat( + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t start_row, + uint32_t n_rows, + uint32_t k_block, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + size_t nb12, + uint32_t cne1, + uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + const float *row0_ptr = (const float *) ((const uint8_t *) src + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + mapping1.i2 * nb12); + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r_idx0 = start_row + r; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + row0_ptr = (const float *) ((const uint8_t *) src + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + row1_ptr = (const float *) ((const uint8_t *) src + mapping1.i2 * nb12); + } + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + uint32_t start_row, + uint32_t n_rows, + uint32_t n_cols, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + uint32_t cne1) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + const size_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; + const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + + if (r_idx0 >= cne1) break; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); + + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HTP_MM_HMX_TILE_N_COLS) { + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + HVX_Vector *pv_out0 = (HVX_Vector *) (output_row0 + c); + HVX_Vector *pv_out1 = output_row1 ? (HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c deleted file mode 100644 index 114d8c14811..00000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.c +++ /dev/null @@ -1,6 +0,0 @@ -// HMX operations compiled as a single translation unit. -// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. - -#include "hmx-queue.c" -#include "hmx-matmul-ops.c" -#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h deleted file mode 100644 index a67842f3ffc..00000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ /dev/null @@ -1,88 +0,0 @@ -// HMX operation entry-point declarations. -// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) - -#ifndef HMX_OPS_H -#define HMX_OPS_H - -#include -#include - -#include "htp-ops.h" - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - float *dst; - const float *activation; - const __fp16 *permuted_weight; - int m; - int k; - int n; - int act_stride; - int weight_stride; - int dst_stride; - int ne02; - int ne03; - int ne12; - int ne13; - size_t src0_nb2; - size_t src0_nb3; - size_t src1_nb2; - size_t src1_nb3; - size_t dst_nb2; - size_t dst_nb3; -} hmx_matmul_f16_f32_batched_params_t; - -// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output -// act_stride: activation row stride in elements (= k for contiguous, or -// nb[1]/sizeof(float) for permuted tensors like attention Q). -// weight_stride: weight row stride in elements (= k for compact weights, or -// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_matmul_f16_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const __fp16 *permuted_weight, - int m, int k, int n, - int act_stride, - int weight_stride); - -// Batched F16 wrapper over hmx_mat_mul_f16_f32. -// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); - -// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int act_stride, - int weight_stride, - int weight_type); - -struct mmid_row_mapping; - -int hmx_matmul_id_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int ne11, - size_t act_nb1, size_t act_nb2, - size_t dst_nb1, size_t dst_nb2, - int weight_stride, - int weight_type, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride); - -// HMX flash attention -int hmx_flash_attn_ext(struct htp_ops_context * octx); - -#ifdef __cplusplus -} -#endif - -#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index cbb5d08786b..6ad77d3daa3 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -13,7 +13,9 @@ #include #include +#ifndef HTP_MAX_NTHREADS #define HTP_MAX_NTHREADS 10 +#endif #define HTP_MAX_MMAPS 16 // Memory mapping @@ -42,9 +44,13 @@ struct htp_ops_context { enum htp_op_code op; // FIXME: rename to opcode int32_t op_params[HTP_OP_MAX_PARAMS]; + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; - const struct htp_tensor * dst; + union { + const struct htp_tensor * dst; + const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS]; + }; // TODO convert these to an array struct htp_spad src0_spad; @@ -87,13 +93,13 @@ struct htp_context { struct htp_ops_context octx; -#ifdef HTP_HAS_HMX struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap -#endif }; int op_matmul(struct htp_ops_context * octx); int op_matmul_id(struct htp_ops_context * octx); +int op_matmul_qkv(struct htp_ops_context * octx); +int op_matmul_ffn(struct htp_ops_context * octx); int op_binary(struct htp_ops_context * octx); int op_unary(struct htp_ops_context * octx); int op_sum_rows(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 0f4b74a93ac..d0409013578 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -28,18 +28,19 @@ enum htp_data_type { HTP_TYPE_MXFP4 = 39, // types used internally for repack, dyn.quant, etc - HTP_TYPE_Q4_0x4x2 = 200, - HTP_TYPE_Q4_1x4x2, - HTP_TYPE_Q8_0x4x2, - HTP_TYPE_MXFP4x4x2, + HTP_TYPE_Q4_0_TILED = 200, + HTP_TYPE_Q4_1_TILED, + HTP_TYPE_Q8_0_TILED, + HTP_TYPE_MXFP4_TILED, HTP_TYPE_INVALID }; // Constats for internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks +#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout +#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout +#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout + // Mask to enable various stages of the Ops. @@ -57,6 +58,8 @@ enum htp_op_code { HTP_OP_DIV = 3, HTP_OP_MUL_MAT, HTP_OP_MUL_MAT_ID, + HTP_OP_MUL_MAT_QKV, + HTP_OP_MUL_MAT_FFN, HTP_OP_RMS_NORM, HTP_OP_RMS_NORM_MUL, HTP_OP_UNARY_SILU, @@ -99,7 +102,9 @@ enum htp_op_code { #define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS #define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_OUTPUTS 4 #define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS +#define HTP_OP_MAX_KERN_PARAMS 32 #define HTP_OP_MAX_BUFS 16 #define HTP_OP_MAX_REQS 256 @@ -142,8 +147,10 @@ struct htp_op_desc { uint32_t opcode; // GGML/HTP Op uint32_t flags; // Op flags int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices - uint16_t dst; // Output tensor index + uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices + uint16_t pad[2]; // padding to align to 64 bits }; #ifndef HTP_MAX_NTHREADS diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index d696a5fba0c..47693d8b8b2 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -11,12 +11,13 @@ struct htp_iface_pmu_conf { }; interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem); AEEResult stop(); AEEResult mmap(in uint32 fd, in uint32 size); AEEResult munmap(in uint32 fd); AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); AEEResult etm(in uint32 enable); + AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index f6cb02951d0..493b26c6e75 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { } #endif -/* Q6_Vsf_equals_Vw is only available on v73+.*/ -#if __HVX_ARCH__ < 73 -static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) -{ - HVX_Vector const vzero = Q6_V_vzero(); - HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); - HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); - HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); - HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); - HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); - HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); - return ret; -} -static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) -{ - return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); -} -#endif static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { // This looks complicated. @@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { #endif // __HVX_ARCH__ < 79 +static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) { + if (kt % 4 == 0) { + *v_act_all = hvx_vmem(y_q + kt * 32); + return *v_act_all; + } else if (kt % 4 == 1) { + return Q6_V_vror_VR(*v_act_all, 32); + } else if (kt % 4 == 2) { + return Q6_V_vror_VR(*v_act_all, 64); + } else { + return Q6_V_vror_VR(*v_act_all, 96); + } +} + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h new file mode 100644 index 00000000000..52351b1039c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h @@ -0,0 +1,1024 @@ +// Dynamic quantizers that produce flat (non-tiled) activations + +static inline void quantize_block_f32_q8_0_flat( + float * restrict x, + uint8_t * restrict y_quants, + __fp16 * restrict y_scales, + uint32_t block_idx +) { + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + * (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8; + + HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2); + HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2); + HVX_Vector v_scales = Q6_V_lo_W(vp2); + hvx_vec_store_u(y_scales + block_idx * 4, 8, v_scales); +} + +static inline void quantize_block_f32_q8_1_flat( + float * restrict x, + uint8_t * restrict y_quants, + __fp16 * restrict y_scales, + uint32_t block_idx +) { + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + * (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8; + + HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2); + HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2); + HVX_Vector v_scales = Q6_V_lo_W(vp2); + + HVX_VectorPair v_deal1 = Q6_W_vdeal_VVR(v_sums, v_sums, -4); + HVX_Vector v_even1 = Q6_V_lo_W(v_deal1); + HVX_VectorPair v_deal2 = Q6_W_vdeal_VVR(v_even1, v_even1, -4); + HVX_Vector v_even2 = Q6_V_lo_W(v_deal2); + HVX_VectorPair v_deal3 = Q6_W_vdeal_VVR(v_even2, v_even2, -4); + HVX_Vector v_sums_shuffled = Q6_V_lo_W(v_deal3); + + HVX_Vector v_sums_sf = Q6_Vsf_equals_Vw(v_sums_shuffled); + HVX_Vector v_sums_hf = hvx_vec_f32_to_f16(v_sums_sf, Q6_V_vzero()); + + HVX_Vector v_prod = hvx_vec_mul_f16_f16(v_scales, v_sums_hf); + + HVX_VectorPair vp_scales = Q6_W_vshuff_VVR(v_prod, v_scales, -2); + HVX_Vector v_final = Q6_V_lo_W(vp_scales); + + hvx_vec_store_u(y_scales + block_idx * 8, 16, v_final); +} + +static inline void quantize_row_f32_q8_0_flat(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t quants_size = hex_round_up(k, 128); + uint8_t * restrict y_quants = y; + __fp16 * restrict y_scales = (__fp16 *) (y + quants_size); + + const uint32_t nb = (k + 127) / 128; + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_0_flat(x + i * 128, y_quants, y_scales, i); + } +} + +static inline void quantize_row_f32_q8_1_flat(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t quants_size = hex_round_up(k, 128); + uint8_t * restrict y_quants = y; + __fp16 * restrict y_scales = (__fp16 *) (y + quants_size); + + const uint32_t nb = (k + 127) / 128; + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1_flat(x + i * 128, y_quants, y_scales, i); + } +} + +static inline void quantize_f32_q8_0_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_0_flat((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_1_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1_flat((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_f32_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +static inline void quantize_f32_f16_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +static inline void quantize_f16_f16_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +// Dot kernels that consume flat (non-tiled) activations + +static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, i8); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, i8); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, Q6_V_vzero()); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + __fp16 scale_a_val = y_scales[kt * 2 + 0]; + __fp16 sum_a_val = y_scales[kt * 2 + 1]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + HVX_Vector v_sum_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a); + HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a); + + HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, Q6_V_vzero()); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + __fp16 scale_a0_val = y0_scales[kt * 2 + 0]; + __fp16 sum_a0_val = y0_scales[kt * 2 + 1]; + __fp16 scale_a1_val = y1_scales[kt * 2 + 0]; + __fp16 sum_a1_val = y1_scales[kt * 2 + 1]; + + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_sum_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + HVX_Vector v_sum_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a0); + HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a1); + HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a1); + + HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0); + + HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act_rep); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[8]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0_rep, v_act1_rep); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[8]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32(v_scale_a_f16); + HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32); + + HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + HVX_VectorPair p_scale_a0_f32 = hvx_vec_f16_to_f32(v_scale_a0_f16); + HVX_VectorPair p_scale_a1_f32 = hvx_vec_f16_to_f32(v_scale_a1_f16); + HVX_Vector v_scale_a0 = Q6_V_lo_W(p_scale_a0_f32); + HVX_Vector v_scale_a1 = Q6_V_lo_W(p_scale_a1_f32); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f)); + v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} diff --git a/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h new file mode 100644 index 00000000000..bcb0b8f9e47 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h @@ -0,0 +1,1140 @@ +// Dynamic quantizers that produce tiled activations + +static inline void quantize_block_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y_block) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_block % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + for (int b = 0; b < 4; b++) { + HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32); + + HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl); + HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl); + HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl); + HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl); + HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl); + HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl); + HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl); + HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl); + + __fp16 scale_h, offset_h; + if (b == 0) { + scale_h = (__fp16) d0; + offset_h = (__fp16) (sums[0] * d0); + } else if (b == 1) { + scale_h = (__fp16) d1; + offset_h = (__fp16) (sums[8] * d1); + } else if (b == 2) { + scale_h = (__fp16) d2; + offset_h = (__fp16) (sums[16] * d2); + } else { + scale_h = (__fp16) d3; + offset_h = (__fp16) (sums[24] * d3); + } + + HVX_Vector r_scale = Q6_Vh_vsplat_R(*(int16_t *)&scale_h); + HVX_Vector r_offset = Q6_Vh_vsplat_R(*(int16_t *)&offset_h); + + HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1280); + dst[0] = r0; + dst[1] = r1; + dst[2] = r2; + dst[3] = r3; + dst[4] = r4; + dst[5] = r5; + dst[6] = r6; + dst[7] = r7; + dst[8] = r_scale; + dst[9] = r_offset; + } +} + +static inline void quantize_block_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y_block) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_block % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + HVX_Vector r_scale = hvx_vec_repl_f16(vd_hf); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + for (int b = 0; b < 4; b++) { + HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32); + + HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl); + HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl); + HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl); + HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl); + HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl); + HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl); + HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl); + HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl); + + HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1152); + dst[0] = r0; + dst[1] = r1; + dst[2] = r2; + dst[3] = r3; + dst[4] = r4; + dst[5] = r5; + dst[6] = r6; + dst[7] = r7; + dst[8] = r_scale; + } +} + +static void quantize_row_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (k + qk - 1) / qk; + + for (uint32_t i = 0; i < nb; i++) { + uint8_t * restrict y_block = y + i * 4 * 1152; + quantize_block_f32_q8_0_tiled(x + i * qk, y_block); + } +} + +static void quantize_row_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (k + qk - 1) / qk; + + for (uint32_t i = 0; i < nb; i++) { + uint8_t * restrict y_block = y + i * 4 * 1280; + quantize_block_f32_q8_1_tiled(x + i * qk, y_block); + } +} + +// Dot kernels & helpers that consume tiled activations + +static inline HVX_Vector hvx_vec_mul_f16_f16_to_f32_lower32(HVX_Vector v1, HVX_Vector v2) { +#if __HVX_ARCH__ >= 79 + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v1, v2); + return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(p), Q6_V_lo_W(p), -4)); +#else + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v1, v2); + HVX_Vector hi = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); + HVX_Vector lo = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); + return Q6_V_lo_W(Q6_W_vshuff_VVR(hi, lo, -4)); +#endif +} + +static inline HVX_Vector unpack_and_interleave_4bit(HVX_Vector v_a, HVX_Vector v_b, HVX_Vector mask_h4) { + HVX_Vector v_W0 = Q6_V_vand_VV(v_a, mask_h4); + HVX_Vector v_W1 = Q6_Vub_vlsr_VubR(v_a, 4); + HVX_Vector v_W2 = Q6_V_vand_VV(v_b, mask_h4); + HVX_Vector v_W3 = Q6_Vub_vlsr_VubR(v_b, 4); + + HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_W1, v_W0, -1); + HVX_VectorPair v23_pair = Q6_W_vshuff_VVR(v_W3, v_W2, -1); + HVX_VectorPair v0123_pair = Q6_W_vshuff_VVR(Q6_V_lo_W(v23_pair), Q6_V_lo_W(v01_pair), -2); + return Q6_V_lo_W(v0123_pair); +} + +static inline HVX_VectorPair unpack_and_interleave_4bit_x2(HVX_Vector v_src, HVX_Vector mask_h4) { + HVX_Vector v_lo = Q6_V_vand_VV(v_src, mask_h4); + HVX_Vector v_hi = Q6_Vub_vlsr_VubR(v_src, 4); + HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_hi, v_lo, -1); + HVX_Vector v01_lo = Q6_V_lo_W(v01_pair); + HVX_Vector v01_hi = Q6_V_hi_W(v01_pair); + + HVX_Vector v23_lo = Q6_V_valign_VVR(v01_hi, v01_lo, 64); + HVX_Vector v_W0 = Q6_V_lo_W(Q6_W_vshuff_VVR(v23_lo, v01_lo, -2)); + + HVX_Vector v67_lo = Q6_V_valign_VVR(v01_lo, v01_hi, 64); + HVX_Vector v_W1 = Q6_V_lo_W(Q6_W_vshuff_VVR(v67_lo, v01_hi, -2)); + + return Q6_W_vcombine_VV(v_W1, v_W0); +} + +static inline HVX_Vector accum_4bit_32x1( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act, + HVX_Vector i8 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8); + HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]); + } + + return Q6_Vw_vadd_VwVw(v_sum0, v_sum1); +} + +static inline HVX_Vector accum_4bit_32x1_lut( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act, + HVX_Vector mask_h4, + HVX_Vector lut +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0); + HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]); + } + + return Q6_Vw_vadd_VwVw(v_sum0, v_sum1); +} + +static inline HVX_VectorPair accum_4bit_32x2( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1, + HVX_Vector i8 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8); + HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8); + + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]); + + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]); + } + + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static inline HVX_VectorPair accum_4bit_32x2_lut( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1, + HVX_Vector mask_h4, + HVX_Vector lut +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0); + HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0); + + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]); + + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]); + } + + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static inline HVX_Vector accum_q8_0_32x1( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act +) { + HVX_Vector v_sum = Q6_V_vzero(); + #pragma unroll + for (int g = 0; g < 8; g++) { + HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64); + HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2)); + v_sum = Q6_Vw_vrmpyacc_VwVbVb(v_sum, v_W, v_act[g]); + } + return v_sum; +} + +static inline HVX_VectorPair accum_q8_0_32x2( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + #pragma unroll + for (int g = 0; g < 8; g++) { + HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64); + HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2)); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W, v_act0[g]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W, v_act1[g]); + } + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, i8); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, i8); + HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, i8); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[4]; + HVX_Vector v_scale_w1 = vptr1[4]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, i8); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1280); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, Q6_V_vzero()); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_sum_a = v_act[9]; + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a); + HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a); + + HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1280); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1280); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1280); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1280); + + HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, Q6_V_vzero()); + HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, Q6_V_vzero()); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_offset0 = vptr0[4]; + HVX_VectorPair p_deal0 = Q6_W_vdeal_VVR(v_scale_offset0, v_scale_offset0, -2); + HVX_Vector v_scale0 = Q6_V_lo_W(p_deal0); + HVX_Vector v_offset0 = Q6_V_hi_W(p_deal0); + + HVX_Vector v_scale_offset1 = vptr1[4]; + HVX_VectorPair p_deal1 = Q6_W_vdeal_VVR(v_scale_offset1, v_scale_offset1, -2); + HVX_Vector v_scale1 = Q6_V_lo_W(p_deal1); + HVX_Vector v_offset1 = Q6_V_hi_W(p_deal1); + + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_sum_a_c0_0 = v_act0_0[9]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_sum_a_c1_0 = v_act1_0[9]; + + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_sum_a_c0_1 = v_act0_1[9]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + HVX_Vector v_sum_a_c1_1 = v_act1_1[9]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c0_0); + HVX_Vector v_offset_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c1_0); + HVX_Vector v_offset_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c1_0); + + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c0_1); + HVX_Vector v_offset_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c1_1); + HVX_Vector v_offset_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c1_1); + + HVX_Vector v_scaled_dot_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_add_f32_f32(v_scaled_dot_c0_0, v_offset_comb_c0_0); + + HVX_Vector v_scaled_dot_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_add_f32_f32(v_scaled_dot_c1_0, v_offset_comb_c1_0); + + HVX_Vector v_scaled_dot_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_add_f32_f32(v_scaled_dot_c0_1, v_offset_comb_c0_1); + + HVX_Vector v_scaled_dot_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_add_f32_f32(v_scaled_dot_c1_1, v_offset_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1280); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1280); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, Q6_V_vzero()); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_sum_a_c0 = v_act0[9]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + HVX_Vector v_sum_a_c1 = v_act1[9]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c0); + HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c1); + HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c1); + + HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0); + + HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[8]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 1152); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 1152); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_q8_0_32x2(vptr0, v_act0_0, v_act1_0); + HVX_VectorPair v_sums1 = accum_q8_0_32x2(vptr1, v_act0_1, v_act1_1); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[8]; + HVX_Vector v_scale_w1 = vptr1[8]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0, v_act1); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[8]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut); + HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[4]; + HVX_Vector v_scale_w1 = vptr1[4]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector v_scale_a_f16 = v_act[8]; + HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_f16); + HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32); + + HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut); + HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = hvx_vmem(tile_ptr + (kt + 0) * 640 + 512); + HVX_Vector r0_d0 = Q6_V_vdelta_VV(v_scale_w0, expand); + r0_d0 = Q6_V_vand_VV(r0_d0, e8m0_mask); + HVX_Vector v_scale_w_f32_0 = Q6_Vw_vasl_VwR(r0_d0, 23); + + HVX_Vector v_scale_w1 = hvx_vmem(tile_ptr + (kt + 1) * 640 + 512); + HVX_Vector r0_d1 = Q6_V_vdelta_VV(v_scale_w1, expand); + r0_d1 = Q6_V_vand_VV(r0_d1, e8m0_mask); + HVX_Vector v_scale_w_f32_1 = Q6_Vw_vasl_VwR(r0_d1, 23); + + HVX_Vector v_scale_a_c0_f16_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_f16_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_f16_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_f16_1 = v_act1_1[8]; + + HVX_VectorPair p_scale_a_c0_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_0); + HVX_VectorPair p_scale_a_c1_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_0); + HVX_VectorPair p_scale_a_c0_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_1); + HVX_VectorPair p_scale_a_c1_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_1); + + HVX_Vector v_scale_a_c0_0 = Q6_V_lo_W(p_scale_a_c0_f32_0); + HVX_Vector v_scale_a_c1_0 = Q6_V_lo_W(p_scale_a_c1_f32_0); + HVX_Vector v_scale_a_c0_1 = Q6_V_lo_W(p_scale_a_c0_f32_1); + HVX_Vector v_scale_a_c1_1 = Q6_V_lo_W(p_scale_a_c1_f32_1); + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector v_scale_a_c0_f16 = v_act0[8]; + HVX_Vector v_scale_a_c1_f16 = v_act1[8]; + + HVX_VectorPair p_scale_a_c0_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16); + HVX_VectorPair p_scale_a_c1_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16); + + HVX_Vector v_scale_a_c0 = Q6_V_lo_W(p_scale_a_c0_f32); + HVX_Vector v_scale_a_c1 = Q6_V_lo_W(p_scale_a_c1_f32); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f)); + v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static inline void quantize_f32_q8_0_tiled_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_0_tiled((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_1_tiled_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1_tiled((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_0_tiled_block_kernel( + const float * restrict src, + uint8_t * restrict dst, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t ib_first, + uint32_t ib_last, + size_t src_row_size, + size_t dst_row_size, + uint32_t r, + uint32_t c +) { + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne0 + qk - 1) / qk; + + for (uint32_t ib = ib_first; ib < ib_last; ++ib) { + const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float); + uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1152; + + hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1); + + if (c == nb - 1) { + uint32_t active_elements = ne0 - c * qk; + hvx_splat_f32_a(tmp_data, 0.0f, qk); + hvx_copy_f32_aa(tmp_data, src_ptr, active_elements); + } else { + hvx_copy_f32_aa(tmp_data, src_ptr, qk); + } + + quantize_block_f32_q8_0_tiled((float *) tmp_data, dst_ptr); + + c++; + if (c == nb) { + c = 0; + r++; + } + } +} + +static inline void quantize_f32_q8_1_tiled_block_kernel( + const float * restrict src, + uint8_t * restrict dst, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t ib_first, + uint32_t ib_last, + size_t src_row_size, + size_t dst_row_size, + uint32_t r, + uint32_t c +) { + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne0 + qk - 1) / qk; + + for (uint32_t ib = ib_first; ib < ib_last; ++ib) { + const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float); + uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1280; + + hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1); + + if (c == nb - 1) { + uint32_t active_elements = ne0 - c * qk; + hvx_splat_f32_a(tmp_data, 0.0f, qk); + hvx_copy_f32_aa(tmp_data, src_ptr, active_elements); + } else { + hvx_copy_f32_aa(tmp_data, src_ptr, qk); + } + + quantize_block_f32_q8_1_tiled((float *) tmp_data, dst_ptr); + + c++; + if (c == nb) { + c = 0; + r++; + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 53ab33c07bd..d76512ea4a3 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { +AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } -#ifdef HTP_HAS_HMX - ctx->hmx_enabled = use_hmx; + ctx->hmx_enabled = n_hmx; ctx->hmx_queue = NULL; - if (use_hmx) { + if (n_hmx) { ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); if (ctx->hmx_queue) { ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS]; @@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->hmx_enabled = false; } } - FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); -#endif + FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx); qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); @@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { dma_queue_delete(ctx->dma[i]); } -#ifdef HTP_HAS_HMX if (ctx->hmx_queue) { hmx_queue_delete(ctx->hmx_queue); ctx->hmx_queue = NULL; } ctx->hmx_enabled = false; -#endif vtcm_free(ctx); @@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) { return AEE_SUCCESS; } +AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) { + (void)handle; + if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) { + return AEE_EBADPARM; + } + + qurt_sysenv_max_hthreads_t hw_threads; + qurt_sysenv_get_max_hw_threads(&hw_threads); + uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; + + uint32_t n_hvx_val = hw_nhvx; + if (n_hvx_val > hw_threads.max_hthreads) { + n_hvx_val = hw_threads.max_hthreads; + } + if (n_hvx_val > HTP_MAX_NTHREADS) { + n_hvx_val = HTP_MAX_NTHREADS; + } + + // for now we force n_threads == n_hvx + *n_threads = n_hvx_val; + *n_hvx = n_hvx_val; + *n_hmx = 1; + + uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback + HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL); + *vtcm_size = vtcm_sz; + + return AEE_SUCCESS; +} + static void htp_error_callback(dspqueue_t queue, int error, void * context) { // No errors expected on the DSP. FARF(ERROR, "Error callback: 0x%08x", (unsigned) error); @@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_MUL_MAT_ID: return op_matmul_id(octx); + case HTP_OP_MUL_MAT_QKV: + return op_matmul_qkv(octx); + + case HTP_OP_MUL_MAT_FFN: + return op_matmul_ffn(octx); + case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: @@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str } } -static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { +static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params)); octx->flags = op->flags; octx->op = op->opcode; @@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, src->ne[0], src->ne[1], src->ne[3], src->ne[3]); } - // Prep output tensor - struct htp_tensor *dst = tens + op->dst; + // Prep output tensors + for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) { + uint16_t dst_idx = op->dst[i]; + if (dst_idx == 0xffff) { + octx->dsts[i] = NULL; + continue; + } + struct htp_tensor *dst = tens + dst_idx; + octx->dsts[i] = dst; - octx->dst = dst; + FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } - FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, - dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + int status = execute_op(octx); - (void) execute_op(octx); + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->src3_spad.src = NULL; + octx->dst_spad.src = NULL; // flush buffers on output - hex_l2flush((void *) dst->data, dst->size); - dst->flags |= HTP_TENSOR_FLUSHED; + for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) { + if (octx->dsts[i]) { + struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i]; + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; + + FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } + } - FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, - dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + return status; } #define DSPQUEUE_POLL_TIMEOUT_USEC 100 @@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } + int op_status = HTP_STATUS_OK; + uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch + for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; - if (i == (n_ops-1)) { - // wake up the host before starting the last op + if (i == op_wakeup) { dspqueue_write_early_wakeup_noblock(queue, 0, 0); } profile_start(ctx->profiler, &prof); - proc_op_req(octx, tens, i, &ops[i]); + op_status = proc_op_req(octx, tens, i, &ops[i]); profile_stop(ctx->profiler, &prof); + if (op_status != HTP_STATUS_OK) { + break; + } + if (ctx->profiler) { pds[i].opcode = ops[i].opcode; pds[i].usecs = prof.usecs; @@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_opbatch_rsp rsp; rsp.id = req.id; - rsp.status = HTP_STATUS_OK; + rsp.status = op_status; rsp.n_bufs = n_bufs; rsp.n_tensors = n_tens; rsp.n_ops = n_ops; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 8e016c1be5d..81a0ffbebb8 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -17,2906 +18,120 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" -#include "hmx-ops.h" - -#define MM_SPAD_SRC0_NROWS 16 -#define MM_SPAD_SRC1_NROWS 16 -#define MM_SPAD_DST_NROWS 2 - -struct htp_matmul_context { - const char * type; - struct htp_ops_context * octx; - - void (*vec_dot_1x1)(const int n, float * restrict s0, - const void * restrict vx0, - const void * restrict vy0); - - void (*vec_dot_2x1)(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0); - - void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1); - - void (*vec_dot_4x1)(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0); - - // Precomputed values - uint32_t src0_nrows_per_thread; - uint32_t src1_nrows_per_thread; - - struct fastdiv_values mm_div_ne12_ne1; - struct fastdiv_values mm_div_ne1; - struct fastdiv_values mm_div_r2; - struct fastdiv_values mm_div_r3; - - // Fields for scattered mapping & HMX support in MUL_MAT_ID - const uint32_t * matrix_row_counts; - const struct mmid_row_mapping * matrix_rows; - bool hmx_eligible; -}; - -// vdelta control to expand first 32 e8m0 values into 32 uint32 elements -static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { - 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, - 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, - 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02, - 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, - 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48, - 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, - 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, -}; - -// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue -// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 -static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = { - 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0, - 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { - 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, - 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -}; - -static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); - r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); - } - - return r; -} - -// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales - -static inline size_t q8x4x2_row_size(uint32_t ne) { - // ensures perfect alignment of quants and full row - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (ne + qk - 1) / qk; - return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); -} - -static inline size_t q8_1x4x2_row_size(uint32_t ne) { - // ensures perfect alignment of quants and full row - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (ne + qk - 1) / qk; - return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); -} - -static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - // Convert uint4 to int4 (i.e. x - 8) - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8); - r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8); - r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = v0; - r.v[i*2+1] = v1; - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_V_lo_W(v0_1_p); - r.v[i*2+1] = Q6_V_hi_W(v0_1_p); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); - r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 128 vals - HVX_Vector v1 = vptr[1]; // ... - HVX_Vector v2 = vptr[2]; // ... - HVX_Vector v3 = vptr[3]; // ... - HVX_Vector v4 = vptr[4]; // ... - HVX_Vector v5 = vptr[5]; // ... - HVX_Vector v6 = vptr[6]; // ... - HVX_Vector v7 = vptr[7]; // ... - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) { - return hvx_vec_load_q8x4x8_full(ptr); -} - -// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). -// Accumulate each block into a single int32 value. -// Return a single HVX vector with 32x int32 accumulators. -// This version is parameterized to support less than 1024 elements. -// if() checks are optimized out at compile time -- make sure to pass N as a constexpr. - -static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - HVX_Vector r0 = Q6_V_vzero(); - HVX_Vector r1 = Q6_V_vzero(); - HVX_Vector r2 = Q6_V_vzero(); - HVX_Vector r3 = Q6_V_vzero(); - HVX_Vector r4 = Q6_V_vzero(); - HVX_Vector r5 = Q6_V_vzero(); - HVX_Vector r6 = Q6_V_vzero(); - HVX_Vector r7 = Q6_V_vzero(); - - HVX_VectorPair p3; - HVX_VectorPair p2; - HVX_VectorPair p1; - HVX_VectorPair p0; - - if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); } - if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); } - if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); } - if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); } - if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); } - if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); } - if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); } - if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } - if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); } - if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); } - - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } - if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); } - if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } - - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - - return r0; -} - -static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { - HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); - HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); - HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); - HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); - HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); - HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); - HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); - HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); - - HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4); - HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4); - HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4); - HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4); - - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); - r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); - r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); - - p0 = Q6_W_vdeal_VVR(r1, r0, -4); - p1 = Q6_W_vdeal_VVR(r3, r2, -4); - - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); - - p0 = Q6_W_vdeal_VVR(r1, r0, -4); - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - - return r0; -} - -static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - if (n >= 512) - return hvx_vec_rmpy_x8_full(x, y); - - return hvx_vec_rmpy_x8_partial(x, y, 512); -} - -static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r1_ms = Q6_V_vand_QV(bmask, r1_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); - HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); - - HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); - HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); - HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); - - HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); - HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r1_ms = Q6_V_vand_QV(bmask, r1_ms); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r2_ms = Q6_V_vand_QV(bmask, r2_ms); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r3_ms = Q6_V_vand_QV(bmask, r3_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); - HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - - HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); - HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); - - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); - HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); - HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); - HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); - HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - - HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); - HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); - - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); - - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); - HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); - HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); - HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Zero out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q8_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -// ======== IQ4_NL x Q8_0 vec_dot kernels ======== -// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). -// Scale format is identical to Q4_0 (fp16 scales). - -static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vx2, - const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, - float * restrict s0, - float * restrict s1, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vy0, - const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; - - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - - // Zero-out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (f32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - - // Zero-out unused values - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - r2_d = Q6_V_vdelta_VV(r2_d, expand); - r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); - r2_d = Q6_Vw_vasl_VwR(r2_d, 23); - r3_d = Q6_V_vdelta_VV(r3_d, expand); - r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); - r3_d = Q6_Vw_vasl_VwR(r3_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - r2_d = Q6_V_vdelta_VV(r2_d, expand); - r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); - r2_d = Q6_Vw_vasl_VwR(r2_d, 23); - r3_d = Q6_V_vdelta_VV(r3_d, expand); - r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); - r3_d = Q6_Vw_vasl_VwR(r3_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); - - // Zero-out unused values - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; +#include "matmul-ops.h" +#include "vtcm-utils.h" + +typedef struct { + float *dst; + const float *activation; + const __fp16 *weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_mm_f16_f32_batched_params_t; + +struct htp_mm_context { + const char * type; + struct htp_ops_context * octx; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + void (*vec_dot_1x1)(const uint32_t n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + void (*vec_dot_2x1)(const uint32_t n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + void (*vec_dot_2x2)(const uint32_t n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + void (*vec_dot_32x1)(const uint32_t n, float * restrict s, + const void * restrict vx, + const void * restrict vy, uint32_t valid_rows); - // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; + struct fastdiv_values mm_div_ne11; - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + // Precomputed block-parallel quantization values + uint32_t quant_ib_first[MAX_NUM_WORKERS]; + uint32_t quant_ib_last[MAX_NUM_WORKERS]; + uint32_t quant_r[MAX_NUM_WORKERS]; + uint32_t quant_c[MAX_NUM_WORKERS]; - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; - // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); - vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); - vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); - vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } + // Dynamic VTCM pointers allocated sequentially + uint8_t * vtcm_src0; + uint8_t * vtcm_src1; + uint8_t * vtcm_src2; + uint8_t * vtcm_src3; + uint8_t * vtcm_dst; + + // Cached strides + uint32_t vtcm_src0_stride; + uint32_t vtcm_src1_stride; + uint32_t vtcm_src2_stride; + uint32_t vtcm_src3_stride; + + // Cached thread offsets/sizes + uint32_t vtcm_src0_size_per_thread; + uint32_t vtcm_src1_size_per_thread; + uint32_t vtcm_src2_size_per_thread; + uint32_t vtcm_src3_size_per_thread; + uint32_t vtcm_dst_size_per_thread; +}; - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); - vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); - vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); - vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); - - // Zero out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } +// vdelta control to expand first 32 e8m0 values into 32 uint32 elements +static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { + 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, + 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, + 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02, + 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, + 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48, + 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, + 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, +}; - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); +// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue +// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = { + 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0, + 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} +static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, + 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; #if __HVX_ARCH__ < 79 #define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) @@ -2926,7 +141,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif -static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -2954,7 +169,7 @@ static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, +static void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -2996,7 +211,7 @@ static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, s0[1] = va.fp32[1]; } -static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3054,7 +269,7 @@ static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * res s1[1] = va1.fp32[1]; } -static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -3088,7 +303,7 @@ static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3115,7 +330,7 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, +static void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3152,7 +367,7 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3212,7 +427,7 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -3242,7 +457,7 @@ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -3295,65 +510,58 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict src2 = octx->src[2]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne20 = src2->ne[0]; \ - const uint32_t ne21 = src2->ne[1]; \ - const uint32_t ne22 = src2->ne[2]; \ - const uint32_t ne23 = src2->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - struct htp_matmul_context * mmctx = data; \ - struct htp_ops_context * octx = mmctx->octx; \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; \ + htp_matmul_tensors_preamble; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mm_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -3388,7 +596,9 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { return; } - // block-tiling attempt + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0_start); + const uint32_t blck_0 = 64; const uint32_t blck_1 = 64; @@ -3412,28 +622,606 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0); } } } - t2 = HAP_perf_get_qtimer_count(); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0_start); +} + +#include "hmx-mm-kernels-tiled.h" +#include "hvx-mm-kernels-tiled.h" +#include "hvx-mm-kernels-flat.h" + +// Specialized repacked matmul macros +#define MATMUL_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + htp_matmul_preamble; \ + \ + const uint32_t src0_nrows = ne01 * ne02 * ne03; \ + const uint32_t src1_nrows = ne11 * ne12 * ne13; \ + \ + const uint32_t src0_start_row = src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); \ + \ + if (src0_start_row >= src0_end_row) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const size_t dst_row_size = nb1; \ + const size_t src1_row_size = nb11; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)ne0 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); \ + \ + float * dst_ptr0 = &dst_row0[ct * 32]; \ + float * dst_ptr1 = &dst_row1[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0, dst_ptr1, w_tile, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); \ + float * dst_ptr = &dst_row[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr, w_tile, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ +} + +#define MATVEC_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X1) \ +static void hvx_mv_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + htp_matmul_preamble; \ + \ + const uint32_t src0_nrows = ne01; \ + \ + const uint32_t src0_start_row = src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); \ + \ + if (src0_start_row >= src0_end_row) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const size_t dst_row_size = nb1; \ + const size_t src1_row_size = nb11; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; \ + uint8_t * vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * src1_data = mmctx->vtcm_src1; \ + \ + float * tmp = (float *) vtcm_dst_ptr; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; \ + float * restrict dst_col = (float *) dst->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; \ + \ + float * dst_ptr = &tmp[ct * 32 - src0_start_row]; \ + int valid_rows = (int)ne0 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + DOT_2X1(ne10, dst_ptr, w_tile, src1_col, valid_rows); \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + \ + int copy_cnt = (int)MIN(src0_end_row, ne0) - (int)src0_start_row; \ + if (copy_cnt > 0) { \ + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, copy_cnt); \ + } \ +} + +#define MATMUL_QKV_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_qkv_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + \ + const struct htp_tensor * restrict src0 = octx->src[0]; /* Wk */ \ + const struct htp_tensor * restrict src1 = octx->src[1]; /* x */ \ + const struct htp_tensor * restrict src2 = octx->src[2]; /* Wv */ \ + const struct htp_tensor * restrict src3 = octx->src[3]; /* Wq */ \ + const struct htp_tensor * restrict dst_k = octx->dsts[0]; \ + const struct htp_tensor * restrict dst_v = octx->dsts[1]; \ + const struct htp_tensor * restrict dst_q = octx->dsts[2]; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; \ + \ + const size_t dst_k_row_size = dst_k->nb[1]; /* K and V share output width */ \ + const size_t dst_q_row_size = dst_q->nb[1]; /* Q may be wider (GQA) */ \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; \ + uint8_t * restrict vtcm_src3_ptr = mmctx->vtcm_src3 + mmctx->vtcm_src3_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; \ + const uint8_t * restrict src3_row = (const uint8_t *) src3->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + \ + /* 1. Process K and V together */ \ + const uint32_t src0_nrows_kv = src0->ne[1] * src0->ne[2] * src0->ne[3]; /* src0 is Wk */ \ + uint32_t src0_nrows_per_thread_kv = (src0_nrows_kv + nth - 1) / nth; \ + src0_nrows_per_thread_kv = hex_round_up(src0_nrows_per_thread_kv, 32); \ + \ + const uint32_t start_row_kv = src0_nrows_per_thread_kv * ith; \ + const uint32_t end_row_kv = MIN(start_row_kv + src0_nrows_per_thread_kv, src0_nrows_kv); \ + \ + if (start_row_kv < end_row_kv) { \ + uint32_t ct_start_kv = start_row_kv / 32; \ + uint32_t ct_end_kv = (end_row_kv + 31) / 32; \ + \ + uint32_t push_ct = ct_start_kv; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end_kv; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + d * tile_row_transfer_size_aligned, \ + src2_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start_kv; ct < ct_end_kv; ct++) { \ + const uint8_t * w_tile_k = dma_queue_pop(dma_queue).dst; \ + const uint8_t * w_tile_v = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)src0->ne[1] - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ith); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_k = (float *) (dst_k->data + ((ir1+0) * dst_k_row_size)); \ + float * restrict dst_row1_k = (float *) (dst_k->data + ((ir1+1) * dst_k_row_size)); \ + float * dst_ptr0_k = &dst_row0_k[ct * 32]; \ + float * dst_ptr1_k = &dst_row1_k[ct * 32]; \ + \ + float * restrict dst_row0_v = (float *) (dst_v->data + ((ir1+0) * dst_k_row_size)); \ + float * restrict dst_row1_v = (float *) (dst_v->data + ((ir1+1) * dst_k_row_size)); \ + float * dst_ptr0_v = &dst_row0_v[ct * 32]; \ + float * dst_ptr1_v = &dst_row1_v[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_k, dst_ptr1_k, w_tile_k, src1_col0, src1_col1, valid_rows); \ + DOT_2X2(ne10, dst_ptr0_v, dst_ptr1_v, w_tile_v, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); \ + float * dst_ptr_k = &dst_row_k[ct * 32]; \ + \ + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); \ + float * dst_ptr_v = &dst_row_v[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_k, w_tile_k, src1_col, valid_rows); \ + DOT_2X1(ne10, dst_ptr_v, w_tile_v, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ith); \ + \ + if (push_ct < ct_end_kv) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_k, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_v, src2_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + } \ + \ + /* 2. Process Q separately */ \ + const uint32_t src0_nrows_q = src3->ne[1] * src3->ne[2] * src3->ne[3]; /* src3 is Wq */ \ + uint32_t src0_nrows_per_thread_q = (src0_nrows_q + nth - 1) / nth; \ + src0_nrows_per_thread_q = hex_round_up(src0_nrows_per_thread_q, 32); \ + \ + const uint32_t start_row_q = src0_nrows_per_thread_q * ith; \ + const uint32_t end_row_q = MIN(start_row_q + src0_nrows_per_thread_q, src0_nrows_q); \ + \ + if (start_row_q < end_row_q) { \ + uint32_t ct_start_q = start_row_q / 32; \ + uint32_t ct_end_q = (end_row_q + 31) / 32; \ + \ + uint32_t push_ct = ct_start_q; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end_q; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + d * tile_row_transfer_size_aligned, \ + src3_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start_q; ct < ct_end_q; ct++) { \ + const uint8_t * w_tile_q = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)src3->ne[1] - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_q = (float *) (dst_q->data + ((ir1+0) * dst_q_row_size)); \ + float * restrict dst_row1_q = (float *) (dst_q->data + ((ir1+1) * dst_q_row_size)); \ + float * dst_ptr0_q = &dst_row0_q[ct * 32]; \ + float * dst_ptr1_q = &dst_row1_q[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_q, dst_ptr1_q, w_tile_q, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); \ + float * dst_ptr_q = &dst_row_q[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_q, w_tile_q, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end_q) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_q, src3_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + } \ +} + +#define MATMUL_FFN_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_ffn_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + \ + const struct htp_tensor * restrict src0 = octx->src[0]; /* Wgate */ \ + const struct htp_tensor * restrict src1 = octx->src[1]; /* y */ \ + const struct htp_tensor * restrict src2 = octx->src[2]; /* Wup */ \ + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; \ + const struct htp_tensor * restrict dst_up = octx->dsts[1]; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; \ + \ + const size_t dst_row_size = dst_gate->nb[1]; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + \ + const uint32_t src0_nrows = ne01 * src0->ne[2] * src0->ne[3]; \ + const uint32_t src0_start_row = mmctx->src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + mmctx->src0_nrows_per_thread, src0_nrows); \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + d * tile_row_transfer_size_aligned, \ + src2_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile_gate = dma_queue_pop(dma_queue).dst; \ + const uint8_t * w_tile_up = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)ne01 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_gate = (float *) (dst_gate->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1_gate = (float *) (dst_gate->data + ((ir1+1) * dst_row_size)); \ + float * dst_ptr0_gate = &dst_row0_gate[ct * 32]; \ + float * dst_ptr1_gate = &dst_row1_gate[ct * 32]; \ + \ + float * restrict dst_row0_up = (float *) (dst_up->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1_up = (float *) (dst_up->data + ((ir1+1) * dst_row_size)); \ + float * dst_ptr0_up = &dst_row0_up[ct * 32]; \ + float * dst_ptr1_up = &dst_row1_up[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_gate, dst_ptr1_gate, w_tile_gate, src1_col0, src1_col1, valid_rows); \ + DOT_2X2(ne10, dst_ptr0_up, dst_ptr1_up, w_tile_up, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); \ + float * dst_ptr_gate = &dst_row_gate[ct * 32]; \ + \ + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); \ + float * dst_ptr_up = &dst_row_up[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_gate, w_tile_gate, src1_col, valid_rows); \ + DOT_2X1(ne10, dst_ptr_up, w_tile_up, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_gate, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_up, src2_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ +} + +MATMUL_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + +#define QUANTIZE_IMPL(name, log_name, kernel_fn, dst_row_size_expr) \ +static void name(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + const struct htp_tensor * src = octx->src[1]; \ + const uint32_t ne0 = src->ne[0]; \ + const uint32_t ne1 = src->ne[1]; \ + const uint32_t ne2 = src->ne[2]; \ + const uint32_t ne3 = src->ne[3]; \ + const uint32_t nrows = ne1 * ne2 * ne3; \ + const uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; \ + \ + const uint32_t ir_first = nrows_per_thread * ith; \ + if (ir_first >= nrows) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); \ + \ + uint8_t * restrict dst = mmctx->vtcm_src1; \ + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); \ + const size_t src_row_size = src->nb[1]; \ + const size_t dst_row_size = (dst_row_size_expr); \ + const uint8_t * restrict src_data = (const uint8_t *) src->data + (src_row_size * ir_first); \ + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); \ + uint8_t * restrict tmp_data = (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith); \ + kernel_fn(src_data, dst_data, tmp_data, ne0, ir_last - ir_first, src_row_size, dst_row_size); \ + \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); \ +} + +QUANTIZE_IMPL(quantize_f32_q8_0_tiled, "quantize-f32-q8_0_tiled", quantize_f32_q8_0_tiled_kernel, htp_mm_q8_0_tiled_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_1_tiled, "quantize-f32-q8_1_tiled", quantize_f32_q8_1_tiled_kernel, htp_mm_q8_1_tiled_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_0_flat, "quantize-f32-q8_0_flat", quantize_f32_q8_0_flat_kernel, htp_mm_q8_0_flat_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_1_flat, "quantize-f32-q8_1_flat", quantize_f32_q8_1_flat_kernel, htp_mm_q8_1_flat_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_f32_flat, "quantize-f32-f32", quantize_f32_f32_flat_kernel, mmctx->vtcm_src1_stride) +QUANTIZE_IMPL(quantize_f32_f16_flat, "quantize-f32-f16", quantize_f32_f16_flat_kernel, mmctx->vtcm_src1_stride) +QUANTIZE_IMPL(quantize_f16_f16_flat, "quantize-f16-f16", quantize_f16_f16_flat_kernel, mmctx->vtcm_src1_stride) + +static void quantize_f32_q8_0_tiled_block(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); + + const struct htp_tensor * src = octx->src[1]; + + quantize_f32_q8_0_tiled_block_kernel( + (const float *) src->data, + mmctx->vtcm_src1, + (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith), + src->ne[0], + mmctx->quant_ib_first[ith], + mmctx->quant_ib_last[ith], + src->nb[1], + htp_mm_q8_0_tiled_row_size(src->ne[0]), + mmctx->quant_r[ith], + mmctx->quant_c[ith] + ); - FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); } -// src1 tensor is already in VTCM spad -static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { - htp_matmul_preamble; +static void quantize_f32_q8_1_tiled_block(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); + + const struct htp_tensor * src = octx->src[1]; + + quantize_f32_q8_1_tiled_block_kernel( + (const float *) src->data, + mmctx->vtcm_src1, + (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith), + src->ne[0], + mmctx->quant_ib_first[ith], + mmctx->quant_ib_last[ith], + src->nb[1], + htp_mm_q8_1_tiled_row_size(src->ne[0]), + mmctx->quant_r[ith], + mmctx->quant_c[ith] + ); + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); +} + +MATVEC_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x1) +MATVEC_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x1) +MATVEC_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x1) +MATVEC_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x1) +MATVEC_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x1) + +MATVEC_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x1) +MATVEC_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x1) +MATVEC_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x1) +MATVEC_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x1) +MATVEC_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x1) + + +MATMUL_QKV_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_QKV_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + + +MATMUL_FFN_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_FFN_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + +static void hvx_mm_2d(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows @@ -3447,34 +1235,31 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; - const size_t src0_stride = src0_spad->stride; - const size_t src1_stride = src1_spad->stride; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; - - volatile uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; const uint8_t * restrict src0_row = (const uint8_t *) src0->data; - // Prefill spad with src0 rows + // Prefill vtcm with src0 rows #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { + if (is0 >= (int)n_prefetch) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } // Process src0 rows @@ -3482,7 +1267,6 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Process src1 columns in pairs (2×2 tiling) uint32_t ir1 = 0; for (; ir1 + 1 < src1_nrows; ir1 += 2) { @@ -3499,24 +1283,23 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + // Prefetch next (n + vtcm_nrows) row + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } } // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); @@ -3528,19 +1311,10 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { } htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mv_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const uint32_t src0_nrows = ne01; @@ -3552,1282 +1326,1929 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; - const size_t src0_stride = src0_spad->stride; - const size_t src1_stride = src1_spad->stride; - - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * src1_data = src1_spad->data; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + // Per-thread VTCMs for all tensors + uint8_t * vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; + uint8_t * vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * src1_data = mmctx->vtcm_src1; - float * tmp = (float *) spad_dst; + float * tmp = (float *) vtcm_dst_ptr; const uint8_t * restrict src0_row = (const uint8_t *) src0->data; const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - if (mmctx->vec_dot_4x1 != NULL) { - const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); - - // Prefill spad with 4x src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 4); - } + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x4) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 4); - } - } + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; - // Process leftovers - uint32_t ir0 = src0_end_row_x4; - if (ir0 + 2 <= src0_end_row) { - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - ir0 += 2; - } - if (ir0 < src0_end_row) { - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - ir0 += 1; + // Prefill vtcm with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= n_prefetch) { + break; } - } else { - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + } - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 2); - } + // Prefetch next (n + vtcm_nrows) row + const uint32_t pr0 = (ir0 + n_prefetch); + const uint32_t is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } + } - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - } + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)] -struct mmid_row_mapping { - uint32_t i1; - uint32_t i2; -}; - -// src1 tensor is already in VTCM spad -static void matmul_id(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mm_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const struct htp_tensor * restrict ids = octx->src[2]; - struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint32_t src0_nrows = ne01; // src0 rows per expert - const uint32_t src1_nrows = ne11; - + const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src1_nrows = ne11; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; - const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; + + const size_t dst_row_size = nb1; + const size_t src1_row_size = htp_mm_q8_0_tiled_row_size(ne10); + + const size_t src1_stride = mmctx->vtcm_src1_stride; + + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; + + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + + const uint8_t * src0_row = (const uint8_t *) src0->data + cur_a * nb02; + + const uint32_t tile_size = htp_mm_get_weight_tile_size(src0->type); + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(src0->type); + const uint32_t n_k_tiles_w = ne00 / 32; + const uint32_t n_k_tiles_a = ne10 / 32; + const uint32_t tile_row_stride = n_k_tiles_w * tile_size; + const uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; + + const uint32_t ct_start = src0_start_row / 32; + const uint32_t ct_end = (src0_end_row + 31) / 32; + + uint32_t push_ct = ct_start; + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + } + + for (uint32_t ct = ct_start; ct < ct_end; ct++) { + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; + + int valid_rows = (int)ne01 - (int)(ct * 32); + valid_rows = MIN(32, MAX(0, valid_rows)); + + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); + for (uint32_t cid = 0; cid < cne1; ++cid) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); + const int rm1 = row_mapping.i1; // expert idx + const int rm2 = row_mapping.i2; // token idx + + const uint32_t ir1 = fastmodulo(rm1, ne11, &mmctx->mm_div_ne11); // src1 row idx + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_stride); + float * restrict dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + + mmctx->vec_dot_32x1(ne10, &dst_row[ct * 32], w_tile, src1_col, valid_rows); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); + + if (push_ct < ct_end) { + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + push_ct++; + } + } + } +} + +static void hvx_mv_id(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + const struct htp_tensor * restrict ids = octx->src[2]; + + const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + + assert(ne13 % ne03 == 0); + + const size_t dst_row_size = nb1; + const size_t src1_row_size = htp_mm_q8_0_tiled_row_size(ne10); + + const uint32_t n_aids = src2->ne[0]; // num activated experts + const uint32_t n_ids = ne02; // num experts + + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; + + for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert + const int32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); + if (eid < 0) { + continue; + } + assert(eid < (int32_t) n_ids); + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02; + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; + float * restrict dst_row = (float *) (dst->data + ie1 * nb1); + + const uint32_t tile_size = htp_mm_get_weight_tile_size(src0->type); + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(src0->type); + const uint32_t n_k_tiles_w = ne00 / 32; + const uint32_t n_k_tiles_a = ne10 / 32; + const uint32_t tile_row_stride = n_k_tiles_w * tile_size; + const uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; + + const uint32_t ct_start = src0_start_row / 32; + const uint32_t ct_end = (src0_end_row + 31) / 32; + + uint32_t push_ct = ct_start; + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + } + + for (uint32_t ct = ct_start; ct < ct_end; ct++) { + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; + + int valid_rows = (int)ne01 - (int)(ct * 32); + valid_rows = MIN(32, MAX(0, valid_rows)); + + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); + mmctx->vec_dot_32x1(ne10, &dst_row[ct * 32], w_tile, src1_col, valid_rows); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); + + if (push_ct < ct_end) { + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + push_ct++; + } + } + } +} + +static int hvx_mm_init_vec_dot(struct htp_mm_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4_0_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q4_0_32x1; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q4_1_32x1; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8_0_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q8_0_32x1; + return 0; + case HTP_TYPE_IQ4_NL: + mmctx->type = "iq4nl_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_iq4nl_32x1; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_mxfp4_32x1; + return 0; + default: + return -1; + } +} + +static int hvx_mm_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } - const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t dst_row_size = nb1; + size_t src1_row_size = nb11; const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + size_t src1_row_size_padded; + + worker_callback_t quant_job_func; + worker_callback_t matmul_job_func; + uint32_t n_quant_jobs = 1; + if (src1_nrows > 1) { + if (is_repacked) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + matmul_job_func = hvx_mm_2d; + } + } else { + if (is_repacked) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mv_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mv_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mv_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mv_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mv_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + matmul_job_func = hvx_mv_2d; + } + } - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; + bool need_quant = true; - for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { - const int32_t cne1 = matrix_row_counts[cur_a]; + switch (kparams->kernel_type) { + case HTP_MM_KERNEL_HVX_F16_F16_VTCM: + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16_flat : quantize_f16_f16_flat; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + src1_row_size = hex_round_up(ne10 * 2, 128); + break; - if (cne1 == 0) { - continue; - } + case HTP_MM_KERNEL_HVX_F16_F32_DDR: + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = hvx_mm_4d; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + need_quant = false; + quant_job_func = NULL; + src1_row_size = nb11; + break; - if (mmctx->hmx_eligible) { - continue; - } + case HTP_MM_KERNEL_HVX_F16_F16_DDR: + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = hvx_mm_4d; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + src1_row_size = nb11; + need_quant = false; + quant_job_func = NULL; + break; - const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); + case HTP_MM_KERNEL_HVX_F32_F32_VTCM: + quant_job_func = quantize_f32_f32_flat; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + src1_row_size = hex_round_up(ne10 * 4, 128); + break; + + case HTP_MM_KERNEL_HVX_F32_F32_DDR: + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + src1_row_size = nb11; + need_quant = false; + matmul_job_func = hvx_mm_4d; + break; - // Prefill spad with src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; + case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + + if (src1_nrows > 1) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mv_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mv_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mv_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mv_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mv_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + break; } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + case HTP_MM_KERNEL_HVX_QUANT_BLOCK: + case HTP_MM_KERNEL_HVX_QUANT_ROW: + default: + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - for (uint32_t cid = 0; cid < cne1; ++cid) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); - const int rm1 = row_mapping.i1; // expert idx - const int rm2 = row_mapping.i2; // token idx + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne10 + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + break; + } - const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); - float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + size_t src0_sz = 0, src1_sz = 0, dst_sz = 0; + if (kparams->vtcm_src0_size > 0 || kparams->vtcm_src1_size > 0 || kparams->vtcm_dst_size > 0) { + src0_sz = kparams->vtcm_src0_size; + src1_sz = kparams->vtcm_src1_size; + dst_sz = kparams->vtcm_dst_size; + } else { + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, src0->type, ne10, src1_nrows, octx->n_threads, + dst_row_size, src0_row_size, src1_row_size, n_prefetch, + &src0_sz, &src1_sz, &dst_sz + ); + } + + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || + kparams->kernel_type == HTP_MM_KERNEL_HVX_F32_F32_VTCM || + kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW || + kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) { + mmctx->vtcm_src1_size_per_thread = src1_sz; + } else { + mmctx->vtcm_src1_size_per_thread = src1_sz / octx->n_threads; + } - mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); - } - } + mmctx->vtcm_src0_size_per_thread = src0_sz / octx->n_threads; + mmctx->vtcm_dst_size_per_thread = dst_sz / octx->n_threads; - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + size_t vtcm_size = kparams->vtcm_size > 0 ? (size_t)kparams->vtcm_size : (src1_sz + src0_sz + dst_sz); - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - for (uint32_t cid = 0; cid < cne1; ++cid) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); - const int rm1 = row_mapping.i1; // expert idx - const int rm2 = row_mapping.i2; // token idx + FARF(HIGH, "matmul-%s : src0-vtcm-size %zu src1-vtcm-size %zu dst-vtcm-size %zu (%zu)\n", mmctx->type, + src0_sz, src1_sz, dst_sz, vtcm_size); + + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, + octx->ctx->vtcm_size, vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } - const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); - float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_dst = vtcm_seq_alloc(&vtcm_ptr, dst_sz); - mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - } + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; + + mmctx->vtcm_src0_stride = src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (need_quant) { + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } - t2 = HAP_perf_get_qtimer_count(); + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + return HTP_STATUS_OK; } -// src1 tensor is already in VTCM spad -static void matvec_id(unsigned int nth, unsigned int ith, void * data) { - htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; +static void hvx_mm_qkv_2d(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * restrict ids = octx->src[2]; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict src0 = octx->src[0]; // Wk + const struct htp_tensor * restrict src1 = octx->src[1]; // x + const struct htp_tensor * restrict src2 = octx->src[2]; // Wv + const struct htp_tensor * restrict src3 = octx->src[3]; // Wq + const struct htp_tensor * restrict dst_k = octx->dsts[0]; + const struct htp_tensor * restrict dst_v = octx->dsts[1]; + const struct htp_tensor * restrict dst_q = octx->dsts[2]; - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + + const uint32_t ne11 = src1->ne[1]; + const uint32_t ne12 = src1->ne[2]; + const uint32_t ne13 = src1->ne[3]; - const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + const uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // no work for this thread if (src0_start_row >= src0_end_row) { return; } - assert(ne13 % ne03 == 0); - - const size_t dst_row_size = nb1; - const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); - - const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); - - const uint32_t n_aids = src2->ne[0]; // num activated experts - const uint32_t n_ids = ne02; // num experts + const size_t dst_k_row_size = dst_k->nb[1]; // K and V share output width + const size_t dst_q_row_size = dst_q->nb[1]; // Q may be wider (GQA) + const size_t src0_row_size = src0->nb[1]; + const size_t src2_row_size = src2->nb[1]; + const size_t src3_row_size = src3->nb[1]; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src2_stride = mmctx->vtcm_src2_stride; + const size_t src3_stride = mmctx->vtcm_src3_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert - const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); - assert(eid < n_ids); + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; + uint8_t * restrict vtcm_src3_ptr = mmctx->vtcm_src3 + mmctx->vtcm_src3_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; - const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02; - const uint8_t * restrict src1_col = (const uint8_t *) src1_data; - float * restrict dst_row = (float *) (dst->data + ie1 * nb1); + dma_queue * dma_queue = octx->ctx->dma[ith]; - // Prefill spad with src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); - } + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); - } - } + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; + const uint8_t * restrict src3_row = (const uint8_t *) src3->data; - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + // Prefill spad with src0, src2, src3 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= (int)n_prefetch) { + break; } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + ir0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 2); } - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], - dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - -// *** dynamic quant - -static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vzero(); - - // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); - - // Load and convert into QF32 - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); - - // Combine and convert to fp16 - HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); - HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; - - // --- Sum calculation --- - const HVX_Vector ones = Q6_Vb_vsplat_R(1); - HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements - // Sum 8 elements: - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); - - // Copy to stack to extract sums and vmaxes - float vmax0[32] __attribute__((aligned(128))); - float vmax1[32] __attribute__((aligned(128))); - float vmax2[32] __attribute__((aligned(128))); - float vmax3[32] __attribute__((aligned(128))); - int32_t sums[32] __attribute__((aligned(128))); - - hvx_vec_store_u(vmax0, 128, vmax0_sf); - hvx_vec_store_u(vmax1, 128, vmax1_sf); - hvx_vec_store_u(vmax2, 128, vmax2_sf); - hvx_vec_store_u(vmax3, 128, vmax3_sf); - hvx_vec_store_u(sums, 128, v_sums); - - float d0 = vmax0[0] / 127.0f; - float d1 = vmax1[0] / 127.0f; - float d2 = vmax2[0] / 127.0f; - float d3 = vmax3[0] / 127.0f; - - __fp16 * y_d_half = (__fp16 *) y_d; - y_d_half[0] = d0; - y_d_half[1] = (float) sums[0] * d0; - y_d_half[2] = d1; - y_d_half[3] = (float) sums[8] * d1; - y_d_half[4] = d2; - y_d_half[5] = (float) sums[16] * d2; - y_d_half[6] = d3; - y_d_half[7] = (float) sums[24] * d3; -} + // Process rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss3 = dma_queue_pop(dma_queue).dst; -static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vzero(); - - // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); - // Load and convert into QF32 - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes - - // Combine and convert to fp16 - HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); - HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - hvx_vec_store_u(y_d + 0, 2, vd01_hf); - HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); - hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf); - - hvx_vec_store_u(y_d + 4, 2, vd23_hf); - rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64); - hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); -static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); + float * restrict dst_row0_k = (float *) (dst_k->data + ((ir1+0) * dst_k_row_size)); + float * restrict dst_row1_k = (float *) (dst_k->data + ((ir1+1) * dst_k_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_k[ir0], &dst_row1_k[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); - HVX_Vector * vx = (HVX_Vector *) x; + float * restrict dst_row0_v = (float *) (dst_v->data + ((ir1+0) * dst_k_row_size)); + float * restrict dst_row1_v = (float *) (dst_v->data + ((ir1+1) * dst_k_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_v[ir0], &dst_row1_v[ir0], ss2, ss2 + src2_stride, src1_col0, src1_col1); - // Load and convert into QF32 - HVX_Vector zero = Q6_V_vzero(); - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + float * restrict dst_row0_q = (float *) (dst_q->data + ((ir1+0) * dst_q_row_size)); + float * restrict dst_row1_q = (float *) (dst_q->data + ((ir1+1) * dst_q_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_q[ir0], &dst_row1_q[ir0], ss3, ss3 + src3_stride, src1_col0, src1_col1); + } - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); - // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes - HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_k[ir0], ss0, ss0 + src0_stride, src1_col); - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_v[ir0], ss2, ss2 + src2_stride, src1_col); - hvx_vec_store_u(y_d + 0, 4, vd01_hf); - hvx_vec_store_u(y_d + 4, 4, vd23_hf); + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_q[ir0], ss3, ss3 + src3_stride, src1_col); + } - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + // Prefetch next (n + vtcm_nrows) rows + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + pr0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + pr0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 2); + } + } - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + // Process last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + ir0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 1); - *(HVX_Vector *) y_q = vx_i8; -} + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss3 = dma_queue_pop(dma_queue).dst; -static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); - HVX_Vector * vx = (HVX_Vector *) x; + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_k[ir0], ss0, src1_col); - // Load and convert into QF32 - HVX_Vector zero = Q6_V_vzero(); - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_v[ir0], ss2, src1_col); - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_q[ir0], ss3, src1_col); + } + } +} - // Compute max and scale - HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes +static void hvx_mm_ffn_2d(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; - HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + const struct htp_tensor * restrict src0 = octx->src[0]; // Wgate + const struct htp_tensor * restrict src1 = octx->src[1]; // y + const struct htp_tensor * restrict src2 = octx->src[2]; // Wup + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; + const struct htp_tensor * restrict dst_up = octx->dsts[1]; - *(HVX_UVector *) y_d = vd_hf; + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; - // Divide input by the scale - HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + const uint32_t ne11 = src1->ne[1]; + const uint32_t ne12 = src1->ne[2]; + const uint32_t ne13 = src1->ne[3]; - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; - *(HVX_Vector *) y_q = vx_i8; -} + const uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); -// Overrides input x -static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { - assert(k % 32 == 0); - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (k + qk - 1) / qk; - - const uint32_t qrow_size = k; // int8 - - const uint32_t dblk_size = 8 * 2; // 8x __fp16 - const uint32_t qblk_size = QK_Q8_0x4x2; // int8 - - uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales - - // Temp scales override input since we're working off of the aligned temp buffer in VTCM - uint8_t * restrict t_d = (uint8_t *) x; - - for (uint32_t i = 0; i < nb; i++) { -#if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#else -#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" -#endif + if (src0_start_row >= src0_end_row) { + return; } - // now copy the scales into final location - hvx_copy_f16_ua(y_d, t_d, nb * 8); -} + const size_t dst_row_size = dst_gate->nb[1]; + const size_t src0_row_size = src0->nb[1]; + const size_t src2_row_size = src2->nb[1]; -static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src2_stride = mmctx->vtcm_src2_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - struct htp_spad * spad = &octx->src0_spad; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; - uint64_t t1 = HAP_perf_get_qtimer_count(); + dma_queue * dma_queue = octx->ctx->dma[ith]; - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + // Prefill spad with src0, src2 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= (int)n_prefetch) { + break; + } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + } - const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8x4x2_row_size(ne0); + // Process rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; - uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); - uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); - const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); - memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + float * restrict dst_row0_gate = (float *) (dst_gate->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1_gate = (float *) (dst_gate->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_gate[ir0], &dst_row1_gate[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_row_size, 2); - hvx_copy_f32_aa(tmp_data, src_data, ne0); + float * restrict dst_row0_up = (float *) (dst_up->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1_up = (float *) (dst_up->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_up[ir0], &dst_row1_up[ir0], ss2, ss2 + src2_stride, src1_col0, src1_col1); + } - // FARF(HIGH, "quantize-q8x4-row: %u\n", i); - quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); - dst_data += dst_row_size; - src_data += src_row_size; - } + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); - uint64_t t2 = HAP_perf_get_qtimer_count(); + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_gate[ir0], ss0, ss0 + src0_stride, src1_col); - FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_up[ir0], ss2, ss2 + src2_stride, src1_col); + } -static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { - assert(k % 32 == 0); - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (k + qk - 1) / qk; + // Prefetch next rows + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + pr0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + } + } - const uint32_t qrow_size = k; // int8 + // Process last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 1); - const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes - const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; - uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales/sums + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); - // Temp scales override input since we're working off of the aligned temp buffer in VTCM - uint8_t * restrict t_d = (uint8_t *) x; + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_gate[ir0], ss0, src1_col); - for (uint32_t i = 0; i < nb; i++) { - quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_up[ir0], ss2, src1_col); + } } +} - // now copy the scales/sums into final location - hvx_copy_f16_ua(y_d, t_d, nb * 16); +#define DEQUANTIZE_WORKER_LOOP_IMPL(SUFFIX) \ +static void dequantize_tiled_worker_loop_##SUFFIX(unsigned int n, unsigned int i, void *data) { \ + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; \ + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_tiled_weight_to_fp16_task_##SUFFIX(state, start, end); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ } -static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; +DEQUANTIZE_WORKER_LOOP_IMPL(q4_0) +DEQUANTIZE_WORKER_LOOP_IMPL(q4_1) +DEQUANTIZE_WORKER_LOOP_IMPL(iq4_nl) +DEQUANTIZE_WORKER_LOOP_IMPL(mxfp4) +DEQUANTIZE_WORKER_LOOP_IMPL(q8_0) - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - struct htp_spad * spad = &octx->src0_spad; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); +} - uint64_t t1 = HAP_perf_get_qtimer_count(); +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, i); - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, i); +} - const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8_1x4x2_row_size(ne0); +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; - uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); - uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); - memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + int start_chunk_idx = i * st->n_chunks_per_task; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, start_chunk_idx); - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_row_size, 2); - hvx_copy_f32_aa(tmp_data, src_data, ne0); + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); - dst_data += dst_row_size; - src_data += src_row_size; + float *dst = st->dst + chunk_idx * st->dst_stride; + transfer_output_chunk_fp16_to_fp32(dst, st->vtcm_src, chunk_idx, chunk_size, st->n_cols, st->dst_stride, st->dst_cols); } - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, start_chunk_idx); } -static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - uint64_t t1 = HAP_perf_get_qtimer_count(); + int start_chunk_idx = i * st->n_chunks_per_task; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, start_chunk_idx); - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + if (st->vtcm_f32_act) { + float *thread_f32_act = st->vtcm_f32_act + i * HTP_MM_DMA_ACT_MULTIPLIER * st->k_block; + transfer_activation_chunk_fp32_to_fp16_dma_pipelined( + st->ctx->dma[i], dst, src, chunk_size, st->k_block, st->k_stride, st->k_valid, thread_f32_act + ); + } else { + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride, st->k_valid); + } + } - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, start_chunk_idx); +} + +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + transfer_activation_chunk_fp32_to_fp16_gathered( + st->dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1, st->k_valid); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + } +} + +static void transfer_activation_chunk_gathered_worker_flat_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + transfer_activation_chunk_fp32_to_fp16_gathered_flat( + st->dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->nb12, st->cne1, st->k_valid); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + } +} + +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, chunk_idx); + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, st->vtcm_src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, chunk_idx); + } +} + +// --- HMX Dispatchers & Entry Points --- + +static void dequantize_tiled_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *weight_src_ddr, + int n_cols, int k_block, + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn, int n_threads) { + + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + assert(k_block % HTP_MM_HMX_TILE_N_COLS == 0); + + size_t n_col_tiles = n_cols / HTP_MM_HMX_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); + + tiled_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)weight_src_ddr; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; + state.traces = ctx->trace; + state.ctx = ctx; + + state.tile_size = htp_mm_get_weight_tile_size(weight_type); + state.aligned_tile_size = htp_mm_get_weight_aligned_tile_size(weight_type); + + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + int n_tasks = hex_smin((int) state.n_tasks, n_threads); + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_tasks); + } +} - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int dst_stride, int dst_cols, int n_threads) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f32_au(dst_data, src_data, ne0); + if (n_rows <= 0) return; - dst_data += dst_stride; - src_data += src_stride; - } + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : hmx_ceil_div(n_rows, n_threads); + n_chunks_per_task = hex_align_up(n_chunks_per_task, 2); - uint64_t t2 = HAP_perf_get_qtimer_count(); + int actual_threads = hmx_ceil_div(n_rows, n_chunks_per_task); - FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} + output_transfer_task_state_t state; + state.n_tasks = actual_threads; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.dst_stride = dst_stride; + state.dst_cols = dst_cols; + state.traces = ctx->trace; -static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + if (actual_threads <= 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, actual_threads); + } +} + +static void transfer_activation_chunk_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int n_rows, + int k_block, + int k_stride, + int n_threads, + int k_valid, + float *vtcm_f32_act) { + assert(k_block % HTP_MM_HMX_TILE_N_COLS == 0 && k_stride % HTP_MM_HMX_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + state.k_valid = k_valid; + state.traces = ctx->trace; + state.ctx = ctx; + state.vtcm_f32_act = vtcm_f32_act; + + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + int n_tasks = hex_smin((int) state.n_tasks, n_threads); + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_tasks); + } +} + +static int hmx_mm_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *weight, + int m, int k, int n, + int act_stride, + int weight_stride, + int weight_type, + int k_valid, + int dst_stride, + int dst_cols, + int m_chunk, + int n_chunk, + int pipeline, + int n_threads, + int act_threads, + int tile_size, + int aligned_tile_size, + int vtcm_size) { + if (k % 32 != 0 || n % 32 != 0) { return -1; } + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN)) { return -1; } + + size_t row_stride = htp_mm_get_tiled_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_tiled_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_tiled_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_tiled_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_tiled_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_tiled_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; + const int n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - uint64_t t1 = HAP_perf_get_qtimer_count(); + const bool is_quant = (weight_type != HTP_TYPE_F16 && weight_type != HTP_TYPE_F32); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; + size_t m_chunk_n_rows = m_chunk; + size_t n_chunk_n_cols = n_chunk; + size_t vtcm_used = vtcm_size; - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + const size_t act_f32_size = hex_align_up((size_t)act_threads * HTP_MM_DMA_ACT_MULTIPLIER * k * sizeof(float), HTP_MM_HMX_TILE_SIZE); - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; + const size_t weight_area_size = is_quant + ? hex_align_up((n_chunk_n_cols / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE) + : hex_align_up(n_chunk_n_cols * row_stride, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + size_t scratch0_size, scratch1_size, scratch2_size; + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); // dequant buf 0 + scratch1_size = pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = pipeline ? output_area_size : 0; // output buf 1 - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f16_f32_au(dst_data, src_data, ne0); + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight_raw[2] = { NULL, NULL }; + if (weight_area_size) { + if (pipeline) { + vtcm_weight_raw[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + vtcm_weight_raw[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + } else { + vtcm_weight_raw[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + } + } + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + float *vtcm_f32_act = (float *) vtcm_seq_alloc(&vtcm_ptr, act_f32_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - dst_data += dst_stride; - src_data += src_stride; + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-2d-precomputed: VTCM overflow: used %zu budget %zu, m %d k %d n %d mc %zu nc %zu", + vtcm_used, vtcm_budget, m, k, n, m_chunk_n_rows, n_chunk_n_cols); + return -1; } - uint64_t t2 = HAP_perf_get_qtimer_count(); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} + FARF(HIGH, "hmx-mm-2d-precomputed: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); -// TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; + if (pipeline) { + // --- Asynchronous Pipelined Loop --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - uint64_t t1 = HAP_perf_get_qtimer_count(); + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; + transfer_activation_chunk_threaded(ctx, vtcm_f16_act, activation + mr * act_stride, n_rows, k, act_stride, act_threads, k_valid, vtcm_f32_act); - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + // Prologue: push A0 and optionally A1 (if n_chunk_cnt > 1) + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight), aligned_tile_size, tile_size, tile_size, (n_cols_A0 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight), row_stride, weight_stride, row_stride, n_cols_A0); + } - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + if (1 < n_chunk_cnt) { + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[1], weight + n_chunk_n_cols * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols_A1 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[1], weight + n_chunk_n_cols * weight_stride), row_stride, weight_stride, row_stride, n_cols_A1); + } + } - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; + // pop A0 -> dequantize A0 -> submit C0 + dma_queue_pop(ctx->dma[0]); + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_weight_bufs[0], vtcm_weight_raw[0], + n_cols_A0, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_f16_act, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HTP_MM_HMX_TILE_N_COLS), k / HTP_MM_HMX_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // Main loop: pop/dequantize A_{i+1} -> push A_{i+2} -> submit C_{i+1} -> wait C_i and store D_i + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // 1. pop A_{i+1} and dequantize it (if i+1 < n_chunk_cnt) + if (i + 1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_weight_bufs[(i + 1) % 2], vtcm_weight_raw[(i + 1) % 2], + n_cols_p1, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + } - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + // 2. push A_{i+2} (if i+2 < n_chunk_cnt) + if (i + 2 < n_chunk_cnt) { + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[(i + 2) % 2], weight + nc_p2 * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols_p2 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[(i + 2) % 2], weight + nc_p2 * weight_stride), row_stride, weight_stride, row_stride, n_cols_p2); + } + } - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f16_au(dst_data, src_data, ne0); + // 3. submit C_{i+1} (if i+1 < n_chunk_cnt) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_f16_act, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HTP_MM_HMX_TILE_N_COLS), k / HTP_MM_HMX_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } - dst_data += dst_stride; - src_data += src_stride; + // 4. wait C_i and store D_i (multi-thread HVX, parallel with C_{i+1}) + hmx_queue_pop(ctx->hmx_queue); + float *output_chunk = dst + (mr * dst_stride + nc); + int chunk_dst_cols = dst_cols - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, dst_stride, chunk_dst_cols, n_threads); + } + } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Un-pipelined loop (m <= 32 or fallback) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + transfer_activation_chunk_threaded(ctx, vtcm_f16_act, activation + mr * act_stride, n_rows, k, act_stride, act_threads, k_valid, vtcm_f32_act); + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HTP_MM_HMX_TILE_N_COLS); + + // A: Weight DMA (Synchronous) + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight + nc * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight + nc * weight_stride), row_stride, weight_stride, row_stride, n_cols); + } + dma_queue_pop(ctx->dma[0]); + + // B: Weight Dequantize (Threaded) + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_scratch0, vtcm_weight_raw[0], + n_cols, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_f16_act, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HTP_MM_HMX_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * dst_stride + nc); + int chunk_dst_cols = dst_cols - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, dst_stride, chunk_dst_cols, n_threads); + } + } + } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); + return 0; } - -static inline bool htp_is_permuted(const struct htp_tensor * t) { - return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; +static inline int hmx_mm_batch_r2(const hmx_mm_f16_f32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; } -static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { - switch (type) { - case HTP_TYPE_Q4_0: - mmctx->type = "q4x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; - return 0; - case HTP_TYPE_Q4_1: - mmctx->type = "q4_1x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; - return 0; - case HTP_TYPE_Q8_0: - mmctx->type = "q8x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; - return 0; - case HTP_TYPE_IQ4_NL: - mmctx->type = "iq4nlx4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; - return 0; - case HTP_TYPE_MXFP4: - mmctx->type = "mxfp4x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; - return 0; - default: - return -1; - } +static inline int hmx_mm_batch_r3(const hmx_mm_f16_f32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; } -static void htp_mminit_spad(struct htp_ops_context * octx, - size_t dst_row_size, - size_t src0_row_size_padded, - size_t src1_row_size, - uint32_t src1_nrows, - size_t src2_spad_size_per_thread) { - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - if (src2_spad_size_per_thread > 0) { - octx->src2_spad.size_per_thread = src2_spad_size_per_thread; - octx->src2_spad.size = octx->src2_spad.size_per_thread; - } +static inline const __fp16 *hmx_mm_weight_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_mm_batch_r2(params); + const int r3 = hmx_mm_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} - // src0 spad is also used in dynamic quantizer to store padded src1 rows - size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } +static inline const float *hmx_mm_activation_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; +static inline float *hmx_mm_dst_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); } -static int op_matmul_hvx(struct htp_ops_context * octx) { - htp_matmul_tensors_preamble; +static int hmx_mm_f16_f32_batched_simple(struct htp_context *ctx, + const hmx_mm_f16_f32_batched_params_t *params, + int m_chunk, int n_chunk, int pipeline, int n_threads, int act_threads, int vtcm_size) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mm_2d_f32(ctx, hmx_mm_dst_batch_ptr(params, b2, b3), + hmx_mm_activation_batch_ptr(params, b2, b3), + (const uint8_t *)hmx_mm_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride * (int)sizeof(__fp16), + HTP_TYPE_F16, params->k, params->n, params->n, + m_chunk, n_chunk, pipeline, n_threads, act_threads, + 0, 0, vtcm_size); + } + } + return ret; +} + +static int hmx_mm_f16_f32_batched(struct htp_context *ctx, const hmx_mm_f16_f32_batched_params_t *params, + int m_chunk, int n_chunk, int pipeline, int n_threads, int act_threads, int vtcm_size) { + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + if (!hex_is_aligned(params->dst, VLEN) || !hex_is_aligned(params->activation, VLEN)) { return -1; } + + const int group_size = hmx_mm_batch_r2(params); + const size_t vtcm_budget = ctx->vtcm_size; + + // Check if the precomputed parameters are grouped or simple. + // If simple, or if group_size <= 1, we use simple fallback loop. + // Grouped path is only valid if group_size > 1 and it fits within VTCM budget. + bool run_grouped = (group_size > 1 && (size_t)vtcm_size <= vtcm_budget); + if (!run_grouped) { + return hmx_mm_f16_f32_batched_simple(ctx, params, m_chunk, n_chunk, pipeline, n_threads, act_threads, vtcm_size); + } + + const size_t vec_dot_size = params->k * sizeof(__fp16); + + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up((size_t)act_threads * HTP_MM_DMA_ACT_MULTIPLIER * (size_t) params->k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + + size_t m_chunk_n_rows = m_chunk; + size_t n_chunk_n_cols = n_chunk; + size_t vtcm_used = vtcm_size; + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to simple batched loop", __func__); + return hmx_mm_f16_f32_batched_simple(ctx, params, m_chunk, n_chunk, pipeline, n_threads, act_threads, vtcm_size); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_mm_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HTP_MM_HMX_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_mm_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_f16_act + (size_t) g * act_head_stride; + if (use_dma_activation) { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, act_threads, params->k, vtcm_f32_act); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, act_threads, params->k, NULL); + } + } - struct htp_matmul_context mmctx_struct = {0}; - struct htp_matmul_context * mmctx = &mmctx_struct; - mmctx->octx = octx; + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; - const uint32_t src0_nrows = ne01 * ne02 * ne03; - const uint32_t src1_nrows = ne11 * ne12 * ne13; + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } - // Compute src0_nrows_per_thread - mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HTP_MM_HMX_TILE_N_COLS); + + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + struct htp_thread_trace * tr = &ctx->trace[HTP_MAX_NTHREADS]; + htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, g); + { + const __fp16 * vtcm_act_g = vtcm_f16_act + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, g); + + { + float *output = hmx_mm_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + int chunk_dst_cols = params->n - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, chunk_dst_cols, ctx->n_threads); + } + } + } + } + } + } + } - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - size_t src1_row_size = nb11; + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); - size_t src1_row_size_padded; + return 0; +} - worker_callback_t quant_job_func; - worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads, + int k_valid) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, 2); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = ne11 > 1 ? init_fastdiv_values(ne11) : (struct fastdiv_values){0, 0}, + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + .k_valid = k_valid, + .traces = ctx->trace, + }; + + worker_callback_t worker_fn = ne11 == 1 ? transfer_activation_chunk_gathered_worker_flat_fn : + transfer_activation_chunk_gathered_worker_fn; + + if (actual_threads <= 1) { + worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, worker_fn, &state, actual_threads); + } +} + +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, 2); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + .traces = ctx->trace, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} + +static int hmx_mm_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *weight, + int m, int k, int n, + int k_valid, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); + + if (k % 32 != 0 || n % 32 != 0) { return -1; } + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN)) { return -1; } + + size_t row_stride = htp_mm_get_tiled_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_tiled_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_tiled_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_tiled_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_tiled_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_tiled_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } - bool need_quant = true; + const int n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - if (src0->type == HTP_TYPE_F16) { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); - - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; - mmctx->type = "f16-f16"; - mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; - mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; - mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; - - src1_row_size = f16_src1_row_size; // row size post quantization - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - mmctx->type = "f16-f32"; - mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; - matmul_job_func = matmul_4d; - } else { - mmctx->type = "f16-f16"; - mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; - matmul_job_func = matmul_4d; - } + const int n_threads = ctx->n_threads; + const bool is_quant = (weight_type != HTP_TYPE_F16 && weight_type != HTP_TYPE_F32); - src1_row_size = nb11; // original row size in DDR + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + int tile_size = htp_mm_get_weight_tile_size(weight_type); + int aligned_tile_size = htp_mm_get_weight_aligned_tile_size(weight_type); - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + const size_t weight_row_stride = is_quant ? qweight_row_stride : row_stride; - // Init fastdiv for matmul_4d (supports broadcasting) - mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + size_t size_per_n = 0, size_per_m = 0, size_per_mn = 0; + htp_mm_hmx_get_2d_chunk_costs(weight_type, k, /*pipeline=*/false, aligned_tile_size, + &size_per_n, &size_per_m, &size_per_mn); - need_quant = false; - } - } else if (src0->type == HTP_TYPE_F32) { - // Try optimized f32-f32 path first (src1 in VTCM) - const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); - const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); - const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; - - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); - - if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { - // Optimized path - quant_job_func = quantize_f32_f32; - mmctx->type = "f32-f32"; - mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; - mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; - mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; - - src1_row_size = f32_src1_row_size; - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to DDR / broadcasting - quant_job_func = NULL; - mmctx->type = "f32-f32"; - mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; - matmul_job_func = matmul_4d; + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (htp_mm_hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, size_per_m, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * HTP_MM_HMX_COST_W_DEQUANT, + /*n_block_cost=*/(size_t) m_padded * HTP_MM_HMX_COST_A_CONVERT, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(ERROR, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } - src1_row_size = nb11; + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * weight_row_stride, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = weight_area_size ? (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size) : NULL; + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - // Init fastdiv for matmul_4d (supports broadcasting) - mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; + } - need_quant = false; - } - } else { - if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { - return HTP_STATUS_NO_SUPPORT; - } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - if (src0->type == HTP_TYPE_Q4_1) { - quant_job_func = quantize_f32_q8_1x4x2; - src1_row_size = q8_1x4x2_row_size(ne10); - } else { - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); - } - htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); - } + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - // VTCM scratchpads for all tensors - size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS); - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_f16_act, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, n_threads, k_valid); - FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], - src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], - dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HTP_MM_HMX_TILE_N_COLS); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, - octx->ctx->vtcm_size, spad_size); - return HTP_STATUS_VTCM_TOO_SMALL; + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, weight + nc * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, weight + nc * weight_stride), row_stride, weight_stride, row_stride, n_cols); + } + dma_queue_pop(ctx->dma[0]); + + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_scratch0, vtcm_weight, + n_cols, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads + ); + + struct htp_thread_trace * tr = &ctx->trace[HTP_MAX_NTHREADS]; + htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, nc); + core_dot_chunk_fp16(vtcm_output, vtcm_f16_act, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HTP_MM_HMX_TILE_N_ROWS); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, nc); + + transfer_output_chunk_scattered_threaded( + ctx, dst + nc, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, n_threads); + } } - // Place src1 spad first. We use it for dyn.quant and may reuse between ops - octx->src1_spad.data = octx->ctx->vtcm_base; - octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + return 0; +} - octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; - octx->src0_spad.src = NULL; - octx->dst_spad.src = NULL; - octx->src0_spad.stride = src0_row_size_padded; - octx->src1_spad.stride = src1_row_size; +// --- Dispatchers and Public Entry Points --- - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) - return HTP_STATUS_OK; +static int hmx_mm_op_matmul(struct htp_ops_context * octx, const struct htp_mm_kernel_params * kparams) { + htp_matmul_tensors_preamble; - if (need_quant && !octx->src1_spad.src) { - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->src1_spad.src = src1; + int k = (int) src0->ne[0]; + int n = (int) src0->ne[1]; + const int m_total = (int) src1->ne[1]; + const int act_stride = (int)(src1->nb[1] / sizeof(float)); + const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; } - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + int ret = -1; + const int n_threads = MIN(kparams->n_threads, (int) octx->n_threads); + if (kparams->kernel_type == HTP_MM_KERNEL_HMX_F16_BATCHED) { + hmx_mm_f16_f32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .weight = (const __fp16 *) src0->data, + .m = m_total, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_mm_f16_f32_batched(octx->ctx, &batch_params, + kparams->m_chunk, kparams->n_chunk, + kparams->pipeline, n_threads, + kparams->n_act_threads, + kparams->vtcm_size); + } else { + ret = hmx_mm_2d_f32( + octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type, (int) src1->ne[0], + (int)(dst->nb[1] / sizeof(float)), (int)dst->ne[0], + kparams->m_chunk, kparams->n_chunk, kparams->pipeline, n_threads, + kparams->n_act_threads, + kparams->tile_size, kparams->aligned_tile_size, kparams->vtcm_size + ); + } + if (ret != 0) { + FARF(ERROR, "HMX matmul failed (ret=%d)\n", ret); + return HTP_STATUS_INTERNAL_ERR; + } return HTP_STATUS_OK; } int op_matmul(struct htp_ops_context * octx) { - htp_matmul_tensors_preamble; - -#ifndef HTP_HAS_HMX - return op_matmul_hvx(octx); -#else - if (!octx->ctx->hmx_enabled) { - return op_matmul_hvx(octx); - } + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; - // HMX weight tile requires N to be 32-aligned. - if (src0->ne[1] % 32 != 0) { - return op_matmul_hvx(octx); + if (kparams->n_hmx) { + return hmx_mm_op_matmul(octx, kparams); } - // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. - // Other types fall back to HVX. - uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { - return op_matmul_hvx(octx); - } + return hvx_mm_matmul(octx); +} - // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 and F32 HMX paths require K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { - return op_matmul_hvx(octx); - } +static int hmx_mm_op_matmul_id( + struct htp_ops_context * octx, + struct htp_mm_context * mmctx, + const uint32_t * matrix_row_counts, + const struct mmid_row_mapping * matrix_rows, + void * mapping_buf, + bool must_free_mapping +) { + htp_matmul_tensors_preamble; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const int n_ids = octx->src[2]->ne[0]; + const int n_as = ne02; - if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { - return op_matmul_hvx(octx); + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_mm_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne10, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * octx->src[2]->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } } - const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; +} - // Quantised HMX kernels only handle flat 2D matmul (host already rejects - // batched quantised, but guard here too). F16 batched matmul is handled - // by the dedicated wrapper in hmx-matmul-ops.c. - if (is_batched && src0->type != HTP_TYPE_F16) { - return op_matmul_hvx(octx); - } +static int hvx_mm_op_matmul_id( + struct htp_ops_context * octx, + struct htp_mm_context * mmctx, + size_t src0_row_size_padded, + uint32_t src1_nrows, + worker_callback_t matmul_id_job_func, + void * mapping_buf, + bool must_free_mapping +) { + htp_matmul_tensors_preamble; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const struct htp_tensor * restrict ids = octx->src[2]; + const size_t src0_row_size = nb01; - // HMX assumes contiguous row-major layout. Fall back for permuted - // tensors where strides are non-monotonic (e.g. transposed KV cache). - if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { - return op_matmul_hvx(octx); - } + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne10 + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; - // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) - // is handled by HMX itself; when M < 32 fall back to HVX. - const int m_total = (int) src1->ne[1]; - const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { - return op_matmul_hvx(octx); + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; } + size_t src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); - // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, - // so any previously cached quantized data is invalid. - octx->src1_spad.src = NULL; + // Scratchpad sizes are computed on the host (htp_mm_hvx_id_get_vtcm_sizes) and passed in. + // The ID layout is routing-independent, so the host has exact visibility -- consume it here + // rather than recomputing, to keep host budgeting and device allocation in lockstep. + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = 0; // mapping lives in DDR + size_t dst_sz = 0; // ID kernels scatter straight to DDR + size_t vtcm_size = kparams->vtcm_size; - int k = (int) src0->ne[0]; // inner dimension - int n = (int) src0->ne[1]; // weight columns + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = 0; + size_t dst_sz_per_thread = 0; - int ret = -1; + FARF(HIGH, "matmul-id-%s : src0-spad-size %zu src1-spad-size %zu src2-spad-size %zu dst-spad-size %zu (%zu)\n", mmctx->type, + src0_sz, src1_sz, src2_sz, dst_sz, vtcm_size); - // Row strides in elements. For compact tensors these equal k; for - // permuted attention views they can be larger, so pass the real stride. - const int act_stride = (int)(src1->nb[1] / sizeof(float)); - const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, + src1->data, dst->data); - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { - return HTP_STATUS_OK; + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, vtcm_size); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_VTCM_TOO_SMALL; } - if (is_batched) { - if (src0->type == HTP_TYPE_F16) { - hmx_matmul_f16_f32_batched_params_t batch_params = { - .dst = (float *) dst->data, - .activation = (float *) src1->data, - .permuted_weight = (const __fp16 *) src0->data, - .m = m_total, - .k = k, - .n = n, - .act_stride = act_stride, - .weight_stride = wgt_stride, - .dst_stride = (int) (dst->nb[1] / sizeof(float)), - .ne02 = ne02, - .ne03 = ne03, - .ne12 = ne12, - .ne13 = ne13, - .src0_nb2 = src0->nb[2], - .src0_nb3 = src0->nb[3], - .src1_nb2 = src1->nb[2], - .src1_nb3 = src1->nb[3], - .dst_nb2 = dst->nb[2], - .dst_nb3 = dst->nb[3], - }; - ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); - } else { - return op_matmul_hvx(octx); - } - } else { - ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); - } + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + mmctx->vtcm_dst = vtcm_seq_alloc(&vtcm_ptr, dst_sz); - if (ret != 0) { - FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); - return op_matmul(octx); - } + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; - return 0; -#endif // HTP_HAS_HMX + mmctx->vtcm_src0_stride = src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + mmctx->vtcm_dst_size_per_thread = dst_sz_per_thread; + + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; } int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_matmul_context mmctx_struct = {0}; - struct htp_matmul_context * mmctx = &mmctx_struct; + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; mmctx->octx = octx; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; @@ -4839,14 +3260,11 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src1_nrows = ne11 * ne12 * ne13; worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? hvx_mm_id : hvx_mv_id; // Compute src0_nrows_per_thread mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even - - size_t src1_row_size; - size_t src1_row_size_padded; + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); // row groups const int n_ids = ids->ne[0]; // n_expert_used @@ -4875,130 +3293,324 @@ int op_matmul_id(struct htp_ops_context * octx) { mmctx->matrix_row_counts = matrix_row_counts; mmctx->matrix_rows = matrix_rows; + mmctx->mm_div_ne11 = kparams->div_ne11; - if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } - if (src0->type == HTP_TYPE_Q4_1) { - quant_job_func = quantize_f32_q8_1x4x2; - src1_row_size = q8_1x4x2_row_size(ne10); + if (src1_nrows > 1) { + // initialize matrix_row_counts and map + memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); + + // group rows by src0 matrix + for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx + for (uint32_t id = 0; id < n_ids; ++id) { // expert idx + const int32_t i02 = *(const int32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + if (i02 < 0) { + continue; + } + assert(i02 < n_as); + + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } + + if (kparams->n_hmx) { + return hmx_mm_op_matmul_id(octx, mmctx, matrix_row_counts, matrix_rows, mapping_buf, must_free_mapping); + } + + return hvx_mm_op_matmul_id(octx, mmctx, src0_row_size_padded, src1_nrows, matmul_id_job_func, mapping_buf, must_free_mapping); +} + +int op_matmul_qkv(struct htp_ops_context * octx) { + const struct htp_tensor * restrict src0 = octx->src[0]; // Wk + const struct htp_tensor * restrict src1 = octx->src[1]; // x + const struct htp_tensor * restrict src2 = octx->src[2]; // Wv + const struct htp_tensor * restrict src3 = octx->src[3]; // Wq + const struct htp_tensor * restrict dst_k = octx->dsts[0]; + const struct htp_tensor * restrict dst_v = octx->dsts[1]; + const struct htp_tensor * restrict dst_q = octx->dsts[2]; + + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); } else { - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even } - const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! - htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); - size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, - octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (src1->ne[0] + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, - src1->data, dst->data); + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + } else if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); - if (must_free_mapping) free(mapping_buf); + size_t src1_row_size; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(src1->ne[0]) : htp_mm_q8_0_flat_row_size(src1->ne[0]); + } else { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(src1->ne[0]) : htp_mm_q8_0_tiled_row_size(src1->ne[0]); + } + + // Set up scratchpads using precomputed sizes from the host + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = kparams->vtcm_src2_size; + size_t src3_sz = kparams->vtcm_src3_size; + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = src2_sz / octx->n_threads; + size_t src3_sz_per_thread = src3_sz / octx->n_threads; + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-qkv: current VTCM reservation %zu is too small, needed %zu\n", + octx->ctx->vtcm_size, vtcm_size); return HTP_STATUS_VTCM_TOO_SMALL; } - // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. - octx->src1_spad.data = octx->ctx->vtcm_base; - octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + mmctx->vtcm_src3 = vtcm_seq_alloc(&vtcm_ptr, src3_sz); - octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src1_spad.src = NULL; octx->src0_spad.src = NULL; octx->src2_spad.src = NULL; - octx->dst_spad.src = NULL; + octx->src3_spad.src = NULL; - octx->src0_spad.stride = src0_row_size_padded; - octx->src1_spad.stride = src1_row_size; + mmctx->vtcm_src0_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src2_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src3_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; - if (src1_nrows > 1) { - // initialize matrix_row_counts and map - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + mmctx->vtcm_src3_size_per_thread = src3_sz_per_thread; - // group rows by src0 matrix - for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx - for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; - assert(i02 >= 0 && i02 < n_as); + // Run quantization once + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; - matrix_row_counts[i02] += 1; + // Run fused matmul + const uint32_t n_matmul_jobs = octx->n_threads; + worker_callback_t matmul_job_func; + if (is_repacked) { + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_qkv_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_qkv_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_qkv_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_qkv_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; } } + } else { + matmul_job_func = hvx_mm_qkv_2d; } + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_OK; + return HTP_STATUS_OK; +} + +int op_matmul_ffn(struct htp_ops_context * octx) { + const struct htp_tensor * restrict src0 = octx->src[0]; // Wgate + const struct htp_tensor * restrict src1 = octx->src[1]; // y + const struct htp_tensor * restrict src2 = octx->src[2]; // Wup + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; + const struct htp_tensor * restrict dst_up = octx->dsts[1]; + + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even } - bool hmx_eligible = false; -#ifdef HTP_HAS_HMX - if (octx->ctx->hmx_enabled && src1_nrows > 1) { - uint32_t wtype = src0->type; - if (ne01 % 32 == 0 && - (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { - if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { - hmx_eligible = true; - } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { - hmx_eligible = true; - } - } + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; } -#endif - mmctx->hmx_eligible = hmx_eligible; - - if (hmx_eligible) { - for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { - const int32_t cne1 = matrix_row_counts[cur_a]; - if (cne1 == 0) continue; - - int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, - (const uint8_t *) src0->data + cur_a * nb02, - cne1, ne00, ne01, - ne11, - nb11, nb12, - nb1, nb2, - (int) src0->nb[1], (int) src0->type, - matrix_rows, cur_a, n_ids * ids->ne[1]); - if (ret != 0) { - FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_NO_SUPPORT; - } - } + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (src1->ne[0] + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; - // HMX has overwritten VTCM, so force dynamic quantization cache to clear - octx->src1_spad.src = NULL; + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + } else if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * (ith + 0)) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_OK; + size_t src1_row_size; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(src1->ne[0]) : htp_mm_q8_0_flat_row_size(src1->ne[0]); + } else { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(src1->ne[0]) : htp_mm_q8_0_tiled_row_size(src1->ne[0]); } - if (octx->src1_spad.src != src1) { - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->src1_spad.src = src1; + // Set up scratchpads using precomputed sizes from the host + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = kparams->vtcm_src2_size; + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = src2_sz / octx->n_threads; + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-ffn: current VTCM reservation %zu is too small, needed %zu\n", octx->ctx->vtcm_size, vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; } + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + + mmctx->vtcm_src0_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src2_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + // Run quantization once + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + + // Run fused matmul const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + worker_callback_t matmul_job_func; + if (is_repacked) { + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_ffn_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_ffn_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_ffn_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_ffn_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + } else { + matmul_job_func = hvx_mm_ffn_2d; + } + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); - if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.h b/ggml/src/ggml-hexagon/htp/matmul-ops.h new file mode 100644 index 00000000000..a94d5430dab --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.h @@ -0,0 +1,508 @@ +#ifndef HTP_MATMUL_OPS_H +#define HTP_MATMUL_OPS_H + +#include +#include +#include "htp-ops.h" +#include "hex-fastdiv.h" +#include "hex-common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// --- HMX Tile Constraints --- +#define HTP_MM_HMX_TILE_N_COLS 32 +#define HTP_MM_HMX_TILE_N_ROWS 32 +#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes +#define HTP_MM_HMX_TILE_N_ELMS 1024 +#define HTP_MM_HMX_MIN_NROWS 4 + +// --- Weight Repacked Tile Sizes --- +#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576 +#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640 +#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088 +#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576 +#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544 + +// --- Weight Repacked Aligned Tile Sizes --- +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640 + +// --- Activation Tiled Block Sizes (including padding) --- +#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152 +#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280 + +#define HTP_MM_MAX_PREFETCH 16 + +// --- Solver Cost Model Penalty Weights (HMX-specific) --- +#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization +#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion + +// --- DMA Activation Transfer Configuration --- +#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2 +#define HTP_MM_DMA_ACT_MULTIPLIER 4 + +enum htp_mm_kernel_type { + HTP_MM_KERNEL_UNSUPPORTED = 0, + + // HMX paths + HTP_MM_KERNEL_HMX_2D, + HTP_MM_KERNEL_HMX_F16_BATCHED, + + // HVX floating-point paths + HTP_MM_KERNEL_HVX_F16_F16_VTCM, + HTP_MM_KERNEL_HVX_F16_F16_DDR, + HTP_MM_KERNEL_HVX_F16_F32_DDR, + + HTP_MM_KERNEL_HVX_F32_F32_VTCM, + HTP_MM_KERNEL_HVX_F32_F32_DDR, + HTP_MM_KERNEL_HVX_F32_F16_DDR, + + // HVX quantized paths + HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization + HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization + HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization +}; + +// Op-specific struct for precomputed matmul params +struct htp_mm_kernel_params { + int32_t kernel_type; // enum htp_mm_kernel_type + int32_t pipeline; // 1 = pipelined execution, 0 = standard + int32_t m_chunk; // Row chunk size (M chunk) + int32_t n_chunk; // Col chunk size (N chunk) + int32_t n_threads; // Number of threads to spawn + int32_t n_act_threads; // Number of threads for activation preparation + int32_t n_hmx; // 1 = use HMX, 0 = use HVX + int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM + int32_t tile_size; // Weight tile size + int32_t aligned_tile_size; // Aligned weight tile size (padded to 128) + int32_t src1_row_size; // Row size for quantized activation + int32_t vtcm_size; // Total required scratchpad size in VTCM + int32_t vtcm_src0_size; // src0 scratchpad size in VTCM + int32_t vtcm_src1_size; // src1 scratchpad size in VTCM + int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only) + int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only) + int32_t vtcm_dst_size; // dst scratchpad size in VTCM + + // Precomputed division values + struct fastdiv_values div_ne12_ne1; + struct fastdiv_values div_ne1; + struct fastdiv_values div_r2; + struct fastdiv_values div_r3; + struct fastdiv_values div_ne11; +}; + +#if defined(__cplusplus) +static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob"); +#else +_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob"); +#endif + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + size_t m, + size_t n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { + if (m == 0 || n == 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS); + for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) { + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS); + mc = hex_smin(mc, m); + + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hex_mul_overflow(best_m, best_n, &mn)) return -1; + if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hex_add_overflow(t0, t1, &total)) return -1; + if (hex_add_overflow(total, t2, &total)) return -1; + if (hex_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// --- Tile Size Helpers --- +static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) { + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + case HTP_TYPE_Q4_1: + return HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + case HTP_TYPE_Q8_0: + return HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + case HTP_TYPE_MXFP4: + return HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + default: + return 0; + } +} + +static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) { + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0; + case HTP_TYPE_Q4_1: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1; + case HTP_TYPE_Q8_0: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0; + case HTP_TYPE_MXFP4: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4; + default: + return 0; + } +} + +// --- Activation/Row Size Helpers --- +static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) { + const uint32_t ne_padded = ((ne + 127) / 128) * 128; + const uint32_t nb_32 = ne_padded / 32; + return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0; +} + +static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) { + const uint32_t ne_padded = ((ne + 127) / 128) * 128; + const uint32_t nb_32 = ne_padded / 32; + return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1; +} + +static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) { + const uint32_t quants_size = hex_align_up(ne, 128); + const uint32_t num_scales = (ne + 31) / 32; + const uint32_t scales_size = hex_align_up(num_scales * 2, 128); + return quants_size + scales_size; +} + +static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) { + const uint32_t quants_size = hex_align_up(ne, 128); + const uint32_t num_scales = (ne + 31) / 32; + const uint32_t scales_size = hex_align_up(num_scales * 4, 128); + return quants_size + scales_size; +} + +static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) { + uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + case HTP_TYPE_Q4_1: + case HTP_TYPE_Q8_0: + case HTP_TYPE_MXFP4: + return (size_t) nb * htp_mm_get_weight_tile_size(weight_type); + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); + default: + return 0; + } +} + +static inline size_t htp_mm_round_up(size_t n, size_t m) { + return ((n + m - 1) / m) * m; +} + +static inline bool htp_mm_hmx_pipeline(uint32_t m) { + return m > 32; +} + +static inline void htp_mm_hmx_get_2d_chunk_costs( + int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size, + size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out +) { + const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32); + const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k); + const size_t vec_dot_size = k * sizeof(uint16_t); + const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + + *size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) + + (pipeline ? 2 * vec_dot_size : vec_dot_size); + *size_per_m_out = vec_dot_size; + *size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t); +} + +static inline void htp_mm_hmx_get_batched_chunk_costs( + uint32_t k, uint32_t group_size, + size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out +) { + const size_t vec_dot_size = k * sizeof(uint16_t); + *size_per_n_out = 3 * vec_dot_size; + *size_per_m_out = group_size * vec_dot_size; + *size_per_mn_out = sizeof(uint16_t); +} + +static inline size_t htp_mm_hmx_get_2d_vtcm_size( + int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size +) { + const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32); + const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k); + const size_t vec_dot_size = k * sizeof(uint16_t); + + const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE); + size_t weight_area_size = is_quant + ? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE) + : htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE); + if (pipeline) { + weight_area_size *= 2; + } + const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + size_t scratch1_size = pipeline ? scratch0_size : 0; + size_t scratch2_size = pipeline ? output_area_size : 0; + + return weight_area_size + act_area_size + act_f32_size + output_area_size + + scratch0_size + scratch1_size + scratch2_size + 256; +} + +static inline size_t htp_mm_hmx_get_batched_vtcm_size( + int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) { + (void)wtype; + (void)pipeline; + const size_t vec_dot_size = k * sizeof(uint16_t); + const size_t f32_scratch_size = use_dma_activation + ? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + + const size_t act_head_stride = mc * k; + const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + return weight_area_size + act_area_size + output_area_size + + 2 * scratch_area_size + 256 + f32_scratch_size; +} + +static inline size_t htp_mm_hvx_get_vtcm_sizes( + int kernel_type, + int wtype, + uint32_t ne10, // k + uint32_t src1_nrows, // m_total (or act_nrows) + uint32_t n_threads, + size_t dst_row_size, + size_t src0_row_size, + size_t src1_row_size, + uint32_t n_prefetch, + size_t * vtcm_src0_size_out, + size_t * vtcm_src1_size_out, + size_t * vtcm_dst_size_out +) { + size_t vtcm_src0_size = 0; + size_t vtcm_src1_size = 0; + size_t vtcm_dst_size = 0; + + const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || + wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || + wtype == HTP_TYPE_MXFP4); + + const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128); + const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1; + + switch (kernel_type) { + case HTP_MM_KERNEL_HVX_F16_F16_VTCM: { + size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128); + vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256); + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_F16_F32_DDR: + case HTP_MM_KERNEL_HVX_F16_F16_DDR: + case HTP_MM_KERNEL_HVX_F32_F32_DDR: + case HTP_MM_KERNEL_HVX_F32_F16_DDR: { + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads; + vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_F32_F32_VTCM: { + size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128); + vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256); + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_QUANT_BLOCK: + case HTP_MM_KERNEL_HVX_QUANT_ROW: { + size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0; + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float)); + if (vtcm_src0_size < src1_row_size_padded) { + vtcm_src0_size = src1_row_size_padded; + } + + vtcm_src0_size = vtcm_src0_size * n_threads; + vtcm_dst_size = vtcm_dst_size * n_threads; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = ne10 / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + vtcm_src0_size = repacked_vtcm_size * n_threads; + } + break; + } + case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: { + size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0; + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256); + + size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256); + if (vtcm_src0_size < src1_row_size_padded) { + vtcm_src0_size = src1_row_size_padded; + } + + vtcm_src0_size = vtcm_src0_size * n_threads; + vtcm_dst_size = vtcm_dst_size * n_threads; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = ne10 / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + vtcm_src0_size = repacked_vtcm_size * n_threads; + } + break; + } + default: + break; + } + + *vtcm_src0_size_out = vtcm_src0_size; + *vtcm_src1_size_out = vtcm_src1_size; + *vtcm_dst_size_out = vtcm_dst_size; + + return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size; +} + +static inline size_t htp_mm_hvx_id_get_vtcm_sizes( + int wtype, + uint32_t ne10, // k + uint32_t src1_nrows, + uint32_t n_threads, + size_t src0_row_size, // nb01 + uint32_t n_prefetch, + size_t * vtcm_src0_size_out, + size_t * vtcm_src1_size_out +) { + const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || + wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || + wtype == HTP_TYPE_MXFP4); + + const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128); + const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) + : htp_mm_q8_0_tiled_row_size(ne10); + + size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256); + + // src0 spad also holds temporary transposed src1 columns during dynamic quantization. + const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + if (src0_sz_per_thread < src1_row_size_padded) { + src0_sz_per_thread = src1_row_size_padded; + } + + if (is_repack) { + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + const uint32_t n_k_tiles = ne10 / 32; + const uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + } + + const size_t vtcm_src0_size = src0_sz_per_thread * n_threads; + + *vtcm_src0_size_out = vtcm_src0_size; + *vtcm_src1_size_out = src1_sz; + + return vtcm_src0_size + src1_sz; +} + +#ifdef __cplusplus +} +#endif + +#endif // HTP_MATMUL_OPS_H diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 39cefcdda38..874dde1b887 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -14,8 +14,6 @@ Drivers_Dir = 13 1 = %DiskId% [SourceDisksFiles] -libggml-htp-v68.so = 1 -libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 libggml-htp-v79.so = 1 @@ -28,8 +26,6 @@ ExcludeFromSelect = * CopyFiles=Drivers_Dir [Drivers_Dir] -libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE -libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE From 3d57322a6e152b5894410ca1296f0dcdd7ef0566 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 24 Jun 2026 19:21:25 -0700 Subject: [PATCH 17/30] opencl: support non-contig rows in norm (llama/24965) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 21 ++++++++------------- ggml/src/ggml-opencl/kernels/norm.cl | 7 +++++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 5ad8d76fa51..fb330e06251 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -10152,14 +10152,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); const int nth = MIN(64, ne00); @@ -10173,11 +10167,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl index 170f822787b..a5ccac24137 100644 --- a/ggml/src/ggml-opencl/kernels/norm.cl +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -24,6 +24,7 @@ kernel void kernel_norm( int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, @@ -43,7 +44,8 @@ kernel void kernel_norm( // parallel sum sum[get_local_id(0)] = 0.0f; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - sum[get_local_id(0)] += x[i00]; + // this kernel handles float, nb00/4 translates byte offset to element offset + sum[get_local_id(0)] += x[i00*nb00/4]; } // reduce barrier(CLK_LOCAL_MEM_FENCE); @@ -60,7 +62,8 @@ kernel void kernel_norm( global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; sum[get_local_id(0)] = 0.0f; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = x[i00] - mean; + // this kernel handles float, nb00/4 translates byte offset to element offset + y[i00] = x[i00*nb00/4] - mean; sum[get_local_id(0)] += y[i00] * y[i00]; } From bb6b2ae89f07d6b7ae38077fda0ac722c45dc386 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Thu, 25 Jun 2026 13:27:58 +0800 Subject: [PATCH 18/30] sycl : fix the failed UT cases of conv_3d (llama/24900) --- ggml/src/ggml-sycl/conv3d.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-sycl/conv3d.cpp b/ggml/src/ggml-sycl/conv3d.cpp index 2fa29f93057..3796562553c 100644 --- a/ggml/src/ggml-sycl/conv3d.cpp +++ b/ggml/src/ggml-sycl/conv3d.cpp @@ -103,8 +103,8 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { // allocate packed arrays: A_packed (k x m), B_packed (k x n) ggml_sycl_pool_alloc A_packed_alloc(ctx.pool()); ggml_sycl_pool_alloc B_packed_alloc(ctx.pool()); - A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float)); - B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float)); + A_packed_alloc.alloc((size_t) knl_n_total * patch_total); + B_packed_alloc.alloc((size_t) knl_n_total * oc); float * A_packed = A_packed_alloc.get(); float * B_packed = B_packed_alloc.get(); @@ -115,10 +115,16 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { // Combined kernel: im2col -> pack A, and pack B simultaneously const char * src1_base = (const char *) src1->data; + const char * src0_base = (const char *) src0->data; const int64_t src1_nb0 = src1->nb[0]; const int64_t src1_nb1 = src1->nb[1]; const int64_t src1_nb2 = src1->nb[2]; const int64_t src1_nb3 = src1->nb[3]; + const int64_t src1_w = src1->ne[0]; + const int64_t src1_h = src1->ne[1]; + const int64_t src1_d = src1->ne[2]; + + const bool src0_is_f32 = (src0->type == GGML_TYPE_F32); // Compute correct strides for src0 as (knl_n_total, oc) matrix const int64_t src0_packed_nb0 = kernel_type_size; @@ -165,7 +171,7 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const int64_t sz = dst_z * s2 + kz * d2 - p2; float val = 0.0f; - if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) { + if (sx >= 0 && sx < src1_w && sy >= 0 && sy < src1_h && sz >= 0 && sz < src1_d) { const int64_t channel_idx = batch_idx * c + ic; const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3; val = *(const float *) ptr; @@ -184,9 +190,9 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const int64_t row = t % k; const int64_t col = t / k; - const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1; + const char * src_ptr = src0_base + row * src0_packed_nb0 + col * src0_packed_nb1; float v; - if (src0->type == GGML_TYPE_F32) { + if (src0_is_f32) { v = *(const float *) src_ptr; } else { v = sycl::vec(*(const sycl::half *) src_ptr).convert()[0]; From c1e9f248421e52f658690c2045ba9cd395e86317 Mon Sep 17 00:00:00 2001 From: David Spruill <62445444+Spruill-1@users.noreply.github.com> Date: Thu, 25 Jun 2026 01:35:21 -0400 Subject: [PATCH 19/30] sycl : support --split-mode tensor (llama/24152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Sycl tp stage1 (llama/1) * SYCL: tensor parallelism (--split-mode tensor) for dual-GPU Adds the comm_init/comm_free/comm_allreduce_tensor trio that the meta-backend queries via get_proc_address to enable backend-specific all-reduce, mirroring the pattern used by ggml-cuda.cu. For N=2 (the common dual-GPU case) implements a degenerate ring all-reduce with two size-branched paths: * Small (nelem < 32768): FP32 direct memcpy + per-device ADD kernel chained via depends_on(memcpy_event). 4 SYCL submissions/call. * Large (nelem >= 32768): BF16-compressed. Each device compresses FP32 -> BF16 in a local outbox, cross-device memcpys to the peer's inbox (HALF the PCIe bytes), then decompresses + adds into the local FP32 partial. 6 SYCL submissions/call but PCIe bytes halved -- wins for any tensor where PCIe dominates kernel time. Threshold and BF16 path pattern mirror the CUDA NCCL allreduce. Storage: ONE persistent uint8_t buffer per device, 4 * nelem bytes (matches both path layouts: FP32 nelem floats; BF16 outbox+inbox = 2 * nelem uint16_t each). Single alloc+free per device keeps the SYCL pool's strict-LIFO invariant trivial. Initial impl handles N=2 FP32 contiguous tensors. Other cases return false, causing the meta-backend to use its generic butterfly fallback. Per-call sync is intentionally omitted. SYCL in-order queue semantics ensure that the meta-backend's next compute on the same per-device queue waits for our final ADD, and the next allreduce's first op on the same persistent buffer waits via the same queue. Only comm_free does an explicit final wait. OneCCL is NOT used: OneCCL 2021.17 hardcodes single-device-per-process in communicator_impl.hpp:47 (condition devices.size() == 1), which is incompatible with llama.cpp's single-process multi-GPU model. Measured on dual Intel Arc Pro B70 (NEO 26.05.x, oneAPI 2025.3 + DPC++ nightly): Llama-3.3-70B Q4_K_M, -sm tensor -fa 1 -ctk f16 -ctv f16: pp512 = 377.08 t/s (vs 313.65 layer mode = +20.2%) tg128 = 17.40 t/s (vs 9.74 layer mode = +78.6%) Qwen3-Coder-Next-80B-A3B Q3_K_M (MoE): pp512 = 216.56 t/s (vs 156.58 meta-backend butterfly = +38.3%) tg128 = 17.60 t/s (vs 14.31 meta-backend butterfly = +23.0%) Qwen3-4B Q4_K_M: pp64 = 984.51 t/s, tg16 = 49.29 t/s Llama-3.3-70B in SYCL TP now comfortably beats production layer mode on both prefill and decode. Coder-Next-80B-A3B (MoE) also wins on both — the BF16 path is what unlocks the many-medium-allreduces prefill pattern. Build/CMake: no changes. No new dependencies. ~210 lines added across ggml-sycl.h and ggml-sycl.cpp. * Fix comments * documentation update to address PR feedback * Bring over my device-to-device memcpy chagnes * move the dev2dev_memcpy calls to the upstream 7-parameter variety * Fix a typo and remove a trailing whitespace --- ggml/include/ggml-sycl.h | 8 + ggml/src/ggml-sycl/ggml-sycl.cpp | 255 +++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) diff --git a/ggml/include/ggml-sycl.h b/ggml/include/ggml-sycl.h index 5ce349a880e..418a7ba978b 100644 --- a/ggml/include/ggml-sycl.h +++ b/ggml/include/ggml-sycl.h @@ -27,6 +27,14 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int de // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); +// Tensor parallelism (--split-mode tensor): comm_init/free/allreduce_tensor +// trio queried by the meta-backend via ggml_backend_reg_get_proc_address. +// See typedefs in ggml/include/ggml-backend.h. Mirrors the CUDA backend's +// pattern (ggml_backend_cuda_comm_*). +GGML_BACKEND_API void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends); +GGML_BACKEND_API void ggml_backend_sycl_comm_free(void * comm_ctx); +GGML_BACKEND_API bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx, struct ggml_tensor ** tensors); + // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d8b83d0e23c..41449db665e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -5859,6 +5859,250 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re return ctx->devices[index]; } +// ========================================================================== +// Tensor parallelism (--split-mode tensor) for the SYCL backend. +// +// The meta-backend invokes these three entry points via get_proc_address: +// * ggml_backend_sycl_comm_init - one-time per-graph setup +// * ggml_backend_sycl_comm_allreduce_tensor - per-allreduce step +// * ggml_backend_sycl_comm_free - tear-down +// +// For N=2 (dual-GPU), this is a degenerate ring allreduce with dual paths +// chosen by tensor size: +// +// * Small (nelem < 32K): FP32 direct memcpy + per-device ADD +// kernel. The kernel depends_on() its corresponding memcpy event +// so it doesn't read partial data. Both devices run in parallel. +// +// * Large (nelem >= 32K): BF16-compressed. Each device compresses +// its FP32 partial to BF16 locally, cross-device memcpys +// to the peer (half the PCI bandwidth), where it is decompressed +// and added into the local FP32 partial. 6 SYCL submissions per +// allreduce (2 compress + 2 memcpy + 2 decompress-add) vs the +// 4 for the small path, but the bandwidth saving > 6 GB/s PCIe x 2 +// dominates for larger tensors. +// +// Storage: A persistent uint8_t buffer per device, sized to +// 4 * nelem bytes. Both paths reinterpret the same bytes (small path +// as nelem floats; large path as outbox + inbox = 2*nelem uint16_t +// each, using the full 4*nelem byte budget either way). Single +// alloc+free per device keeps the SYCL pool's strict-LIFO invariant +// trivial. +// +// For non-(N=2 FP32 contiguous) cases, comm_init or comm_allreduce_tensor +// returns null/false, causing the meta-backend to use its generic +// butterfly all-reduce fallback. +// ========================================================================== + +struct ggml_backend_sycl_comm_context { + std::vector backends; + // ONE persistent per-device byte buffer, 4*nelem bytes. Both the + // FP32 small-tensor path and the BF16 large-tensor path share it + // by reinterpreting. + std::unique_ptr> buf0; + std::unique_ptr> buf1; + int64_t buf_nelem = 0; +}; + +void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends) try { + for (size_t i = 0; i < n_backends; ++i) { + if (!ggml_backend_is_sycl(backends[i])) { + return nullptr; + } + } + + // Initial version: N=2 only. For N!=2, returning null makes the + // meta-backend skip this backend-specific allreduce entirely. + if (n_backends != 2) { + return nullptr; + } + + auto * ctx = new ggml_backend_sycl_comm_context; + ctx->backends.assign(backends, backends + n_backends); + auto * sctx0 = (ggml_backend_sycl_context *) backends[0]->context; + auto * sctx1 = (ggml_backend_sycl_context *) backends[1]->context; + ctx->buf0 = std::make_unique>(sctx0->pool()); + ctx->buf1 = std::make_unique>(sctx1->pool()); + return ctx; +} +catch (const sycl::exception &) { return nullptr; } +catch (...) { return nullptr; } + +void ggml_backend_sycl_comm_free(void * comm_ctx_v) { + auto * comm_ctx = static_cast(comm_ctx_v); + if (comm_ctx == nullptr) { + return; + } + + // Sync both per-device queues so the pool_alloc destructors don't + // return memory still in use by the last kernel. + if (comm_ctx->backends.size() == 2) { + auto * sctx0 = (ggml_backend_sycl_context *) comm_ctx->backends[0]->context; + auto * sctx1 = (ggml_backend_sycl_context *) comm_ctx->backends[1]->context; + try { + sctx0->stream()->wait(); + sctx1->stream()->wait(); + } catch (...) { /* best effort during shutdown */ } + } + + delete comm_ctx; +} + +bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) try { + if (comm_ctx_v == nullptr) { + return false; + } + + auto * comm_ctx = static_cast(comm_ctx_v); + const size_t n_backends = comm_ctx->backends.size(); + + // Fast path: N=2, F32/F16, contiguous, matching shapes. + if (n_backends != 2) { + return false; + } + // Accept F32 or F16 inputs natively (types must match). F16 takes the + // direct 2-byte memcpy + add path below; other types return false so the + // meta-backend uses its generic all-reduce. + if (tensors[0]->type != tensors[1]->type) { + return false; + } + if (tensors[0]->type != GGML_TYPE_F32 && tensors[0]->type != GGML_TYPE_F16) { + return false; + } + if (!ggml_is_contiguous(tensors[0]) || !ggml_is_contiguous(tensors[1])) { + return false; + } + if (ggml_nelements(tensors[0]) != ggml_nelements(tensors[1])) { + return false; + } + + const int64_t nelem = ggml_nelements(tensors[0]); + const size_t nbytes = ggml_nbytes(tensors[0]); + if (nelem == 0) { + return true; + } + + auto * ctx0 = (ggml_backend_sycl_context *) comm_ctx->backends[0]->context; + auto * ctx1 = (ggml_backend_sycl_context *) comm_ctx->backends[1]->context; + queue_ptr q0 = ctx0->stream(); + queue_ptr q1 = ctx1->stream(); + + // Grow per-device byte buffers if needed (4 * nelem bytes each). + if (comm_ctx->buf_nelem < nelem) { + comm_ctx->buf0->realloc(nelem * 4); + comm_ctx->buf1->realloc(nelem * 4); + comm_ctx->buf_nelem = nelem; + } + uint8_t * buf0 = comm_ctx->buf0->get(); + uint8_t * buf1 = comm_ctx->buf1->get(); + + // F16 native path: direct 2-byte cross-device copy + add, skipping the + // F32 round-trip the meta-backend fallback would force. Cross-device copies + // go through dev2dev_memcpy because the two devices are in separate SYCL + // contexts (a raw peer-USM q->memcpy would be a silent no-op). + if (tensors[0]->type == GGML_TYPE_F16) { + sycl::half * f16_out0 = (sycl::half *) tensors[0]->data; + sycl::half * f16_out1 = (sycl::half *) tensors[1]->data; + sycl::half * f16_tmp0 = (sycl::half *) buf0; + sycl::half * f16_tmp1 = (sycl::half *) buf1; + + q0->wait(); + q1->wait(); + dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, f16_tmp0, tensors[1]->data, nbytes); + dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, f16_tmp1, tensors[0]->data, nbytes); + + q0->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + f16_out0[i] = (sycl::half) ((float) f16_out0[i] + (float) f16_tmp0[i]); + }); + }); + q1->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + f16_out1[i] = (sycl::half) ((float) f16_out1[i] + (float) f16_tmp1[i]); + }); + }); + return true; + } + + float * out0 = (float *) tensors[0]->data; + float * out1 = (float *) tensors[1]->data; + + // BF16 threshold: above this, the PCIe savings from halving the + // cross-device bytes outweigh the 2 extra compress kernels. + // Below: stay on the FP32 fast path. Threshold mirrors the CUDA + // NCCL allreduce pattern for n_backends=2. + static constexpr int64_t BF16_THRESHOLD = 32768; + + if (nelem < BF16_THRESHOLD) { + // FP32 small path: 4 SYCL submissions per allreduce. + float * tmp0 = (float *) buf0; + float * tmp1 = (float *) buf1; + + // COMM-D2D-FIX: the two devices are in SEPARATE SYCL contexts, so a raw + // q->memcpy of a peer USM pointer is a silent no-op. Route cross-device + // copies through dev2dev_memcpy (L0 direct copy / host staging). It is + // synchronous, so wait for the local partials to be produced first. + q0->wait(); + q1->wait(); + dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, tmp0, tensors[1]->data, nbytes); + dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, tmp1, tensors[0]->data, nbytes); + + q0->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out0[i] += tmp0[i]; + }); + }); + q1->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out1[i] += tmp1[i]; + }); + }); + return true; + } + + // BF16 large path: 6 SYCL submissions per allreduce, but the + // cross-device memcpy is HALF the bytes. Pure bit-shift + // conversion (no rounding) — matches ggml's truncating fp32->bf16. + uint16_t * outbox0 = (uint16_t *) buf0; + uint16_t * inbox0 = outbox0 + nelem; + uint16_t * outbox1 = (uint16_t *) buf1; + uint16_t * inbox1 = outbox1 + nelem; + + // Phase A: compress each device's local partial in parallel. + sycl::event c0 = q0->parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + outbox0[i] = (uint16_t) (sycl::bit_cast(out0[i]) >> 16); + }); + + sycl::event c1 = q1->parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + outbox1[i] = (uint16_t) (sycl::bit_cast(out1[i]) >> 16); + }); + + // Phase B: COMM-D2D-FIX-BF16 cross-device copy of compressed bytes via + // dev2dev_memcpy (separate SYCL contexts; sync copy after compress). + const size_t bf16_bytes = nelem * sizeof(uint16_t); + c0.wait(); + c1.wait(); + dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, inbox0, outbox1, bf16_bytes); + dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, inbox1, outbox0, bf16_bytes); + + // Phase C: decompress + add into local FP32 partial. + q0->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out0[i] += sycl::bit_cast(((uint32_t) inbox0[i]) << 16); + }); + }); + + q1->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out1[i] += sycl::bit_cast(((uint32_t) inbox1[i]) << 16); + }); + }); + + return true; +} +catch (const sycl::exception &) { return false; } +catch (...) { return false; } + static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) { GGML_UNUSED(reg); @@ -5866,6 +6110,17 @@ static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, cons return (void *)ggml_backend_sycl_split_buffer_type; } + // Tensor parallelism (--split-mode tensor) entry points. + if (strcmp(name, "ggml_backend_comm_init") == 0) { + return (void *)ggml_backend_sycl_comm_init; + } + if (strcmp(name, "ggml_backend_comm_free") == 0) { + return (void *)ggml_backend_sycl_comm_free; + } + if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) { + return (void *)ggml_backend_sycl_comm_allreduce_tensor; + } + // SYCL doesn't support registering host memory, left here for reference // "ggml_backend_register_host_buffer" // "ggml_backend_unregister_host_buffer" From 26fef3fac9c02023af1b3856228125799c2184a8 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 25 Jun 2026 10:06:44 +0200 Subject: [PATCH 20/30] ggml : address integer overflows in binary ops CUDA implementation (llama/24706) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml : address integer overflows in binary ops CUDA implementation * ggml : add size_t casts to avoid integer overflows * ggml : add more asserts checking integer overflows in binary ops CUDA implementation --------- Co-authored-by: Stanisław Szymczyk --- ggml/src/ggml-cuda/binbcast.cu | 136 ++++++++++++++++++++++----------- 1 file changed, 90 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index c25f42b32bb..2e38077bf67 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -34,26 +34,26 @@ template = (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) { return; } @@ -69,25 +69,32 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint32_t i12 = fastmodulo(i2, ne12); const uint32_t i13 = fastmodulo(i3, ne13); - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01; + const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11; + const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + const uint32_t s0 = blockDim.x * gridDim.x; + ggml_cuda_pdl_sync(); - for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + for (uint32_t i0 = i0s; i0 < ne0; i0 += s0) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; + float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10*s10]); + result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]); } dst_row[i0] = (dst_t) result; + + // protect i0 from overflow + if (ne0 - i0 <= s0) { + break; + } } } @@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne12, const uint3 ne13, /*const int s0,*/ - const int s1, - const int s2, - const int s3, - const int s00, - const int s01, - const int s02, - const int s03, - const int s10, - const int s11, - const int s12, - const int s13, + const uint32_t s1, + const uint32_t s2, + const uint32_t s3, + const uint32_t s00, + const uint32_t s01, + const uint32_t s02, + const uint32_t s03, + const uint32_t s10, + const uint32_t s11, + const uint32_t s12, + const uint32_t s13, src1_ptrs... src1s) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; + const uint32_t i = blockDim.x*blockIdx.x + threadIdx.x; const uint32_t i3 = fastdiv(i, prod_012); const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); @@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, return; } - const int i11 = fastmodulo(i1, ne11); - const int i12 = fastmodulo(i2, ne12); - const int i13 = fastmodulo(i3, ne13); + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01; + const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11; + const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - const int i10 = fastmodulo(i0, ne10); + const uint32_t i10 = fastmodulo(i0, ne10); ggml_cuda_pdl_sync(); - float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; + float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10*s10]); + result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]); } dst_row[i0] = (dst_t) result; @@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t s02 = nb02 / sizeof(src0_t); size_t s03 = nb03 / sizeof(src0_t); + GGML_ASSERT(ne0 <= std::numeric_limits::max()); + GGML_ASSERT(ne1 <= std::numeric_limits::max()); + GGML_ASSERT(ne2 <= std::numeric_limits::max()); + GGML_ASSERT(ne3 <= std::numeric_limits::max()); + + //GGML_ASSERT(s0 <= std::numeric_limits::max()); + GGML_ASSERT(s1 <= std::numeric_limits::max()); + GGML_ASSERT(s2 <= std::numeric_limits::max()); + GGML_ASSERT(s3 <= std::numeric_limits::max()); + + GGML_ASSERT(s00 <= std::numeric_limits::max()); + GGML_ASSERT(s01 <= std::numeric_limits::max()); + GGML_ASSERT(s02 <= std::numeric_limits::max()); + GGML_ASSERT(s03 <= std::numeric_limits::max()); + + GGML_ASSERT(s10 <= std::numeric_limits::max()); + GGML_ASSERT(s11 <= std::numeric_limits::max()); + GGML_ASSERT(s12 <= std::numeric_limits::max()); + GGML_ASSERT(s13 <= std::numeric_limits::max()); + + GGML_ASSERT(cne1[0] <= std::numeric_limits::max()); + GGML_ASSERT(cne1[1] <= std::numeric_limits::max()); + GGML_ASSERT(cne1[2] <= std::numeric_limits::max()); + GGML_ASSERT(cne1[3] <= std::numeric_limits::max()); + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); GGML_ASSERT(nb1 % sizeof(dst_t) == 0); GGML_ASSERT(nb2 % sizeof(dst_t) == 0); @@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + GGML_ASSERT(ne2 * ne3 <= std::numeric_limits::max()); + const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); if (block_nums.z > 65535 || block_nums.y > 65535) { - int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + + GGML_ASSERT(block_num <= std::numeric_limits::max()); + GGML_ASSERT(block_num * block_size <= std::numeric_limits::max()); + GGML_ASSERT(ne0 * ne1 <= std::numeric_limits::max()); + GGML_ASSERT(ne0 * ne1 * ne2 <= std::numeric_limits::max()); + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); @@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } else { + GGML_ASSERT(int64_t(block_nums.x) * block_dims.x <= std::numeric_limits::max()); + GGML_ASSERT(int64_t(block_nums.y) * block_dims.y <= std::numeric_limits::max()); + GGML_ASSERT(int64_t(block_nums.z) * block_dims.z <= std::numeric_limits::max()); + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); { const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); From f8b9ba4ce32c9131a06802aac05d0beb82796f85 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Thu, 25 Jun 2026 17:29:23 +0200 Subject: [PATCH 21/30] CUDA: Various fixes to `cpy.cu` (llama/25000) * Add failing test-case to test-backend-ops Extracted from https://github.com/ggml-org/llama.cpp/issues/24072 * Minimize repro with help of AI N = 8 * (65535 - 1) + 1 = 524273 * Port and adjust workaround from https://github.com/LostRuins/koboldcpp/commit/0ba798341e0c70517cb226cb63c966b086a3b5b3 Fall-back should share code, also relax y-z constraint to be inclusive * Add test-case + fallback also for y dim * Fix x-guards which is 2^{31}-1, so inlusive of INT_MAX * Fix overflow problems for transposed copy kernel --- ggml/src/ggml-cuda/cpy.cu | 64 +++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 121472ec228..1e625cc1cbe 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -53,10 +53,10 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const const int64_t nmat = ne / (ne00 * ne01); const int64_t n = ne00 * ne01; - const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; - const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; - const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset - const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + const int64_t x = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; + const int64_t y = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + const int64_t tx = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset + const int64_t ty = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; int cur_tile_buf = 0; @@ -197,7 +197,7 @@ static void ggml_cpy_scalar_contiguous_cuda( cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); ggml_cuda_kernel_launch(cpy_scalar_contiguous, launch_params, cx, cdst, ne); } @@ -208,6 +208,14 @@ static void ggml_cpy_scalar_cuda( const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + const auto launch_scalar_generic = [&]() { + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks <= INT_MAX); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar>, launch_params, + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + }; + if (transposed) { GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed int64_t ne00n, ne01n, ne02n; @@ -224,20 +232,18 @@ static void ggml_cpy_scalar_cuda( int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM; - GGML_ASSERT(grid_x < UINT_MAX); - GGML_ASSERT(grid_y < USHRT_MAX); - GGML_ASSERT(grid_z < USHRT_MAX); - dim3 dimGrid(grid_x, grid_y, grid_z); - dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); - ggml_cuda_kernel_launch(cpy_scalar_transpose, launch_params, - cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + GGML_ASSERT(grid_x <= INT_MAX); + if (grid_y > USHRT_MAX || grid_z > USHRT_MAX) { + launch_scalar_generic(); + } else { + dim3 dimGrid(grid_x, grid_y, grid_z); + dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_transpose, launch_params, + cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } } else { - const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - GGML_ASSERT(num_blocks < UINT_MAX); - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); - ggml_cuda_kernel_launch(cpy_scalar>, launch_params, - cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + launch_scalar_generic(); } } @@ -248,7 +254,7 @@ static void ggml_cpy_f32_q8_0_cuda( GGML_ASSERT(ne % QK8_0 == 0); const int64_t num_blocks = ne / QK8_0; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -259,7 +265,7 @@ static void ggml_cpy_q8_0_f32_cuda( const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int64_t num_blocks = ne; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_q_f32<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -271,7 +277,7 @@ static void ggml_cpy_f32_q4_0_cuda( GGML_ASSERT(ne % QK4_0 == 0); const int64_t num_blocks = ne / QK4_0; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -284,7 +290,7 @@ static void ggml_cpy_q4_0_f32_cuda( const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int64_t num_blocks = ne; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); @@ -297,7 +303,7 @@ static void ggml_cpy_f32_q4_1_cuda( GGML_ASSERT(ne % QK4_1 == 0); const int64_t num_blocks = ne / QK4_1; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -310,7 +316,7 @@ static void ggml_cpy_q4_1_f32_cuda( const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int64_t num_blocks = ne; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); @@ -323,7 +329,7 @@ static void ggml_cpy_f32_q5_0_cuda( GGML_ASSERT(ne % QK5_0 == 0); const int64_t num_blocks = ne / QK5_0; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -336,7 +342,7 @@ static void ggml_cpy_q5_0_f32_cuda( const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int64_t num_blocks = ne; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); @@ -349,7 +355,7 @@ static void ggml_cpy_f32_q5_1_cuda( GGML_ASSERT(ne % QK5_1 == 0); const int64_t num_blocks = ne / QK5_1; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -362,7 +368,7 @@ static void ggml_cpy_q5_1_f32_cuda( const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int64_t num_blocks = ne; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); @@ -375,7 +381,7 @@ static void ggml_cpy_f32_iq4_nl_cuda( GGML_ASSERT(ne % QK4_NL == 0); const int64_t num_blocks = ne / QK4_NL; - GGML_ASSERT(num_blocks < UINT_MAX); + GGML_ASSERT(num_blocks <= INT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } From 8005ef82af17d87e0b2a25c4d164897bac5f6508 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Thu, 25 Jun 2026 18:48:24 -0700 Subject: [PATCH 22/30] opencl: flush profiling batch at shutdown for incomplete batches (llama/25016) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index fb330e06251..00f20b09b8f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -850,6 +850,7 @@ struct ggml_backend_opencl_context { ref_count--; if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING + flush_profiling_batch(); write_profiling_info(); profiling_results.clear(); #endif From 8be1e63e64fcf786a7ccf042540178c326776b22 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Fri, 26 Jun 2026 13:51:25 +0800 Subject: [PATCH 23/30] CUDA: batch out_prod broadcast (dps2>1) path with cublasSgemmBatched (llama/24426) --- ggml/src/ggml-cuda/out-prod.cu | 67 ++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 499903d09b1..46b9f3a67ee 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -2,6 +2,28 @@ #include +static __global__ void k_compute_out_prod_ptrs( + const float * src0_d, const float * src1_d, float * dst_d, + const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c, + const int64_t ne2, const int64_t ne3, + const int64_t dps2, const int64_t dps3, + const size_t s02, const size_t s03, + const size_t s12, const size_t s13, + const size_t s2, const size_t s3) { + const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x; + const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y; + + if (i2 >= ne2 || i3 >= ne3) { + return; + } + + const int64_t idx = i3*ne2 + i2; + + ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02; + ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12; + ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2; +} + void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { &beta, dst_d + i3 *s3, ldc, s2, batch_count)); } + } else if (ne2 > 1 || ne3 > 1) { + // dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs + // along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM. + GGML_ASSERT(ne3 > 0); + GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits::max() / ne3); + const int batch_count = (int) (ne2 * ne3); + + ggml_cuda_pool_alloc ptrs_a(ctx.pool(), batch_count); + ggml_cuda_pool_alloc ptrs_b(ctx.pool(), batch_count); + ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count); + + const dim3 block_dims(16, 16); + const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y); + k_compute_out_prod_ptrs<<>>( + src0_d, src1_d, dst_d, + ptrs_a.get(), ptrs_b.get(), ptrs_c.get(), + ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, ptrs_a.get(), lda, + ptrs_b.get(), ldb, + &beta, ptrs_c.get(), ldc, + batch_count)); } else { - // Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2 - // with non-uniform stride; would need cublasSgemmBatched with pointer arrays). - for (int64_t i3 = 0; i3 < ne3; ++i3) { - for (int64_t i2 = 0; i2 < ne2; ++i2) { - CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, - ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, - src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ldc)); - } - } + // ne2 == 1 && ne3 == 1: single GEMM + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d, lda, + src1_d, ldb, + &beta, dst_d, ldc)); } } From 9f0a6b6a1c4c4e8be818deba4028772ce0e000ac Mon Sep 17 00:00:00 2001 From: Jassieluo <130133492+Jassieluo@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:02:42 +0800 Subject: [PATCH 24/30] sycl : clamp softmax input to avoid underflow (llama/24941) --- ggml/src/ggml-sycl/softmax.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 18bf379bbeb..67ea282b4b3 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -126,7 +126,7 @@ static void soft_max_f32(const float * x, break; } - const float val = sycl::native::exp(vals[col] - max_val); + const float val = sycl::native::exp(sycl::max(vals[col] - max_val, -80.0f)); tmp += val; vals[col] = val; } @@ -154,7 +154,7 @@ static void soft_max_f32(const float * x, tmp = warp_reduce_sum(tmp); } if (sinks) { - tmp += sycl::native::exp(sinks[i02] - max_val); + tmp += sycl::native::exp(sycl::max(sinks[i02] - max_val, -80.0f)); } const float inv_sum = 1.0f / tmp; From 96e90a853ee842a4fe3fbb52d6d84d955fba5c50 Mon Sep 17 00:00:00 2001 From: Tarek Dakhran Date: Fri, 26 Jun 2026 09:41:56 +0200 Subject: [PATCH 25/30] ggml-cpu: fix SVE leftover path in ggml_vec_dot_f32 (llama/24699) * ggml-cpu: fix SVE leftover path in ggml_vec_dot_f32 2D convolutions with kernel size 9 produced different results on SVE enabled ARM devices. After debugging it turned out that ggml_vec_dot_f32 was using data from inactive lanes. Use svmla_f32_m(pg, sum1, ax1, ay1) so inactive lanes retain sum1. * cont : clean-up --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/vec.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 67b6b05cac8..ff2b636df86 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -75,12 +75,12 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G ay1 = GGML_F32_VEC_LOAD(y + i); sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); } - // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmla on available elements only if (np2 < n) { svbool_t pg = svwhilelt_b32(np2, n); ax1 = svld1_f32(pg, x + np2); ay1 = svld1_f32(pg, y + np2); - sum1 = svmad_f32_m(pg, ax1, ay1, sum1); + sum1 = svmla_f32_m(pg, sum1, ax1, ay1); } // reduce sum1,sum2 to sum1 GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); From c3281af61307ad8f7f784c65cd9f6d0fb5788ddc Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Fri, 26 Jun 2026 17:42:56 +0800 Subject: [PATCH 26/30] CUDA: add cublasSgemmBatched mapping for HIP/MUSA vendor headers (llama/25033) --- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 2 files changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index a6115cd80dc..d01f1533abb 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -48,6 +48,7 @@ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm +#define cublasSgemmBatched hipblasSgemmBatched #define cublasSgemmStridedBatched hipblasSgemmStridedBatched #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 99e8fa3703e..6d725c7ec19 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -32,6 +32,7 @@ #define cublasSetMathMode mublasSetMathMode #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm +#define cublasSgemmBatched mublasSgemmBatched #define cublasSgemmStridedBatched mublasSgemmStridedBatched #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t From 325c37a41af181001d06482ee97cf61966e4aac1 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 26 Jun 2026 04:53:32 -0500 Subject: [PATCH 27/30] vulkan: Workaround compiler bug in conv2d coopmat2 path (llama/24924) * vulkan: Workaround compiler bug in conv2d coopmat2 path * apply same workaround to CONV_3D * Apply suggestion from @jeffbolznv --- ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 1428ef68d81..99400098bf2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -158,7 +158,7 @@ const uint32_t Csh_stride = BS_NPQ; #ifdef COOPMAT const uint32_t Csh_len = BS_K * Csh_stride; #else -const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug #endif shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp index a9712eb3acf..f66f299f6da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp @@ -144,7 +144,7 @@ const uint32_t Csh_stride = BS_NPQ; #ifdef COOPMAT const uint32_t Csh_len = BS_K * Csh_stride; #else -const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug #endif shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ #endif From 8b2bb5cb1c98c7cdb8ccf0208e48806e809d61b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 26 Jun 2026 14:37:43 +0300 Subject: [PATCH 28/30] ggml : bump version to 0.15.3 (ggml/1550) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a0cd4e7158f..0ec62a3773d 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 15) -set(GGML_VERSION_PATCH 2) +set(GGML_VERSION_PATCH 3) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From b5f276c4a6a125b802b9f90d7f8569ae913d1035 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 26 Jun 2026 15:05:10 +0300 Subject: [PATCH 29/30] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 499be5a5856..27bab1a8ea6 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -707321c4cf6d21cb4bc831aa8b687dbf01a521ce +eced84c86f8b012c752c016f7fe789adea168e1e From fa343d2d5c39cfeb5661b983689ff1f33240de98 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 26 Jun 2026 15:06:05 +0300 Subject: [PATCH 30/30] talk-llama : sync llama.cpp --- examples/talk-llama/llama-context.cpp | 8 ++++ examples/talk-llama/llama-context.h | 1 + examples/talk-llama/llama-cparams.h | 2 + examples/talk-llama/llama-ext.h | 5 ++ examples/talk-llama/llama-graph.h | 11 ++++- examples/talk-llama/llama-model.cpp | 5 ++ examples/talk-llama/llama-model.h | 1 + examples/talk-llama/llama-quant.cpp | 6 +-- examples/talk-llama/llama-sampler.cpp | 2 - examples/talk-llama/llama.h | 17 +++---- examples/talk-llama/models/glm-dsa.cpp | 10 ++-- examples/talk-llama/models/lfm2.cpp | 19 ++++++-- examples/talk-llama/models/mamba-base.cpp | 1 - examples/talk-llama/models/mamba2.cpp | 13 +++--- examples/talk-llama/models/qwen35.cpp | 2 + examples/talk-llama/models/qwen35moe.cpp | 2 + examples/talk-llama/models/step35.cpp | 57 +++++++++++------------ 17 files changed, 102 insertions(+), 60 deletions(-) diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 529bc4a5e99..220240ea952 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { sched_need_reserve = true; } +void llama_context::set_nextn_layer_offset(int32_t offset) { + cparams.nextn_layer_offset = offset; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu ctx->set_embeddings_layer_inp(lid, value); } +void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) { + ctx->set_nextn_layer_offset(offset); +} + llama_memory_t llama_get_memory(const struct llama_context * ctx) { if (!ctx) { return nullptr; diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 853052be2ca..f8b7805871e 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -115,6 +115,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); void set_embeddings_layer_inp(uint32_t lid, bool enable); + void set_nextn_layer_offset(int32_t offset); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 2b109f909c0..546ae1e2c12 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -18,6 +18,8 @@ struct llama_cparams { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + int32_t nextn_layer_offset = 0; + float rope_freq_base; float rope_freq_scale; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 8b5679b690b..348bbae9577 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); +// Select which appended NextN block the DECODER_MTP graph runs (offset past +// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to +// chain multiple trained NextN heads. Default 0 (first head). +LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset); + // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 5e8a658350a..a6e8c3985ba 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -682,9 +682,16 @@ struct llm_graph_params { } } + // TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248 + if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) { + return false; + } + return - cparams.embeddings == other.cparams.embeddings && - cparams.causal_attn == other.cparams.causal_attn && + cparams.embeddings == other.cparams.embeddings && + cparams.embeddings_nextn == other.cparams.embeddings_nextn && + cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked && + cparams.causal_attn == other.cparams.causal_attn && arch == other.arch && gtype == other.gtype && cvec == other.cvec && diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index c5287553390..6cb0ec3791c 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -700,6 +700,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; + case LLM_TYPE_230M: return "230M"; case LLM_TYPE_250M: return "250M"; case LLM_TYPE_256M: return "256M"; case LLM_TYPE_270M: return "270M"; @@ -2312,6 +2313,10 @@ int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer(); } +int32_t llama_model_n_layer_nextn(const llama_model * model) { + return model->hparams.n_layer_nextn; +} + int32_t llama_model_n_head(const llama_model * model) { return model->hparams.n_head(); } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index f4718f6d584..77d8d3b6258 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -36,6 +36,7 @@ enum llm_type { LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, + LLM_TYPE_230M, LLM_TYPE_250M, LLM_TYPE_256M, LLM_TYPE_270M, diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index cf92ce4bb8b..847e79f4655 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -847,7 +847,7 @@ static void init_quantize_state_counters(quantize_state_impl & qs, std::vectordata[i].logit = -INFINITY; } } - - llama_sampler_softmax_impl(cur_p, true); } static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 27e48067428..f723c9f60cf 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -558,14 +558,15 @@ extern "C" { LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp index 11d91312def..32fe6def6f3 100644 --- a/examples/talk-llama/models/glm-dsa.cpp +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -101,11 +101,11 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); - layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); - layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); if (i < (int) hparams.n_layer_dense_lead) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index 97da8a6abb8..70e837d6eb2 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -13,6 +13,7 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { hparams.n_layer_dense_lead = hparams.n_layer(); switch (hparams.n_ff()) { + case 2560: type = LLM_TYPE_230M; break; case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; @@ -190,7 +191,15 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); - bx = ggml_concat(ctx0, conv, bx, 0); + // causal prepends the state, non-causal pads symmetrically for a centered window + if (hparams.causal_attn) { + bx = ggml_concat(ctx0, conv, bx, 0); + } else { + const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2; + auto * left = ggml_cont(ctx0, + ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0])); + bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0); + } GGML_ASSERT(bx->ne[0] > conv->ne[0]); // last d_conv columns is a new conv state @@ -266,10 +275,12 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur, model.output_s); - cb(cur, "result_output", -1); + if (!cparams.embeddings) { + cur = build_lora_mm(model.output, cur, model.output_s); + cb(cur, "result_output", -1); - res->t_logits = cur; + res->t_logits = cur; + } ggml_build_forward_expand(gf, cur); } diff --git a/examples/talk-llama/models/mamba-base.cpp b/examples/talk-llama/models/mamba-base.cpp index c37f29c487e..fd3fe3f0323 100644 --- a/examples/talk-llama/models/mamba-base.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -169,7 +169,6 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); GGML_ASSERT(d_inner % n_head == 0); - GGML_ASSERT(d_inner % d_state == 0); GGML_ASSERT(d_inner % n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp index c5951cf0f7f..d5c167cf056 100644 --- a/examples/talk-llama/models/mamba2.cpp +++ b/examples/talk-llama/models/mamba2.cpp @@ -39,10 +39,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) { const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + const int64_t dt_rank = hparams.ssm_dt_rank; + + const int64_t conv_dim = d_inner + 2 * n_group * d_state; + const int64_t d_in_proj = d_inner + conv_dim + dt_rank; - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -68,11 +69,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) { layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {dt_rank}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, dt_rank}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, dt_rank}, 0); layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 6783d98ec20..d8ffe43ae76 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index eb5e9a406a1..7b0876cbb04 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index e2218c58704..9b7b18a3678 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); }; - auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto load_block_mtp = [&](int i) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); @@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). - // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. - // - // Only the FIRST MTP block (i == n_main) is required for the - // single-block MTP runtime; trailing MTP blocks are always tolerated - // as missing so pruned GGUFs (block 0 only) load cleanly. Override - // mtp_flags to NOT_REQUIRED for those. - const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + // Multi-block MTP: every declared MTP block is required (the draft chain + // runs all n_layer_nextn heads), so each block uses the captured + // `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF, + // which keeps that path correct. + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags); layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags); // dense MLP (leading dense blocks) — present if the MTP block isn't MoE layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); @@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - // Only the first MTP block (i == n_main) is required at runtime — the - // single-block-MTP graph in build_arch_graph always uses that one. - // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with - // all MTP layers still works) but tolerated when absent via the pruning - // path. See scripts/prune_step35_extra_mtp.py for the pruner. + // All n_layer_nextn MTP blocks are required — the multi-block draft chain + // runs every head (head k at offset k). The GGUF declares the count via + // step35.nextn_predict_layers. for (int i = n_layer; i < n_layer_all; ++i) { - load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); + load_block_mtp(i); } } @@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr : llm_graph_context(params) { GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); - // Single-block MTP only: always run the first trained MTP block (Qwen - // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to - // be a much deeper refactor than this PR justifies; the trailing MTP - // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just - // block 0) also work — see load_arch_tensors below and - // scripts/prune_step35_extra_mtp.py. - const int il = hparams.n_layer(); + // Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by + // cparams.nextn_layer_offset (0 = first trained head). The speculative driver + // bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps + // single-block behavior identical to before. + const int il = hparams.n_layer() + cparams.nextn_layer_offset; + GGML_ASSERT(cparams.nextn_layer_offset >= 0 && + cparams.nextn_layer_offset < (int) hparams.n_layer_nextn && + "nextn_layer_offset out of range [0, n_layer_nextn)"); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "mtp_post_ffn", il); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. cb(cur, "h_nextn", -1); res->t_h_nextn = cur;